Repository: google/cel-cpp Branch: master Commit: 9e73d93f77a1 Files: 1115 Total size: 8.7 MB Directory structure: gitextract_qb36pkd3/ ├── .bazelrc ├── .bazelversion ├── .bcr/ │ ├── README.md │ ├── metadata.template.json │ ├── presubmit.yml │ └── source.template.json ├── .github/ │ └── workflows/ │ └── publish_to_bcr.yml ├── .gitignore ├── BUILD.bazel ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── MODULE.bazel ├── README.md ├── base/ │ ├── BUILD │ ├── ast.h │ ├── attribute.cc │ ├── attribute.h │ ├── attribute_set.h │ ├── builtins.h │ ├── function.h │ ├── function_adapter.h │ ├── function_descriptor.h │ ├── function_result.h │ ├── function_result_set.cc │ ├── function_result_set.h │ ├── internal/ │ │ ├── BUILD │ │ ├── memory_manager_testing.cc │ │ ├── memory_manager_testing.h │ │ ├── message_wrapper.h │ │ ├── operators.h │ │ ├── unknown_set.cc │ │ └── unknown_set.h │ ├── kind.h │ ├── operators.cc │ ├── operators.h │ ├── operators_test.cc │ └── type_provider.h ├── bazel/ │ ├── BUILD │ ├── antlr.bzl │ ├── antlr.patch │ ├── cat_param_file.cc │ ├── cel_cc_embed.bzl │ ├── cel_cc_embed.cc │ ├── cel_proto_transitive_descriptor_set.bzl │ └── deps.bzl ├── checker/ │ ├── BUILD │ ├── checker_options.h │ ├── internal/ │ │ ├── BUILD │ │ ├── builtins_arena.cc │ │ ├── builtins_arena.h │ │ ├── descriptor_pool_type_introspector.cc │ │ ├── descriptor_pool_type_introspector.h │ │ ├── descriptor_pool_type_introspector_test.cc │ │ ├── format_type_name.cc │ │ ├── format_type_name.h │ │ ├── format_type_name_test.cc │ │ ├── namespace_generator.cc │ │ ├── namespace_generator.h │ │ ├── namespace_generator_test.cc │ │ ├── test_ast_helpers.cc │ │ ├── test_ast_helpers.h │ │ ├── test_ast_helpers_test.cc │ │ ├── type_check_env.cc │ │ ├── type_check_env.h │ │ ├── type_checker_builder_impl.cc │ │ ├── type_checker_builder_impl.h │ │ ├── type_checker_builder_impl_test.cc │ │ ├── type_checker_impl.cc │ │ ├── type_checker_impl.h │ │ ├── type_checker_impl_test.cc │ │ ├── type_inference_context.cc │ │ ├── type_inference_context.h │ │ └── type_inference_context_test.cc │ ├── optional.cc │ ├── optional.h │ ├── optional_test.cc │ ├── standard_library.cc │ ├── standard_library.h │ ├── standard_library_test.cc │ ├── type_check_issue.cc │ ├── type_check_issue.h │ ├── type_check_issue_test.cc │ ├── type_checker.cc │ ├── type_checker.h │ ├── type_checker_builder.h │ ├── type_checker_builder_factory.cc │ ├── type_checker_builder_factory.h │ ├── type_checker_builder_factory_test.cc │ ├── type_checker_subset_factory.cc │ ├── type_checker_subset_factory.h │ ├── type_checker_subset_factory_test.cc │ ├── validation_result.cc │ ├── validation_result.h │ └── validation_result_test.cc ├── cloudbuild.yaml ├── codelab/ │ ├── BUILD │ ├── Dockerfile │ ├── README.md │ ├── cel_compiler.h │ ├── cel_compiler_test.cc │ ├── exercise1.cc │ ├── exercise1.h │ ├── exercise10.cc │ ├── exercise10.h │ ├── exercise10_test.cc │ ├── exercise1_test.cc │ ├── exercise2.cc │ ├── exercise2.h │ ├── exercise2_test.cc │ ├── exercise3_test.cc │ ├── exercise4.cc │ ├── exercise4.h │ ├── exercise4_test.cc │ ├── network_functions.cc │ ├── network_functions.h │ ├── network_functions_test.cc │ └── solutions/ │ ├── BUILD │ ├── exercise1.cc │ ├── exercise10.cc │ ├── exercise2.cc │ ├── exercise3_test.cc │ └── exercise4.cc ├── common/ │ ├── BUILD │ ├── allocator.h │ ├── allocator_test.cc │ ├── any.cc │ ├── any.h │ ├── any_test.cc │ ├── arena.h │ ├── arena_string.h │ ├── arena_string_pool.h │ ├── arena_string_pool_test.cc │ ├── arena_string_test.cc │ ├── arena_string_view.h │ ├── arena_string_view_test.cc │ ├── ast/ │ │ ├── BUILD │ │ ├── constant_proto.cc │ │ ├── constant_proto.h │ │ ├── expr_proto.cc │ │ ├── expr_proto.h │ │ ├── expr_proto_test.cc │ │ ├── metadata.cc │ │ ├── metadata.h │ │ ├── metadata_test.cc │ │ ├── navigable_ast_internal.h │ │ ├── navigable_ast_internal_test.cc │ │ ├── navigable_ast_kinds.cc │ │ ├── navigable_ast_kinds.h │ │ ├── source_info_proto.cc │ │ └── source_info_proto.h │ ├── ast.cc │ ├── ast.h │ ├── ast_proto.cc │ ├── ast_proto.h │ ├── ast_proto_test.cc │ ├── ast_rewrite.cc │ ├── ast_rewrite.h │ ├── ast_rewrite_test.cc │ ├── ast_test.cc │ ├── ast_traverse.cc │ ├── ast_traverse.h │ ├── ast_traverse_test.cc │ ├── ast_visitor.h │ ├── ast_visitor_base.h │ ├── casting.h │ ├── constant.cc │ ├── constant.h │ ├── constant_test.cc │ ├── container.cc │ ├── container.h │ ├── container_test.cc │ ├── data.h │ ├── data_test.cc │ ├── decl.cc │ ├── decl.h │ ├── decl_proto.cc │ ├── decl_proto.h │ ├── decl_proto_test.cc │ ├── decl_proto_v1alpha1.cc │ ├── decl_proto_v1alpha1.h │ ├── decl_test.cc │ ├── expr.cc │ ├── expr.h │ ├── expr_factory.h │ ├── expr_test.cc │ ├── function_descriptor.cc │ ├── function_descriptor.h │ ├── internal/ │ │ ├── BUILD │ │ ├── byte_string.cc │ │ ├── byte_string.h │ │ ├── byte_string_test.cc │ │ ├── casting.h │ │ ├── metadata.h │ │ ├── reference_count.cc │ │ ├── reference_count.h │ │ ├── reference_count_test.cc │ │ ├── signature.cc │ │ ├── signature.h │ │ ├── signature_test.cc │ │ ├── value_conversion.cc │ │ └── value_conversion.h │ ├── json.h │ ├── kind.cc │ ├── kind.h │ ├── kind_test.cc │ ├── legacy_value.cc │ ├── legacy_value.h │ ├── memory.cc │ ├── memory.h │ ├── memory_test.cc │ ├── memory_testing.h │ ├── minimal_descriptor_database.cc │ ├── minimal_descriptor_database.h │ ├── minimal_descriptor_database_test.cc │ ├── minimal_descriptor_pool.cc │ ├── minimal_descriptor_pool.h │ ├── minimal_descriptor_pool_test.cc │ ├── native_type.h │ ├── navigable_ast.cc │ ├── navigable_ast.h │ ├── navigable_ast_test.cc │ ├── operators.cc │ ├── operators.h │ ├── optional_ref.h │ ├── reference.cc │ ├── reference.h │ ├── reference_count.h │ ├── reference_test.cc │ ├── source.cc │ ├── source.h │ ├── source_test.cc │ ├── standard_definitions.h │ ├── type.cc │ ├── type.h │ ├── type_introspector.cc │ ├── type_introspector.h │ ├── type_kind.h │ ├── type_proto.cc │ ├── type_proto.h │ ├── type_proto_test.cc │ ├── type_reflector.h │ ├── type_reflector_test.cc │ ├── type_test.cc │ ├── type_testing.h │ ├── typeinfo.cc │ ├── typeinfo.h │ ├── typeinfo_test.cc │ ├── types/ │ │ ├── any_type.h │ │ ├── any_type_test.cc │ │ ├── basic_struct_type.cc │ │ ├── basic_struct_type.h │ │ ├── basic_struct_type_test.cc │ │ ├── bool_type.h │ │ ├── bool_type_test.cc │ │ ├── bool_wrapper_type.h │ │ ├── bool_wrapper_type_test.cc │ │ ├── bytes_type.h │ │ ├── bytes_type_test.cc │ │ ├── bytes_wrapper_type.h │ │ ├── bytes_wrapper_type_test.cc │ │ ├── double_type.h │ │ ├── double_type_test.cc │ │ ├── double_wrapper_type.h │ │ ├── double_wrapper_type_test.cc │ │ ├── duration_type.h │ │ ├── duration_type_test.cc │ │ ├── dyn_type.h │ │ ├── dyn_type_test.cc │ │ ├── enum_type.cc │ │ ├── enum_type.h │ │ ├── enum_type_test.cc │ │ ├── error_type.h │ │ ├── error_type_test.cc │ │ ├── function_type.cc │ │ ├── function_type.h │ │ ├── function_type_pool.cc │ │ ├── function_type_pool.h │ │ ├── function_type_test.cc │ │ ├── int_type.h │ │ ├── int_type_test.cc │ │ ├── int_wrapper_type.h │ │ ├── int_wrapper_type_test.cc │ │ ├── legacy_type_introspector.h │ │ ├── list_type.cc │ │ ├── list_type.h │ │ ├── list_type_pool.cc │ │ ├── list_type_pool.h │ │ ├── list_type_test.cc │ │ ├── map_type.cc │ │ ├── map_type.h │ │ ├── map_type_pool.cc │ │ ├── map_type_pool.h │ │ ├── map_type_test.cc │ │ ├── message_type.cc │ │ ├── message_type.h │ │ ├── message_type_test.cc │ │ ├── null_type.h │ │ ├── null_type_test.cc │ │ ├── opaque_type.cc │ │ ├── opaque_type.h │ │ ├── opaque_type_pool.cc │ │ ├── opaque_type_pool.h │ │ ├── opaque_type_test.cc │ │ ├── optional_type.cc │ │ ├── optional_type.h │ │ ├── optional_type_test.cc │ │ ├── string_type.h │ │ ├── string_type_test.cc │ │ ├── string_wrapper_type.h │ │ ├── string_wrapper_type_test.cc │ │ ├── struct_type.cc │ │ ├── struct_type.h │ │ ├── struct_type_test.cc │ │ ├── timestamp_type.h │ │ ├── timestamp_type_test.cc │ │ ├── type_param_type.h │ │ ├── type_param_type_test.cc │ │ ├── type_pool.cc │ │ ├── type_pool.h │ │ ├── type_pool_test.cc │ │ ├── type_type.cc │ │ ├── type_type.h │ │ ├── type_type_pool.cc │ │ ├── type_type_pool.h │ │ ├── type_type_test.cc │ │ ├── types.h │ │ ├── uint_type.h │ │ ├── uint_type_test.cc │ │ ├── uint_wrapper_type.h │ │ ├── uint_wrapper_type_test.cc │ │ ├── unknown_type.h │ │ └── unknown_type_test.cc │ ├── unknown.h │ ├── value.cc │ ├── value.h │ ├── value_kind.h │ ├── value_test.cc │ ├── value_testing.cc │ ├── value_testing.h │ ├── value_testing_test.cc │ └── values/ │ ├── bool_value.cc │ ├── bool_value.h │ ├── bool_value_test.cc │ ├── bytes_value.cc │ ├── bytes_value.h │ ├── bytes_value_input_stream.h │ ├── bytes_value_output_stream.h │ ├── bytes_value_test.cc │ ├── custom_list_value.cc │ ├── custom_list_value.h │ ├── custom_list_value_test.cc │ ├── custom_map_value.cc │ ├── custom_map_value.h │ ├── custom_map_value_test.cc │ ├── custom_struct_value.cc │ ├── custom_struct_value.h │ ├── custom_struct_value_test.cc │ ├── custom_value.h │ ├── double_value.cc │ ├── double_value.h │ ├── double_value_test.cc │ ├── duration_value.cc │ ├── duration_value.h │ ├── duration_value_test.cc │ ├── enum_value.h │ ├── error_value.cc │ ├── error_value.h │ ├── error_value_test.cc │ ├── int_value.cc │ ├── int_value.h │ ├── int_value_test.cc │ ├── legacy_list_value.cc │ ├── legacy_list_value.h │ ├── legacy_map_value.cc │ ├── legacy_map_value.h │ ├── legacy_struct_value.cc │ ├── legacy_struct_value.h │ ├── list_value.cc │ ├── list_value.h │ ├── list_value_builder.h │ ├── list_value_test.cc │ ├── list_value_variant.h │ ├── map_value.cc │ ├── map_value.h │ ├── map_value_builder.h │ ├── map_value_test.cc │ ├── map_value_variant.h │ ├── message_value.cc │ ├── message_value.h │ ├── message_value_test.cc │ ├── mutable_list_value_test.cc │ ├── mutable_map_value_test.cc │ ├── null_value.cc │ ├── null_value.h │ ├── null_value_test.cc │ ├── opaque_value.cc │ ├── opaque_value.h │ ├── optional_value.cc │ ├── optional_value.h │ ├── optional_value_test.cc │ ├── parsed_json_list_value.cc │ ├── parsed_json_list_value.h │ ├── parsed_json_list_value_test.cc │ ├── parsed_json_map_value.cc │ ├── parsed_json_map_value.h │ ├── parsed_json_map_value_test.cc │ ├── parsed_json_value.cc │ ├── parsed_json_value.h │ ├── parsed_json_value_test.cc │ ├── parsed_map_field_value.cc │ ├── parsed_map_field_value.h │ ├── parsed_map_field_value_test.cc │ ├── parsed_message_value.cc │ ├── parsed_message_value.h │ ├── parsed_message_value_test.cc │ ├── parsed_repeated_field_value.cc │ ├── parsed_repeated_field_value.h │ ├── parsed_repeated_field_value_test.cc │ ├── string_value.cc │ ├── string_value.h │ ├── string_value_test.cc │ ├── struct_value.cc │ ├── struct_value.h │ ├── struct_value_builder.cc │ ├── struct_value_builder.h │ ├── struct_value_test.cc │ ├── struct_value_variant.h │ ├── timestamp_value.cc │ ├── timestamp_value.h │ ├── timestamp_value_test.cc │ ├── type_value.cc │ ├── type_value.h │ ├── type_value_test.cc │ ├── uint_value.cc │ ├── uint_value.h │ ├── uint_value_test.cc │ ├── unknown_value.cc │ ├── unknown_value.h │ ├── unknown_value_test.cc │ ├── value_builder.cc │ ├── value_builder.h │ ├── value_variant.cc │ ├── value_variant.h │ ├── value_variant_test.cc │ └── values.h ├── compiler/ │ ├── BUILD │ ├── compiler.h │ ├── compiler_factory.cc │ ├── compiler_factory.h │ ├── compiler_factory_test.cc │ ├── compiler_library_subset_factory.cc │ ├── compiler_library_subset_factory.h │ ├── compiler_library_subset_factory_test.cc │ ├── optional.cc │ ├── optional.h │ ├── optional_test.cc │ ├── standard_library.cc │ └── standard_library.h ├── conformance/ │ ├── BUILD │ ├── run.bzl │ ├── run.cc │ ├── service.cc │ ├── service.h │ └── utils.h ├── env/ │ ├── BUILD │ ├── config.cc │ ├── config.h │ ├── config_test.cc │ ├── env.cc │ ├── env.h │ ├── env_runtime.cc │ ├── env_runtime.h │ ├── env_runtime_test.cc │ ├── env_std_extensions.cc │ ├── env_std_extensions.h │ ├── env_std_extensions_test.cc │ ├── env_test.cc │ ├── env_yaml.cc │ ├── env_yaml.h │ ├── env_yaml_test.cc │ ├── internal/ │ │ ├── BUILD │ │ ├── ext_registry.cc │ │ ├── ext_registry.h │ │ ├── ext_registry_test.cc │ │ ├── runtime_ext_registry.cc │ │ ├── runtime_ext_registry.h │ │ └── runtime_ext_registry_test.cc │ ├── runtime_std_extensions.cc │ ├── runtime_std_extensions.h │ ├── runtime_std_extensions_test.cc │ ├── type_info.cc │ ├── type_info.h │ └── type_info_test.cc ├── eval/ │ ├── BUILD │ ├── LICENSE │ ├── README.md │ ├── compiler/ │ │ ├── BUILD │ │ ├── LICENSE │ │ ├── cel_expression_builder_flat_impl.cc │ │ ├── cel_expression_builder_flat_impl.h │ │ ├── cel_expression_builder_flat_impl_test.cc │ │ ├── check_ast_extensions.cc │ │ ├── check_ast_extensions.h │ │ ├── check_ast_extensions_test.cc │ │ ├── comprehension_vulnerability_check.cc │ │ ├── comprehension_vulnerability_check.h │ │ ├── constant_folding.cc │ │ ├── constant_folding.h │ │ ├── constant_folding_test.cc │ │ ├── flat_expr_builder.cc │ │ ├── flat_expr_builder.h │ │ ├── flat_expr_builder_comprehensions_test.cc │ │ ├── flat_expr_builder_extensions.cc │ │ ├── flat_expr_builder_extensions.h │ │ ├── flat_expr_builder_extensions_test.cc │ │ ├── flat_expr_builder_short_circuiting_conformance_test.cc │ │ ├── flat_expr_builder_test.cc │ │ ├── instrumentation.cc │ │ ├── instrumentation.h │ │ ├── instrumentation_test.cc │ │ ├── qualified_reference_resolver.cc │ │ ├── qualified_reference_resolver.h │ │ ├── qualified_reference_resolver_test.cc │ │ ├── regex_precompilation_optimization.cc │ │ ├── regex_precompilation_optimization.h │ │ ├── regex_precompilation_optimization_test.cc │ │ ├── resolver.cc │ │ ├── resolver.h │ │ └── resolver_test.cc │ ├── eval/ │ │ ├── BUILD │ │ ├── LICENSE │ │ ├── attribute_trail.cc │ │ ├── attribute_trail.h │ │ ├── attribute_trail_test.cc │ │ ├── attribute_utility.cc │ │ ├── attribute_utility.h │ │ ├── attribute_utility_test.cc │ │ ├── cel_expression_flat_impl.cc │ │ ├── cel_expression_flat_impl.h │ │ ├── compiler_constant_step.cc │ │ ├── compiler_constant_step.h │ │ ├── compiler_constant_step_test.cc │ │ ├── comprehension_slots.h │ │ ├── comprehension_slots_test.cc │ │ ├── comprehension_step.cc │ │ ├── comprehension_step.h │ │ ├── comprehension_step_test.cc │ │ ├── const_value_step.h │ │ ├── container_access_step.cc │ │ ├── container_access_step.h │ │ ├── container_access_step_test.cc │ │ ├── create_list_step.cc │ │ ├── create_list_step.h │ │ ├── create_list_step_test.cc │ │ ├── create_map_step.cc │ │ ├── create_map_step.h │ │ ├── create_map_step_test.cc │ │ ├── create_struct_step.cc │ │ ├── create_struct_step.h │ │ ├── create_struct_step_test.cc │ │ ├── direct_expression_step.cc │ │ ├── direct_expression_step.h │ │ ├── equality_steps.cc │ │ ├── equality_steps.h │ │ ├── equality_steps_test.cc │ │ ├── evaluator_core.cc │ │ ├── evaluator_core.h │ │ ├── evaluator_core_test.cc │ │ ├── evaluator_stack.cc │ │ ├── evaluator_stack.h │ │ ├── evaluator_stack_test.cc │ │ ├── expression_step_base.h │ │ ├── function_step.cc │ │ ├── function_step.h │ │ ├── function_step_test.cc │ │ ├── ident_step.cc │ │ ├── ident_step.h │ │ ├── ident_step_test.cc │ │ ├── iterator_stack.h │ │ ├── jump_step.cc │ │ ├── jump_step.h │ │ ├── lazy_init_step.cc │ │ ├── lazy_init_step.h │ │ ├── lazy_init_step_test.cc │ │ ├── logic_step.cc │ │ ├── logic_step.h │ │ ├── logic_step_test.cc │ │ ├── optional_or_step.cc │ │ ├── optional_or_step.h │ │ ├── optional_or_step_test.cc │ │ ├── regex_match_step.cc │ │ ├── regex_match_step.h │ │ ├── regex_match_step_test.cc │ │ ├── select_step.cc │ │ ├── select_step.h │ │ ├── select_step_test.cc │ │ ├── shadowable_value_step.cc │ │ ├── shadowable_value_step.h │ │ ├── shadowable_value_step_test.cc │ │ ├── ternary_step.cc │ │ ├── ternary_step.h │ │ ├── ternary_step_test.cc │ │ └── trace_step.h │ ├── internal/ │ │ ├── BUILD │ │ ├── adapter_activation_impl.cc │ │ ├── adapter_activation_impl.h │ │ ├── cel_value_equal.cc │ │ ├── cel_value_equal.h │ │ ├── cel_value_equal_test.cc │ │ ├── errors.cc │ │ ├── errors.h │ │ └── interop.h │ ├── public/ │ │ ├── BUILD │ │ ├── LICENSE │ │ ├── activation.cc │ │ ├── activation.h │ │ ├── activation_bind_helper.cc │ │ ├── activation_bind_helper.h │ │ ├── activation_bind_helper_test.cc │ │ ├── activation_test.cc │ │ ├── ast_rewrite.cc │ │ ├── ast_rewrite.h │ │ ├── ast_rewrite_test.cc │ │ ├── ast_traverse.cc │ │ ├── ast_traverse.h │ │ ├── ast_traverse_test.cc │ │ ├── ast_visitor.h │ │ ├── ast_visitor_base.h │ │ ├── base_activation.h │ │ ├── builtin_func_registrar.cc │ │ ├── builtin_func_registrar.h │ │ ├── builtin_func_registrar_test.cc │ │ ├── builtin_func_test.cc │ │ ├── cel_attribute.cc │ │ ├── cel_attribute.h │ │ ├── cel_attribute_test.cc │ │ ├── cel_builtins.h │ │ ├── cel_expr_builder_factory.cc │ │ ├── cel_expr_builder_factory.h │ │ ├── cel_expression.h │ │ ├── cel_function.cc │ │ ├── cel_function.h │ │ ├── cel_function_adapter.h │ │ ├── cel_function_adapter_impl.h │ │ ├── cel_function_adapter_test.cc │ │ ├── cel_function_registry.cc │ │ ├── cel_function_registry.h │ │ ├── cel_function_registry_test.cc │ │ ├── cel_number.cc │ │ ├── cel_number.h │ │ ├── cel_number_test.cc │ │ ├── cel_options.cc │ │ ├── cel_options.h │ │ ├── cel_type_registry.cc │ │ ├── cel_type_registry.h │ │ ├── cel_type_registry_protobuf_reflection_test.cc │ │ ├── cel_type_registry_test.cc │ │ ├── cel_value.cc │ │ ├── cel_value.h │ │ ├── cel_value_internal.h │ │ ├── cel_value_producer.h │ │ ├── cel_value_test.cc │ │ ├── comparison_functions.cc │ │ ├── comparison_functions.h │ │ ├── comparison_functions_test.cc │ │ ├── container_function_registrar.cc │ │ ├── container_function_registrar.h │ │ ├── container_function_registrar_test.cc │ │ ├── containers/ │ │ │ ├── BUILD │ │ │ ├── container_backed_list_impl.h │ │ │ ├── container_backed_map_impl.cc │ │ │ ├── container_backed_map_impl.h │ │ │ ├── container_backed_map_impl_test.cc │ │ │ ├── field_access.cc │ │ │ ├── field_access.h │ │ │ ├── field_access_test.cc │ │ │ ├── field_backed_list_impl.h │ │ │ ├── field_backed_list_impl_test.cc │ │ │ ├── field_backed_map_impl.h │ │ │ ├── field_backed_map_impl_test.cc │ │ │ ├── internal_field_backed_list_impl.cc │ │ │ ├── internal_field_backed_list_impl.h │ │ │ ├── internal_field_backed_list_impl_test.cc │ │ │ ├── internal_field_backed_map_impl.cc │ │ │ ├── internal_field_backed_map_impl.h │ │ │ └── internal_field_backed_map_impl_test.cc │ │ ├── equality_function_registrar.cc │ │ ├── equality_function_registrar.h │ │ ├── equality_function_registrar_test.cc │ │ ├── extension_func_registrar.cc │ │ ├── extension_func_registrar.h │ │ ├── extension_func_test.cc │ │ ├── logical_function_registrar.cc │ │ ├── logical_function_registrar.h │ │ ├── logical_function_registrar_test.cc │ │ ├── message_wrapper.h │ │ ├── message_wrapper_test.cc │ │ ├── portable_cel_function_adapter.h │ │ ├── set_util.cc │ │ ├── set_util.h │ │ ├── set_util_test.cc │ │ ├── source_position.cc │ │ ├── source_position.h │ │ ├── source_position_test.cc │ │ ├── string_extension_func_registrar.cc │ │ ├── string_extension_func_registrar.h │ │ ├── string_extension_func_registrar_test.cc │ │ ├── structs/ │ │ │ ├── BUILD │ │ │ ├── cel_proto_descriptor_pool_builder.cc │ │ │ ├── cel_proto_descriptor_pool_builder.h │ │ │ ├── cel_proto_descriptor_pool_builder_test.cc │ │ │ ├── cel_proto_wrap_util.cc │ │ │ ├── cel_proto_wrap_util.h │ │ │ ├── cel_proto_wrap_util_test.cc │ │ │ ├── cel_proto_wrapper.cc │ │ │ ├── cel_proto_wrapper.h │ │ │ ├── cel_proto_wrapper_test.cc │ │ │ ├── dynamic_descriptor_pool_end_to_end_test.cc │ │ │ ├── field_access_impl.cc │ │ │ ├── field_access_impl.h │ │ │ ├── field_access_impl_test.cc │ │ │ ├── legacy_type_adapter.h │ │ │ ├── legacy_type_adapter_test.cc │ │ │ ├── legacy_type_info_apis.h │ │ │ ├── legacy_type_provider.cc │ │ │ ├── legacy_type_provider.h │ │ │ ├── legacy_type_provider_test.cc │ │ │ ├── proto_message_type_adapter.cc │ │ │ ├── proto_message_type_adapter.h │ │ │ ├── proto_message_type_adapter_test.cc │ │ │ ├── protobuf_descriptor_type_provider.cc │ │ │ ├── protobuf_descriptor_type_provider.h │ │ │ ├── protobuf_descriptor_type_provider_test.cc │ │ │ ├── protobuf_value_factory.h │ │ │ ├── trivial_legacy_type_info.h │ │ │ └── trivial_legacy_type_info_test.cc │ │ ├── testing/ │ │ │ ├── BUILD │ │ │ ├── matchers.cc │ │ │ ├── matchers.h │ │ │ └── matchers_test.cc │ │ ├── transform_utility.cc │ │ ├── transform_utility.h │ │ ├── unknown_attribute_set.h │ │ ├── unknown_attribute_set_test.cc │ │ ├── unknown_function_result_set.cc │ │ ├── unknown_function_result_set.h │ │ ├── unknown_function_result_set_test.cc │ │ ├── unknown_set.h │ │ ├── unknown_set_test.cc │ │ ├── value_export_util.cc │ │ ├── value_export_util.h │ │ └── value_export_util_test.cc │ ├── tests/ │ │ ├── BUILD │ │ ├── LICENSE │ │ ├── README.md │ │ ├── allocation_benchmark_test.cc │ │ ├── benchmark_test.cc │ │ ├── end_to_end_test.cc │ │ ├── expression_builder_benchmark_test.cc │ │ ├── memory_safety_test.cc │ │ ├── mock_cel_expression.h │ │ ├── modern_benchmark_test.cc │ │ ├── request_context.proto │ │ └── unknowns_end_to_end_test.cc │ └── testutil/ │ ├── BUILD │ ├── test_extensions.proto │ └── test_message.proto ├── extensions/ │ ├── BUILD │ ├── bindings_ext.cc │ ├── bindings_ext.h │ ├── bindings_ext_benchmark_test.cc │ ├── bindings_ext_test.cc │ ├── comprehensions_v2.cc │ ├── comprehensions_v2.h │ ├── comprehensions_v2_functions.cc │ ├── comprehensions_v2_functions.h │ ├── comprehensions_v2_macros.cc │ ├── comprehensions_v2_macros.h │ ├── comprehensions_v2_test.cc │ ├── encoders.cc │ ├── encoders.h │ ├── encoders_test.cc │ ├── formatting.cc │ ├── formatting.h │ ├── formatting_test.cc │ ├── lists_functions.cc │ ├── lists_functions.h │ ├── lists_functions_test.cc │ ├── math_ext.cc │ ├── math_ext.h │ ├── math_ext_decls.cc │ ├── math_ext_decls.h │ ├── math_ext_macros.cc │ ├── math_ext_macros.h │ ├── math_ext_test.cc │ ├── proto_ext.cc │ ├── proto_ext.h │ ├── protobuf/ │ │ ├── BUILD │ │ ├── ast_converters.h │ │ ├── bind_proto_to_activation.cc │ │ ├── bind_proto_to_activation.h │ │ ├── bind_proto_to_activation_test.cc │ │ ├── enum_adapter.cc │ │ ├── enum_adapter.h │ │ ├── internal/ │ │ │ ├── BUILD │ │ │ ├── map_reflection.cc │ │ │ ├── map_reflection.h │ │ │ ├── qualify.cc │ │ │ └── qualify.h │ │ ├── memory_manager.cc │ │ ├── memory_manager.h │ │ ├── memory_manager_test.cc │ │ ├── runtime_adapter.cc │ │ ├── runtime_adapter.h │ │ ├── value.h │ │ ├── value_end_to_end_test.cc │ │ ├── value_test.cc │ │ ├── value_testing.h │ │ └── value_testing_test.cc │ ├── regex_ext.cc │ ├── regex_ext.h │ ├── regex_ext_test.cc │ ├── regex_functions.cc │ ├── regex_functions.h │ ├── regex_functions_test.cc │ ├── select_optimization.cc │ ├── select_optimization.h │ ├── select_optimization_test.cc │ ├── sets_functions.cc │ ├── sets_functions.h │ ├── sets_functions_benchmark_test.cc │ ├── sets_functions_test.cc │ ├── strings.cc │ ├── strings.h │ └── strings_test.cc ├── internal/ │ ├── BUILD │ ├── align.h │ ├── align_test.cc │ ├── benchmark.h │ ├── casts.h │ ├── empty_descriptors.cc │ ├── empty_descriptors.h │ ├── empty_descriptors_test.cc │ ├── equals_text_proto.cc │ ├── equals_text_proto.h │ ├── exceptions.h │ ├── json.cc │ ├── json.h │ ├── json_test.cc │ ├── lexis.cc │ ├── lexis.h │ ├── lexis_test.cc │ ├── manual.h │ ├── message_equality.cc │ ├── message_equality.h │ ├── message_equality_test.cc │ ├── message_type_name.h │ ├── message_type_name_test.cc │ ├── minimal_descriptor_database.h │ ├── minimal_descriptor_pool.h │ ├── minimal_descriptors.cc │ ├── names.cc │ ├── names.h │ ├── names_test.cc │ ├── new.cc │ ├── new.h │ ├── new_test.cc │ ├── noop_delete.h │ ├── number.h │ ├── number_test.cc │ ├── overflow.cc │ ├── overflow.h │ ├── overflow_test.cc │ ├── parse_text_proto.h │ ├── proto_file_util.h │ ├── proto_matchers.h │ ├── proto_time_encoding.cc │ ├── proto_time_encoding.h │ ├── proto_time_encoding_test.cc │ ├── proto_util.h │ ├── proto_util_test.cc │ ├── protobuf_runtime_version.h │ ├── re2_options.h │ ├── status_builder.h │ ├── status_macros.h │ ├── string_pool.cc │ ├── string_pool.h │ ├── string_pool_test.cc │ ├── strings.cc │ ├── strings.h │ ├── strings_test.cc │ ├── testing.cc │ ├── testing.h │ ├── testing_descriptor_pool.cc │ ├── testing_descriptor_pool.h │ ├── testing_descriptor_pool_test.cc │ ├── testing_message_factory.cc │ ├── testing_message_factory.h │ ├── time.cc │ ├── time.h │ ├── time_test.cc │ ├── to_address.h │ ├── to_address_test.cc │ ├── unicode.h │ ├── utf8.cc │ ├── utf8.h │ ├── utf8_test.cc │ ├── well_known_types.cc │ ├── well_known_types.h │ └── well_known_types_test.cc ├── parser/ │ ├── BUILD │ ├── internal/ │ │ ├── BUILD │ │ ├── Cel.g4 │ │ └── options.h │ ├── macro.cc │ ├── macro.h │ ├── macro_expr_factory.cc │ ├── macro_expr_factory.h │ ├── macro_expr_factory_test.cc │ ├── macro_registry.cc │ ├── macro_registry.h │ ├── macro_registry_test.cc │ ├── options.h │ ├── parser.cc │ ├── parser.h │ ├── parser_benchmarks.cc │ ├── parser_interface.h │ ├── parser_subset_factory.cc │ ├── parser_subset_factory.h │ ├── parser_test.cc │ ├── source_factory.h │ ├── standard_macros.cc │ ├── standard_macros.h │ └── standard_macros_test.cc ├── runtime/ │ ├── BUILD │ ├── activation.cc │ ├── activation.h │ ├── activation_interface.h │ ├── activation_test.cc │ ├── comprehension_vulnerability_check.cc │ ├── comprehension_vulnerability_check.h │ ├── comprehension_vulnerability_check_test.cc │ ├── constant_folding.cc │ ├── constant_folding.h │ ├── constant_folding_test.cc │ ├── embedder_context.h │ ├── embedder_context_test.cc │ ├── function.h │ ├── function_adapter.h │ ├── function_adapter_test.cc │ ├── function_overload_reference.h │ ├── function_provider.h │ ├── function_registry.cc │ ├── function_registry.h │ ├── function_registry_test.cc │ ├── internal/ │ │ ├── BUILD │ │ ├── activation_attribute_matcher_access.cc │ │ ├── activation_attribute_matcher_access.h │ │ ├── attribute_matcher.h │ │ ├── convert_constant.cc │ │ ├── convert_constant.h │ │ ├── errors.cc │ │ ├── errors.h │ │ ├── function_adapter.h │ │ ├── function_adapter_test.cc │ │ ├── issue_collector.h │ │ ├── issue_collector_test.cc │ │ ├── legacy_runtime_type_provider.h │ │ ├── runtime_env.cc │ │ ├── runtime_env.h │ │ ├── runtime_env_testing.cc │ │ ├── runtime_env_testing.h │ │ ├── runtime_friend_access.h │ │ ├── runtime_impl.cc │ │ ├── runtime_impl.h │ │ ├── runtime_type_provider.cc │ │ └── runtime_type_provider.h │ ├── memory_safety_test.cc │ ├── optional_types.cc │ ├── optional_types.h │ ├── optional_types_test.cc │ ├── reference_resolver.cc │ ├── reference_resolver.h │ ├── reference_resolver_test.cc │ ├── regex_precompilation.cc │ ├── regex_precompilation.h │ ├── regex_precompilation_test.cc │ ├── register_function_helper.h │ ├── runtime.h │ ├── runtime_builder.h │ ├── runtime_builder_factory.cc │ ├── runtime_builder_factory.h │ ├── runtime_issue.h │ ├── runtime_options.h │ ├── standard/ │ │ ├── BUILD │ │ ├── arithmetic_functions.cc │ │ ├── arithmetic_functions.h │ │ ├── arithmetic_functions_test.cc │ │ ├── comparison_functions.cc │ │ ├── comparison_functions.h │ │ ├── comparison_functions_test.cc │ │ ├── container_functions.cc │ │ ├── container_functions.h │ │ ├── container_functions_test.cc │ │ ├── container_membership_functions.cc │ │ ├── container_membership_functions.h │ │ ├── container_membership_functions_test.cc │ │ ├── equality_functions.cc │ │ ├── equality_functions.h │ │ ├── equality_functions_test.cc │ │ ├── logical_functions.cc │ │ ├── logical_functions.h │ │ ├── logical_functions_test.cc │ │ ├── regex_functions.cc │ │ ├── regex_functions.h │ │ ├── regex_functions_test.cc │ │ ├── string_functions.cc │ │ ├── string_functions.h │ │ ├── string_functions_test.cc │ │ ├── time_functions.cc │ │ ├── time_functions.h │ │ ├── time_functions_test.cc │ │ ├── type_conversion_functions.cc │ │ ├── type_conversion_functions.h │ │ └── type_conversion_functions_test.cc │ ├── standard_functions.cc │ ├── standard_functions.h │ ├── standard_runtime_builder_factory.cc │ ├── standard_runtime_builder_factory.h │ ├── standard_runtime_builder_factory_test.cc │ ├── type_registry.cc │ └── type_registry.h ├── testing/ │ └── testrunner/ │ ├── BUILD │ ├── cel_cc_test.bzl │ ├── cel_expression_source.h │ ├── cel_test_context.h │ ├── cel_test_factories.h │ ├── coverage_index.cc │ ├── coverage_index.h │ ├── coverage_index_test.cc │ ├── coverage_reporting.cc │ ├── coverage_reporting.h │ ├── resources/ │ │ ├── BUILD │ │ ├── simple_tests.textproto │ │ ├── test.cel │ │ └── test_environment.textproto │ ├── runner_bin.cc │ ├── runner_lib.cc │ ├── runner_lib.h │ ├── runner_lib_test.cc │ └── user_tests/ │ ├── BUILD │ ├── checked_expr_test.cc │ ├── raw_expr_and_cel_file_test.cc │ ├── raw_expression_test.cc │ └── simple.cc ├── testutil/ │ ├── BUILD │ ├── baseline_tests.cc │ ├── baseline_tests.h │ ├── baseline_tests_test.cc │ ├── expr_printer.cc │ ├── expr_printer.h │ ├── expr_printer_test.cc │ ├── test_json_names.proto │ └── util.h ├── tools/ │ ├── BUILD │ ├── branch_coverage.cc │ ├── branch_coverage.h │ ├── branch_coverage_test.cc │ ├── cel_field_extractor.cc │ ├── cel_field_extractor.h │ ├── cel_field_extractor_test.cc │ ├── cel_unparser.cc │ ├── cel_unparser.h │ ├── cel_unparser_test.cc │ ├── descriptor_pool_builder.cc │ ├── descriptor_pool_builder.h │ ├── descriptor_pool_builder_test.cc │ ├── flatbuffers_backed_impl.cc │ ├── flatbuffers_backed_impl.h │ ├── flatbuffers_backed_impl_test.cc │ ├── navigable_ast.cc │ ├── navigable_ast.h │ ├── navigable_ast_test.cc │ └── testdata/ │ ├── BUILD │ ├── checked_expr_and.textproto │ ├── const_str.textproto │ ├── coverage_example.textproto │ ├── exists_macro.textproto │ ├── flatbuffers.fbs │ ├── macro_multiple_references.textproto │ ├── macro_nested_macro_call.textproto │ ├── macro_single_reference.textproto │ ├── msg_new_field.textproto │ └── msg_new_field_int.textproto └── validator/ ├── BUILD ├── ast_depth_validator.cc ├── ast_depth_validator.h ├── ast_depth_validator_test.cc ├── comprehension_nesting_validator.cc ├── comprehension_nesting_validator.h ├── comprehension_nesting_validator_test.cc ├── homogeneous_literal_validator.cc ├── homogeneous_literal_validator.h ├── homogeneous_literal_validator_test.cc ├── regex_validator.cc ├── regex_validator.h ├── regex_validator_test.cc ├── timestamp_literal_validator.cc ├── timestamp_literal_validator.h ├── timestamp_literal_validator_test.cc ├── validator.cc ├── validator.h └── validator_test.cc ================================================ FILE CONTENTS ================================================ ================================================ FILE: .bazelrc ================================================ common --enable_platform_specific_config build --enable_bzlmod build --compilation_mode=fastbuild build:linux --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 build:linux --cxxopt=-fsized-deallocation build:linux --copt=-Wno-deprecated-declarations # you will typically need to spell out the compiler for local dev # BAZEL_VC= # BAZEL_VC_FULL_VERSION=14.44.3520 build:msvc --cxxopt="-std:c++20" --cxxopt="-utf-8" --host_cxxopt="-std:c++20" build:msvc --define=protobuf_allow_msvc=true build:msvc --test_tag_filters=-benchmark,-notap,-no_test_msvc build:msvc --build_tag_filters=-no_test_msvc build:macos --cxxopt=-faligned-allocation build:macos --cxxopt=-mmacosx-version-min=10.13 build:macos --linkopt=-mmacosx-version-min=10.13 # ANTLR tool requires Java 17+. build --java_runtime_version=remotejdk_17 test --test_output=errors # Enable matchers in googletest build --define absl=1 build:asan --linkopt -ldl build:asan --linkopt -fsanitize=address build:asan --copt -fsanitize=address build:asan --copt -DADDRESS_SANITIZER=1 build:asan --copt -D__SANITIZE_ADDRESS__ build:asan --test_env=ASAN_OPTIONS=handle_abort=1:allow_addr2line=true:check_initialization_order=true:strict_init_order=true:detect_odr_violation=1 build:asan --test_env=ASAN_SYMBOLIZER_PATH build:asan --copt -O1 build:asan --copt -fno-optimize-sibling-calls build:asan --linkopt=-fuse-ld=lld try-import %workspace%/clang.bazelrc try-import %workspace%/user.bazelrc try-import %workspace%/local_tsan.bazelrc ================================================ FILE: .bazelversion ================================================ 7.3.2 ================================================ FILE: .bcr/README.md ================================================ # BCR Publishing Templates This directory contains templates used by the [Publish to BCR](https://github.com/bazel-contrib/publish-to-bcr) GitHub Action to automatically publish new versions of cel-cpp to the [Bazel Central Registry (BCR)](https://github.com/bazelbuild/bazel-central-registry). ## Files - **metadata.template.json**: Contains repository metadata including homepage, maintainers, and repository location - **source.template.json**: Template for generating the source.json file that tells BCR where to download release archives - **presubmit.yml**: Defines build and test tasks that BCR will run to verify each published version ## How it works When a new tag matching the pattern `v*.*.*` is created: 1. The GitHub Actions workflow `.github/workflows/publish_to_bcr.yml` is triggered 2. The workflow uses these templates to generate a BCR entry 3. A pull request is automatically created against the Bazel Central Registry 4. Once merged, the new version becomes available to Bazel users via bzlmod ## Template Variables The following variables are automatically substituted: - `{OWNER}`: Repository owner (google) - `{REPO}`: Repository name (cel-cpp) - `{VERSION}`: Version number extracted from the tag (e.g., `0.14.0` from `v0.14.0`) - `{TAG}`: Full tag name (e.g., `v0.14.0`) ## More Information - [Publish to BCR documentation](https://github.com/bazel-contrib/publish-to-bcr) - [BCR documentation](https://bazel.build/external/registry) ================================================ FILE: .bcr/metadata.template.json ================================================ { "homepage": "https://cel.dev", "maintainers": [ { "email": "ferstl@intrinsic.ai", "github": "ferstlf", "github_user_id": 64520639, "name": "Florian Ferstl" }, { "email": "cel-lang-discuss@googlegroups.com", "github": "cel-expr", "github_user_id": 186625994, "name": "CEL Team" }, { "github": "jnthntatum", "github_user_id": 733856 }, { "github": "jcking", "github_user_id": 997958 }, { "github": "tristonianjones", "github_user_id": 483300 } ], "repository": [ "github:google/cel-cpp" ], "versions": [], "yanked_versions": {} } ================================================ FILE: .bcr/presubmit.yml ================================================ matrix: platform: - debian11 - ubuntu2004 bazel: - 8.x - 7.x tasks: verify_targets: name: Verify build targets platform: ${{ platform }} bazel: ${{ bazel }} build_flags: - '--cxxopt=-std=c++17' - '--host_cxxopt=-std=c++17' - '--copt=-Wno-deprecated-declarations' - '--define=absl=1' build_targets: - '@cel-cpp//...' ================================================ FILE: .bcr/source.template.json ================================================ { "integrity": "", "strip_prefix": "cel-cpp-{VERSION}", "url": "https://github.com/{OWNER}/{REPO}/archive/refs/tags/{TAG}.tar.gz" } ================================================ FILE: .github/workflows/publish_to_bcr.yml ================================================ name: Publish to BCR on: push: tags: - "v*.*.*" permissions: id-token: write attestations: write contents: write jobs: publish: uses: bazel-contrib/publish-to-bcr/.github/workflows/publish.yaml@v1.0.0 with: tag_name: ${{ github.ref_name }} secrets: publish_token: ${{ secrets.BCR_PUBLISH_TOKEN }} ================================================ FILE: .gitignore ================================================ bazel-bin bazel-eval bazel-genfiles bazel-out bazel-testlogs bazel-cel-cpp *~ clang.bazelrc user.bazelrc local_tsan.bazelrc MODULE.bazel.lock ================================================ FILE: BUILD.bazel ================================================ package(default_visibility = ["//visibility:public"]) ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Code of Conduct ## Version 0.1.1 (adapted from 0.3b-angular) As contributors and maintainers of the Common Expression Language (CEL) project, we pledge to respect everyone who contributes by posting issues, updating documentation, submitting pull requests, providing feedback in comments, and any other activities. Communication through any of CEL's channels (GitHub, Gitter, IRC, mailing lists, Google+, Twitter, etc.) must be constructive and never resort to personal attacks, trolling, public or private harassment, insults, or other unprofessional conduct. We promise to extend courtesy and respect to everyone involved in this project regardless of gender, gender identity, sexual orientation, disability, age, race, ethnicity, religion, or level of experience. We expect anyone contributing to the project to do the same. If any member of the community violates this code of conduct, the maintainers of the CEL project may take action, removing issues, comments, and PRs or blocking accounts as deemed appropriate. If you are subject to or witness unacceptable behavior, or have any other concerns, please email us at [cel-conduct@google.com](mailto:cel-conduct@google.com). ================================================ FILE: CONTRIBUTING.md ================================================ # How to Contribute We'd love to accept your patches and contributions to this project. There are a few guidelines you need to follow. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement. You (or your employer) retain the copyright to your contribution, this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ## Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ## What to expect from maintainers Expect maintainers to respond to new issues or pull requests within a week. For outstanding and ongoing issues and particularly for long-running pull requests, expect the maintainers to review within a week of a contributor asking for a new review. There is no commitment to resolution -- merging or closing a pull request, or fixing or closing an issue -- because some issues will require more discussion than others. ================================================ FILE: Dockerfile ================================================ # This Dockerfile is used to create a container around gcc9 and bazel for # building the CEL C++ library on GitHub. # # To update a new version of this container, use gcloud. You may need to run # `gcloud auth login` and `gcloud auth configure-docker` first. # # Note, if you need to run docker using `sudo` use the following commands # instead: # # sudo gcloud auth login --no-launch-browser # sudo gcloud auth configure-docker # # Run the following command from the root of the CEL repository: # # gcloud builds submit --region=us -t gcr.io/cel-analysis/gcc9 . # # Once complete get the sha256 digest from the output using the following # command: # # gcloud artifacts versions list --package=gcc9 --repository=gcr.io \ # --location=us # # The cloudbuild.yaml file must be updated to use the new digest like so: # # - name: 'gcr.io/cel-analysis/gcc9@' FROM gcc:9 # Install Bazel prerequesites and required tools. # See https://docs.bazel.build/versions/master/install-ubuntu.html RUN apt-get update && \ apt-get upgrade -y && \ apt-get install -y --no-install-recommends \ ca-certificates \ git \ libssl-dev \ make \ pkg-config \ python3 \ unzip \ wget \ zip \ zlib1g-dev \ default-jdk-headless \ clang-11 && \ apt-get clean # Install Bazel. # https://github.com/bazelbuild/bazel/releases ARG BAZEL_VERSION="7.3.2" ADD https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh /tmp/install_bazel.sh RUN /bin/bash /tmp/install_bazel.sh && rm /tmp/install_bazel.sh RUN mkdir -p /workspace RUN mkdir -p /bazel ENTRYPOINT ["/usr/local/bin/bazel"] ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MODULE.bazel ================================================ module( name = "cel-cpp", ) bazel_dep( name = "bazel_skylib", version = "1.9.0", ) bazel_dep( name = "googleapis", version = "0.0.0-20241220-5e258e33.bcr.1", repo_name = "com_google_googleapis", ) bazel_dep( name = "googleapis-cc", version = "1.0.0", ) bazel_dep( name = "rules_cc", version = "0.2.14", ) bazel_dep( name = "rules_java", version = "8.6.1", ) bazel_dep( name = "rules_proto", version = "7.1.0", ) bazel_dep( name = "rules_python", version = "1.6.3", ) bazel_dep( name = "protobuf", version = "33.4", repo_name = "com_google_protobuf", ) bazel_dep( name = "abseil-cpp", version = "20260107.0", repo_name = "com_google_absl", ) bazel_dep( name = "googletest", version = "1.17.0.bcr.2", dev_dependency = True, repo_name = "com_google_googletest", ) bazel_dep( name = "google_benchmark", version = "1.9.2", dev_dependency = True, repo_name = "com_github_google_benchmark", ) bazel_dep( name = "re2", version = "2025-11-05.bcr.1", repo_name = "com_googlesource_code_re2", ) bazel_dep( name = "flatbuffers", version = "25.9.23", repo_name = "com_github_google_flatbuffers", ) bazel_dep( name = "cel-spec", version = "0.25.1", repo_name = "com_google_cel_spec", ) bazel_dep( name = "platforms", version = "1.0.0", ) ANTLR4_VERSION = "4.13.2" bazel_dep( name = "antlr4-cpp-runtime", version = ANTLR4_VERSION, ) single_version_override( module_name = "antlr4-cpp-runtime", patches = ["//bazel:antlr.patch"], ) python = use_extension("@rules_python//python/extensions:python.bzl", "python") python.toolchain( configure_coverage_tool = False, ignore_root_user_error = True, python_version = "3.11", ) http_jar = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_jar") http_jar( name = "antlr4_jar", sha256 = "eae2dfa119a64327444672aff63e9ec35a20180dc5b8090b7a6ab85125df4d76", urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], ) bazel_dep( name = "yaml-cpp", version = "0.9.0", ) ================================================ FILE: README.md ================================================ # C++ Implementations of the Common Expression Language For background on the Common Expression Language see the [cel-spec][1] repo. This is a C++ implementation of a [Common Expression Language][1] runtime, parser, and type checker. Released under the [Apache License](LICENSE). [1]: https://github.com/google/cel-spec ================================================ FILE: base/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package( # Under active development, not yet being released. default_visibility = ["//visibility:public"], ) licenses(["notice"]) cc_library( name = "attributes", srcs = [ "attribute.cc", ], hdrs = [ "attribute.h", "attribute_set.h", ], deps = [ ":kind", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) cc_library( name = "kind", hdrs = ["kind.h"], deps = [ "//common:kind", "//common:type_kind", "//common:value_kind", ], ) cc_library( name = "operators", srcs = ["operators.cc"], hdrs = ["operators.h"], deps = [ "//base/internal:operators", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) cc_test( name = "operators_test", srcs = ["operators_test.cc"], deps = [ ":operators", "//base/internal:operators", "//internal:testing", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) # Build target encompassing cel::Type, cel::Value, and their related classes. cc_library( name = "data", hdrs = [ "type_provider.h", ], deps = [ "//common:value", ], ) cc_library( name = "function", hdrs = [ "function.h", ], deps = [ "//runtime:function", ], ) cc_library( name = "function_descriptor", hdrs = [ "function_descriptor.h", ], deps = [ "//common:function_descriptor", ], ) cc_library( name = "function_result", hdrs = [ "function_result.h", ], deps = [":function_descriptor"], ) cc_library( name = "function_result_set", srcs = [ "function_result_set.cc", ], hdrs = [ "function_result_set.h", ], deps = [ ":function_result", "@com_google_absl//absl/container:btree", ], ) cc_library( name = "ast", hdrs = ["ast.h"], deps = ["//common:ast"], ) cc_library( name = "function_adapter", hdrs = ["function_adapter.h"], deps = [ "//runtime:function_adapter", ], ) cc_library( name = "builtins", hdrs = ["builtins.h"], ) ================================================ FILE: base/ast.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_AST_H_ #define THIRD_PARTY_CEL_CPP_BASE_AST_H_ #include "common/ast.h" // IWYU pragma: export #endif // THIRD_PARTY_CEL_CPP_BASE_AST_H_ ================================================ FILE: base/attribute.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "base/attribute.h" #include #include #include #include "absl/base/macros.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/variant.h" #include "base/kind.h" #include "internal/status_macros.h" namespace cel { namespace { // Visitor for appending string representation for different qualifier kinds. class AttributeStringPrinter { public: // String representation for the given qualifier is appended to output. // output must be non-null. explicit AttributeStringPrinter(std::string* output, Kind type) : output_(*output), type_(type) {} absl::Status operator()(const Kind& ignored) const { // Attributes are represented as a variant, with illegal attribute // qualifiers represented with their type as the first alternative. return absl::InvalidArgumentError( absl::StrCat("Unsupported attribute qualifier ", KindToString(type_))); } absl::Status operator()(int64_t index) { absl::StrAppend(&output_, "[", index, "]"); return absl::OkStatus(); } absl::Status operator()(uint64_t index) { absl::StrAppend(&output_, "[", index, "]"); return absl::OkStatus(); } absl::Status operator()(bool bool_key) { absl::StrAppend(&output_, "[", (bool_key) ? "true" : "false", "]"); return absl::OkStatus(); } absl::Status operator()(const std::string& field) { absl::StrAppend(&output_, ".", field); return absl::OkStatus(); } private: std::string& output_; Kind type_; }; // Visitor for appending string representation for different qualifier kinds. class AttributeQualifierStringPrinter { public: // String representation for the given qualifier is appended to output. explicit AttributeQualifierStringPrinter(std::string* absl_nonnull output, Kind type) : output_(*output), type_(type) {} absl::Status operator()(const Kind& ignored) const { // Attributes are represented as a variant, with illegal attribute // qualifiers represented with their type as the first alternative. return absl::InvalidArgumentError( absl::StrCat("Unsupported attribute qualifier ", KindToString(type_))); } absl::Status operator()(int64_t index) { absl::StrAppend(&output_, index); return absl::OkStatus(); } absl::Status operator()(uint64_t index) { absl::StrAppend(&output_, index); return absl::OkStatus(); } absl::Status operator()(bool bool_key) { absl::StrAppend(&output_, (bool_key) ? "true" : "false"); return absl::OkStatus(); } absl::Status operator()(const std::string& field) { absl::StrAppend(&output_, field); return absl::OkStatus(); } private: std::string& output_; Kind type_; }; struct AttributeQualifierTypeVisitor final { Kind operator()(const Kind& type) const { return type; } Kind operator()(int64_t ignored) const { static_cast(ignored); return Kind::kInt64; } Kind operator()(uint64_t ignored) const { static_cast(ignored); return Kind::kUint64; } Kind operator()(const std::string& ignored) const { static_cast(ignored); return Kind::kString; } Kind operator()(bool ignored) const { static_cast(ignored); return Kind::kBool; } }; struct AttributeQualifierTypeComparator final { const Kind lhs; bool operator()(const Kind& rhs) const { return static_cast(lhs) < static_cast(rhs); } bool operator()(int64_t) const { return false; } bool operator()(uint64_t other) const { return false; } bool operator()(const std::string&) const { return false; } bool operator()(bool other) const { return false; } }; struct AttributeQualifierIntComparator final { const int64_t lhs; bool operator()(const Kind&) const { return true; } bool operator()(int64_t rhs) const { return lhs < rhs; } bool operator()(uint64_t) const { return true; } bool operator()(const std::string&) const { return true; } bool operator()(bool) const { return false; } }; struct AttributeQualifierUintComparator final { const uint64_t lhs; bool operator()(const Kind&) const { return true; } bool operator()(int64_t) const { return false; } bool operator()(uint64_t rhs) const { return lhs < rhs; } bool operator()(const std::string&) const { return true; } bool operator()(bool) const { return false; } }; struct AttributeQualifierStringComparator final { const std::string& lhs; bool operator()(const Kind&) const { return true; } bool operator()(int64_t) const { return false; } bool operator()(uint64_t) const { return false; } bool operator()(const std::string& rhs) const { return lhs < rhs; } bool operator()(bool) const { return false; } }; struct AttributeQualifierBoolComparator final { const bool lhs; bool operator()(const Kind&) const { return true; } bool operator()(int64_t) const { return true; } bool operator()(uint64_t) const { return true; } bool operator()(const std::string&) const { return true; } bool operator()(bool rhs) const { return lhs < rhs; } }; } // namespace struct AttributeQualifier::ComparatorVisitor final { const AttributeQualifier::Variant& rhs; bool operator()(const Kind& lhs) const { return absl::visit(AttributeQualifierTypeComparator{lhs}, rhs); } bool operator()(int64_t lhs) const { return absl::visit(AttributeQualifierIntComparator{lhs}, rhs); } bool operator()(uint64_t lhs) const { return absl::visit(AttributeQualifierUintComparator{lhs}, rhs); } bool operator()(const std::string& lhs) const { return absl::visit(AttributeQualifierStringComparator{lhs}, rhs); } bool operator()(bool lhs) const { return absl::visit(AttributeQualifierBoolComparator{lhs}, rhs); } }; Kind AttributeQualifier::kind() const { return absl::visit(AttributeQualifierTypeVisitor{}, value_); } bool AttributeQualifier::operator<(const AttributeQualifier& other) const { // The order is not publicly documented because it is subject to change. // Currently we sort in the following order, with each type being sorted // against itself: bool, int, uint, string, type. return absl::visit(ComparatorVisitor{other.value_}, value_); } bool Attribute::operator==(const Attribute& other) const { // We cannot check pointer equality as a short circuit because we have to // treat all invalid AttributeQualifier as not equal to each other. // TODO(issues/41) we only support Ident-rooted attributes at the moment. if (variable_name() != other.variable_name()) { return false; } if (qualifier_path().size() != other.qualifier_path().size()) { return false; } for (size_t i = 0; i < qualifier_path().size(); i++) { if (!(qualifier_path()[i] == other.qualifier_path()[i])) { return false; } } return true; } bool Attribute::operator<(const Attribute& other) const { if (impl_.get() == other.impl_.get()) { return false; } auto lhs_begin = qualifier_path().begin(); auto lhs_end = qualifier_path().end(); auto rhs_begin = other.qualifier_path().begin(); auto rhs_end = other.qualifier_path().end(); while (lhs_begin != lhs_end && rhs_begin != rhs_end) { if (*lhs_begin < *rhs_begin) { return true; } if (!(*lhs_begin == *rhs_begin)) { return false; } lhs_begin++; rhs_begin++; } if (lhs_begin == lhs_end && rhs_begin == rhs_end) { // Neither has any elements left, they are equal. Compare variable names. return variable_name() < other.variable_name(); } if (lhs_begin == lhs_end) { // Left has no more elements. Right is greater. return true; } // Right has no more elements. Left is greater. ABSL_ASSERT(rhs_begin == rhs_end); return false; } const absl::StatusOr Attribute::AsString() const { if (variable_name().empty()) { return absl::InvalidArgumentError( "Only ident rooted attributes are supported."); } std::string result = std::string(variable_name()); for (const auto& qualifier : qualifier_path()) { CEL_RETURN_IF_ERROR(absl::visit( AttributeStringPrinter(&result, qualifier.kind()), qualifier.value_)); } return result; } bool AttributeQualifier::IsMatch(const AttributeQualifier& other) const { if (absl::holds_alternative(value_) || absl::holds_alternative(other.value_)) { return false; } return value_ == other.value_; } absl::StatusOr AttributeQualifier::AsString() const { std::string result; CEL_RETURN_IF_ERROR( absl::visit(AttributeQualifierStringPrinter(&result, kind()), value_)); return result; } } // namespace cel ================================================ FILE: base/attribute.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_H_ #define THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_H_ #include #include #include #include #include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/kind.h" namespace cel { // AttributeQualifier represents a segment in // attribute resolutuion path. A segment can be qualified by values of // following types: string/int64_t/uint64_t/bool. class AttributeQualifier final { private: struct ComparatorVisitor; using Variant = absl::variant; public: static AttributeQualifier OfInt(int64_t value) { return AttributeQualifier(absl::in_place_type, std::move(value)); } static AttributeQualifier OfUint(uint64_t value) { return AttributeQualifier(absl::in_place_type, std::move(value)); } static AttributeQualifier OfString(std::string value) { return AttributeQualifier(absl::in_place_type, std::move(value)); } static AttributeQualifier OfBool(bool value) { return AttributeQualifier(absl::in_place_type, std::move(value)); } AttributeQualifier() = default; AttributeQualifier(const AttributeQualifier&) = default; AttributeQualifier(AttributeQualifier&&) = default; AttributeQualifier& operator=(const AttributeQualifier&) = default; AttributeQualifier& operator=(AttributeQualifier&&) = default; Kind kind() const; // Family of Get... methods. Return values if requested type matches the // stored one. absl::optional GetInt64Key() const { return absl::holds_alternative(value_) ? absl::optional(absl::get<1>(value_)) : absl::nullopt; } absl::optional GetUint64Key() const { return absl::holds_alternative(value_) ? absl::optional(absl::get<2>(value_)) : absl::nullopt; } absl::optional GetStringKey() const { return absl::holds_alternative(value_) ? absl::optional(absl::get<3>(value_)) : absl::nullopt; } absl::optional GetBoolKey() const { return absl::holds_alternative(value_) ? absl::optional(absl::get<4>(value_)) : absl::nullopt; } bool operator==(const AttributeQualifier& other) const { return IsMatch(other); } bool operator<(const AttributeQualifier& other) const; bool IsMatch(absl::string_view other_key) const { absl::optional key = GetStringKey(); return (key.has_value() && key.value() == other_key); } absl::StatusOr AsString() const; private: friend class Attribute; friend struct ComparatorVisitor; template AttributeQualifier(absl::in_place_type_t in_place_type, T&& value) : value_(in_place_type, std::forward(value)) {} bool IsMatch(const AttributeQualifier& other) const; // The previous implementation of Attribute preserved all value // instances, regardless of whether they are supported in this context or not. // We represented unsupported types by using the first alternative and thus // preserve backwards compatibility with the result of `type()` above. Variant value_; }; // AttributeQualifierPattern matches a segment in // attribute resolutuion path. AttributeQualifierPattern is capable of // matching path elements of types string/int64/uint64/bool. class AttributeQualifierPattern final { private: // Qualifier value. If not set, treated as wildcard. std::optional value_; explicit AttributeQualifierPattern(std::optional value) : value_(std::move(value)) {} public: static AttributeQualifierPattern OfInt(int64_t value) { return AttributeQualifierPattern(AttributeQualifier::OfInt(value)); } static AttributeQualifierPattern OfUint(uint64_t value) { return AttributeQualifierPattern(AttributeQualifier::OfUint(value)); } static AttributeQualifierPattern OfString(std::string value) { return AttributeQualifierPattern( AttributeQualifier::OfString(std::move(value))); } static AttributeQualifierPattern OfBool(bool value) { return AttributeQualifierPattern(AttributeQualifier::OfBool(value)); } static AttributeQualifierPattern CreateWildcard() { return AttributeQualifierPattern(std::nullopt); } explicit AttributeQualifierPattern(AttributeQualifier qualifier) : AttributeQualifierPattern( std::optional(std::move(qualifier))) {} bool IsWildcard() const { return !value_.has_value(); } bool IsMatch(const AttributeQualifier& qualifier) const { if (IsWildcard()) return true; return value_.value() == qualifier; } bool IsMatch(absl::string_view other_key) const { if (!value_.has_value()) return true; return value_->IsMatch(other_key); } }; // Attribute represents resolved attribute path. class Attribute final { public: explicit Attribute(std::string variable_name) : Attribute(std::move(variable_name), {}) {} Attribute(std::string variable_name, std::vector qualifier_path) : impl_(std::make_shared(std::move(variable_name), std::move(qualifier_path))) {} absl::string_view variable_name() const { return impl_->variable_name; } bool has_variable_name() const { return !impl_->variable_name.empty(); } absl::Span qualifier_path() const { return impl_->qualifier_path; } bool operator==(const Attribute& other) const; bool operator<(const Attribute& other) const; const absl::StatusOr AsString() const; private: struct Impl final { Impl(std::string variable_name, std::vector qualifier_path) : variable_name(std::move(variable_name)), qualifier_path(std::move(qualifier_path)) {} std::string variable_name; std::vector qualifier_path; }; std::shared_ptr impl_; }; // AttributePattern is a fully-qualified absolute attribute path pattern. // Supported segments steps in the path are: // - field selection; // - map lookup by key; // - list access by index. class AttributePattern final { public: // MatchType enum specifies how closely pattern is matching the attribute: enum class MatchType { NONE, // Pattern does not match attribute itself nor its children PARTIAL, // Pattern matches an entity nested within attribute; FULL // Pattern matches an attribute itself. }; AttributePattern(std::string variable, std::vector qualifier_path) : variable_(std::move(variable)), qualifier_path_(std::move(qualifier_path)) {} absl::string_view variable() const { return variable_; } absl::Span qualifier_path() const { return qualifier_path_; } // Matches the pattern to an attribute. // Distinguishes between no-match, partial match and full match cases. MatchType IsMatch(const Attribute& attribute) const { MatchType result = MatchType::NONE; if (attribute.variable_name() != variable_) { return result; } auto max_index = qualifier_path().size(); result = MatchType::FULL; if (qualifier_path().size() > attribute.qualifier_path().size()) { max_index = attribute.qualifier_path().size(); result = MatchType::PARTIAL; } for (size_t i = 0; i < max_index; i++) { if (!(qualifier_path()[i].IsMatch(attribute.qualifier_path()[i]))) { return MatchType::NONE; } } return result; } private: std::string variable_; std::vector qualifier_path_; }; struct FieldSpecifier { int64_t number; std::string name; }; using SelectQualifier = absl::variant; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_H_ ================================================ FILE: base/attribute_set.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_SET_H_ #define THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_SET_H_ #include "absl/container/btree_set.h" #include "absl/types/span.h" #include "base/attribute.h" namespace google::api::expr::runtime { class AttributeUtility; } // namespace google::api::expr::runtime namespace cel { class UnknownValue; namespace base_internal { class UnknownSet; } // AttributeSet is a container for CEL attributes that are identified as // unknown during expression evaluation. class AttributeSet final { private: using Container = absl::btree_set; public: using value_type = typename Container::value_type; using size_type = typename Container::size_type; using iterator = typename Container::const_iterator; using const_iterator = typename Container::const_iterator; AttributeSet() = default; AttributeSet(const AttributeSet&) = default; AttributeSet(AttributeSet&&) = default; AttributeSet& operator=(const AttributeSet&) = default; AttributeSet& operator=(AttributeSet&&) = default; explicit AttributeSet(absl::Span attributes) { for (const auto& attr : attributes) { Add(attr); } } AttributeSet(const AttributeSet& set1, const AttributeSet& set2) : attributes_(set1.attributes_) { for (const auto& attr : set2.attributes_) { Add(attr); } } iterator begin() const { return attributes_.begin(); } const_iterator cbegin() const { return attributes_.cbegin(); } iterator end() const { return attributes_.end(); } const_iterator cend() const { return attributes_.cend(); } size_type size() const { return attributes_.size(); } bool empty() const { return attributes_.empty(); } bool operator==(const AttributeSet& other) const { return this == &other || attributes_ == other.attributes_; } bool operator!=(const AttributeSet& other) const { return !operator==(other); } static AttributeSet Merge(const AttributeSet& set1, const AttributeSet& set2) { return AttributeSet(set1, set2); } private: friend class google::api::expr::runtime::AttributeUtility; friend class UnknownValue; friend class base_internal::UnknownSet; void Add(const Attribute& attribute) { attributes_.insert(attribute); } void Add(const AttributeSet& other) { for (const auto& attribute : other) { Add(attribute); } } // Attribute container. Container attributes_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_ATTRIBUTE_SET_H_ ================================================ FILE: base/builtins.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ #define THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ namespace cel { // Constants specifying names for CEL builtins. // // Prefer to use the constants in `common/standard_definitions.h`. namespace builtin { // Comparison constexpr char kEqual[] = "_==_"; constexpr char kInequal[] = "_!=_"; constexpr char kLess[] = "_<_"; constexpr char kLessOrEqual[] = "_<=_"; constexpr char kGreater[] = "_>_"; constexpr char kGreaterOrEqual[] = "_>=_"; // Logical constexpr char kAnd[] = "_&&_"; constexpr char kOr[] = "_||_"; constexpr char kNot[] = "!_"; // Strictness constexpr char kNotStrictlyFalse[] = "@not_strictly_false"; // Deprecated '__not_strictly_false__' function. Preserved for backwards // compatibility with stored expressions. constexpr char kNotStrictlyFalseDeprecated[] = "__not_strictly_false__"; // Arithmetical constexpr char kAdd[] = "_+_"; constexpr char kSubtract[] = "_-_"; constexpr char kNeg[] = "-_"; constexpr char kMultiply[] = "_*_"; constexpr char kDivide[] = "_/_"; constexpr char kModulo[] = "_%_"; // String operations constexpr char kRegexMatch[] = "matches"; constexpr char kStringContains[] = "contains"; constexpr char kStringEndsWith[] = "endsWith"; constexpr char kStringStartsWith[] = "startsWith"; // Container operations constexpr char kIn[] = "@in"; // Deprecated '_in_' operator. Preserved for backwards compatibility with stored // expressions. constexpr char kInDeprecated[] = "_in_"; // Deprecated 'in()' function. Preserved for backwards compatibility with stored // expressions. constexpr char kInFunction[] = "in"; constexpr char kIndex[] = "_[_]"; constexpr char kSize[] = "size"; constexpr char kTernary[] = "_?_:_"; // Timestamp and Duration constexpr char kDuration[] = "duration"; constexpr char kTimestamp[] = "timestamp"; constexpr char kFullYear[] = "getFullYear"; constexpr char kMonth[] = "getMonth"; constexpr char kDayOfYear[] = "getDayOfYear"; constexpr char kDayOfMonth[] = "getDayOfMonth"; constexpr char kDate[] = "getDate"; constexpr char kDayOfWeek[] = "getDayOfWeek"; constexpr char kHours[] = "getHours"; constexpr char kMinutes[] = "getMinutes"; constexpr char kSeconds[] = "getSeconds"; constexpr char kMilliseconds[] = "getMilliseconds"; // Type conversions constexpr char kBool[] = "bool"; constexpr char kBytes[] = "bytes"; constexpr char kDouble[] = "double"; constexpr char kDyn[] = "dyn"; constexpr char kInt[] = "int"; constexpr char kString[] = "string"; constexpr char kType[] = "type"; constexpr char kUint[] = "uint"; // Runtime-only functions. // The convention for runtime-only functions where only the runtime needs to // differentiate behavior is to prefix the function with `#`. // Note, this is a different convention from CEL internal functions where the // whole stack needs to be aware of the function id. constexpr char kRuntimeListAppend[] = "#list_append"; } // namespace builtin } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_BUILTINS_H_ ================================================ FILE: base/function.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ #define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ #include "runtime/function.h" // IWYU pragma: export #endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_H_ ================================================ FILE: base/function_adapter.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_ADAPTER_H_ #include "runtime/function_adapter.h" // IWYU pragma: export #endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_ADAPTER_H_ ================================================ FILE: base/function_descriptor.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ #define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ #include "common/function_descriptor.h" // IWYU pragma: export #endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_DESCRIPTOR_H_ ================================================ FILE: base/function_result.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_H_ #define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_H_ #include #include #include "base/function_descriptor.h" namespace cel { // Represents a function result that is unknown at the time of execution. This // allows for lazy evaluation of expensive functions. class FunctionResult final { public: FunctionResult() = delete; FunctionResult(const FunctionResult&) = default; FunctionResult(FunctionResult&&) = default; FunctionResult& operator=(const FunctionResult&) = default; FunctionResult& operator=(FunctionResult&&) = default; FunctionResult(FunctionDescriptor descriptor, int64_t expr_id) : descriptor_(std::move(descriptor)), expr_id_(expr_id) {} // The descriptor of the called function that return Unknown. const FunctionDescriptor& descriptor() const { return descriptor_; } // The id of the |Expr| that triggered the function call step. Provided // informationally -- if two different |Expr|s generate the same unknown call, // they will be treated as the same unknown function result. int64_t call_expr_id() const { return expr_id_; } // Equality operator provided for testing. Compatible with set less-than // comparator. // Compares descriptor then arguments elementwise. bool IsEqualTo(const FunctionResult& other) const { return descriptor() == other.descriptor(); } // TODO(uncreated-issue/5): re-implement argument capture private: FunctionDescriptor descriptor_; int64_t expr_id_; }; inline bool operator==(const FunctionResult& lhs, const FunctionResult& rhs) { return lhs.IsEqualTo(rhs); } inline bool operator<(const FunctionResult& lhs, const FunctionResult& rhs) { return lhs.descriptor() < rhs.descriptor(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_H_ ================================================ FILE: base/function_result_set.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "base/function_result_set.h" namespace cel { // Implementation for merge constructor. FunctionResultSet::FunctionResultSet(const FunctionResultSet& lhs, const FunctionResultSet& rhs) : function_results_(lhs.function_results_) { for (const auto& function_result : rhs) { function_results_.insert(function_result); } } } // namespace cel ================================================ FILE: base/function_result_set.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_SET_H_ #define THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_SET_H_ #include #include #include "absl/container/btree_set.h" #include "base/function_result.h" namespace google::api::expr::runtime { class AttributeUtility; } // namespace google::api::expr::runtime namespace cel { class UnknownValue; namespace base_internal { class UnknownSet; } // Represents a collection of unknown function results at a particular point in // execution. Execution should advance further if this set of unknowns are // provided. It may not advance if only a subset are provided. // Set semantics use |IsEqualTo()| defined on |FunctionResult|. class FunctionResultSet final { private: using Container = absl::btree_set; public: using value_type = typename Container::value_type; using size_type = typename Container::size_type; using iterator = typename Container::const_iterator; using const_iterator = typename Container::const_iterator; FunctionResultSet() = default; FunctionResultSet(const FunctionResultSet&) = default; FunctionResultSet(FunctionResultSet&&) = default; FunctionResultSet& operator=(const FunctionResultSet&) = default; FunctionResultSet& operator=(FunctionResultSet&&) = default; // Merge constructor -- effectively union(lhs, rhs). FunctionResultSet(const FunctionResultSet& lhs, const FunctionResultSet& rhs); // Initialize with a single FunctionResult. explicit FunctionResultSet(FunctionResult initial) : function_results_{std::move(initial)} {} FunctionResultSet(std::initializer_list il) : function_results_(il) {} iterator begin() const { return function_results_.begin(); } const_iterator cbegin() const { return function_results_.cbegin(); } iterator end() const { return function_results_.end(); } const_iterator cend() const { return function_results_.cend(); } size_type size() const { return function_results_.size(); } bool empty() const { return function_results_.empty(); } bool operator==(const FunctionResultSet& other) const { return this == &other || function_results_ == other.function_results_; } bool operator!=(const FunctionResultSet& other) const { return !operator==(other); } private: friend class google::api::expr::runtime::AttributeUtility; friend class UnknownValue; friend class base_internal::UnknownSet; void Add(const FunctionResult& function_result) { function_results_.insert(function_result); } void Add(const FunctionResultSet& other) { for (const auto& function_result : other) { Add(function_result); } } Container function_results_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_FUNCTION_RESULT_SET_H_ ================================================ FILE: base/internal/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "memory_manager_testing", testonly = True, srcs = ["memory_manager_testing.cc"], hdrs = ["memory_manager_testing.h"], deps = [ "//internal:testing", ], ) cc_library( name = "message_wrapper", hdrs = ["message_wrapper.h"], ) cc_library( name = "operators", hdrs = ["operators.h"], deps = [ "@com_google_absl//absl/strings", ], ) cc_library( name = "unknown_set", srcs = ["unknown_set.cc"], hdrs = ["unknown_set.h"], deps = [ "//base:attributes", "//base:function_result_set", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", ], ) ================================================ FILE: base/internal/memory_manager_testing.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "base/internal/memory_manager_testing.h" #include namespace cel::base_internal { std::string MemoryManagerTestModeToString(MemoryManagerTestMode mode) { switch (mode) { case MemoryManagerTestMode::kGlobal: return "Global"; case MemoryManagerTestMode::kArena: return "Arena"; } } } // namespace cel::base_internal ================================================ FILE: base/internal/memory_manager_testing.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ #include #include #include "internal/testing.h" namespace cel::base_internal { enum class MemoryManagerTestMode { kGlobal = 0, kArena, }; std::string MemoryManagerTestModeToString(MemoryManagerTestMode mode); template void AbslStringify(S& sink, MemoryManagerTestMode mode) { sink.Append(MemoryManagerTestModeToString(mode)); } inline auto MemoryManagerTestModeAll() { return testing::Values(MemoryManagerTestMode::kGlobal, MemoryManagerTestMode::kArena); } inline std::string MemoryManagerTestModeName( const testing::TestParamInfo& info) { return MemoryManagerTestModeToString(info.param); } inline std::string MemoryManagerTestModeTupleName( const testing::TestParamInfo>& info) { return MemoryManagerTestModeToString(std::get<0>(info.param)); } } // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MEMORY_MANAGER_TESTING_H_ ================================================ FILE: base/internal/message_wrapper.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MESSAGE_WRAPPER_H_ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MESSAGE_WRAPPER_H_ #include namespace cel::base_internal { inline constexpr uintptr_t kMessageWrapperTagMask = 0b1; inline constexpr uintptr_t kMessageWrapperPtrMask = ~kMessageWrapperTagMask; inline constexpr int kMessageWrapperTagSize = 1; inline constexpr uintptr_t kMessageWrapperTagTypeInfoValue = 0b0; inline constexpr uintptr_t kMessageWrapperTagMessageValue = 0b1; } // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_MESSAGE_WRAPPER_H_ ================================================ FILE: base/internal/operators.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ #include "absl/strings/string_view.h" namespace cel { enum class OperatorId; namespace base_internal { struct OperatorData final { OperatorData() = delete; OperatorData(const OperatorData&) = delete; OperatorData(OperatorData&&) = delete; OperatorData& operator=(const OperatorData&) = delete; OperatorData& operator=(OperatorData&&) = delete; constexpr OperatorData(cel::OperatorId id, absl::string_view name, absl::string_view display_name, int precedence, int arity) : id(id), name(name), display_name(display_name), precedence(precedence), arity(arity) {} const cel::OperatorId id; const absl::string_view name; const absl::string_view display_name; const int precedence; const int arity; }; #define CEL_INTERNAL_UNARY_OPERATORS_ENUM(XX) \ XX(LogicalNot, "!", "!_", 2, 1) \ XX(Negate, "-", "-_", 2, 1) \ XX(NotStrictlyFalse, "", "@not_strictly_false", 0, 1) \ XX(OldNotStrictlyFalse, "", "__not_strictly_false__", 0, 1) #define CEL_INTERNAL_BINARY_OPERATORS_ENUM(XX) \ XX(Equals, "==", "_==_", 5, 2) \ XX(NotEquals, "!=", "_!=_", 5, 2) \ XX(Less, "<", "_<_", 5, 2) \ XX(LessEquals, "<=", "_<=_", 5, 2) \ XX(Greater, ">", "_>_", 5, 2) \ XX(GreaterEquals, ">=", "_>=_", 5, 2) \ XX(In, "in", "@in", 5, 2) \ XX(OldIn, "in", "_in_", 5, 2) \ XX(Index, "", "_[_]", 1, 2) \ XX(LogicalOr, "||", "_||_", 7, 2) \ XX(LogicalAnd, "&&", "_&&_", 6, 2) \ XX(Add, "+", "_+_", 4, 2) \ XX(Subtract, "-", "_-_", 4, 2) \ XX(Multiply, "*", "_*_", 3, 2) \ XX(Divide, "/", "_/_", 3, 2) \ XX(Modulo, "%", "_%_", 3, 2) #define CEL_INTERNAL_TERNARY_OPERATORS_ENUM(XX) \ XX(Conditional, "", "_?_:_", 8, 3) // Macro definining all the operators and their properties. // (1) - The identifier. // (2) - The display name if applicable, otherwise an empty string. // (3) - The name. // (4) - The precedence if applicable, otherwise 0. // (5) - The arity. #define CEL_INTERNAL_OPERATORS_ENUM(XX) \ CEL_INTERNAL_TERNARY_OPERATORS_ENUM(XX) \ CEL_INTERNAL_BINARY_OPERATORS_ENUM(XX) \ CEL_INTERNAL_UNARY_OPERATORS_ENUM(XX) } // namespace base_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_OPERATORS_H_ ================================================ FILE: base/internal/unknown_set.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "base/internal/unknown_set.h" #include "absl/base/no_destructor.h" namespace cel::base_internal { const AttributeSet& EmptyAttributeSet() { static const absl::NoDestructor empty_attribute_set; return *empty_attribute_set; } const FunctionResultSet& EmptyFunctionResultSet() { static const absl::NoDestructor empty_function_result_set; return *empty_function_result_set; } } // namespace cel::base_internal ================================================ FILE: base/internal/unknown_set.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_INTERNAL_UNKNOWN_SET_H_ #define THIRD_PARTY_CEL_CPP_BASE_INTERNAL_UNKNOWN_SET_H_ #include #include #include "absl/base/attributes.h" #include "base/attribute_set.h" #include "base/function_result_set.h" namespace cel::base_internal { // For compatibility with the old API and to avoid unnecessary copying when // converting between the old and new representations, we store the historical // members of google::api::expr::runtime::UnknownSet in this struct for use with // std::shared_ptr. struct UnknownSetRep final { UnknownSetRep() = default; UnknownSetRep(AttributeSet attributes, FunctionResultSet function_results) : attributes(std::move(attributes)), function_results(std::move(function_results)) {} explicit UnknownSetRep(AttributeSet attributes) : attributes(std::move(attributes)) {} explicit UnknownSetRep(FunctionResultSet function_results) : function_results(std::move(function_results)) {} AttributeSet attributes; FunctionResultSet function_results; }; const AttributeSet& EmptyAttributeSet(); const FunctionResultSet& EmptyFunctionResultSet(); struct UnknownSetAccess; class UnknownSet final { private: using Rep = UnknownSetRep; public: // Construct the empty set. // Uses singletons instead of allocating new containers. UnknownSet() = default; UnknownSet(const UnknownSet&) = default; UnknownSet(UnknownSet&&) = default; UnknownSet& operator=(const UnknownSet&) = default; UnknownSet& operator=(UnknownSet&&) = default; // Initialization specifying subcontainers explicit UnknownSet(AttributeSet attributes) : rep_(std::make_shared(std::move(attributes))) {} explicit UnknownSet(FunctionResultSet function_results) : rep_(std::make_shared(std::move(function_results))) {} UnknownSet(AttributeSet attributes, FunctionResultSet function_results) : rep_(std::make_shared(std::move(attributes), std::move(function_results))) {} // Merge constructor UnknownSet(const UnknownSet& set1, const UnknownSet& set2) : UnknownSet( AttributeSet(set1.unknown_attributes(), set2.unknown_attributes()), FunctionResultSet(set1.unknown_function_results(), set2.unknown_function_results())) {} const AttributeSet& unknown_attributes() const { return rep_ != nullptr ? rep_->attributes : EmptyAttributeSet(); } const FunctionResultSet& unknown_function_results() const { return rep_ != nullptr ? rep_->function_results : EmptyFunctionResultSet(); } bool operator==(const UnknownSet& other) const { return this == &other || (unknown_attributes() == other.unknown_attributes() && unknown_function_results() == other.unknown_function_results()); } bool operator!=(const UnknownSet& other) const { return !operator==(other); } private: friend struct UnknownSetAccess; explicit UnknownSet(std::shared_ptr impl) : rep_(std::move(impl)) {} void Add(const UnknownSet& other) { if (rep_ == nullptr) { rep_ = std::make_shared(); } rep_->attributes.Add(other.unknown_attributes()); rep_->function_results.Add(other.unknown_function_results()); } std::shared_ptr rep_; }; struct UnknownSetAccess final { static UnknownSet Construct(std::shared_ptr rep) { return UnknownSet(std::move(rep)); } static void Add(UnknownSet& dest, const UnknownSet& src) { dest.Add(src); } static const std::shared_ptr& Rep(const UnknownSet& value) { return value.rep_; } }; } // namespace cel::base_internal #endif // THIRD_PARTY_CEL_CPP_BASE_INTERNAL_UNKNOWN_SET_H_ ================================================ FILE: base/kind.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_KIND_H_ #define THIRD_PARTY_CEL_CPP_BASE_KIND_H_ // This header exists for compatibility and should be removed once all includes // have been updated. #include "common/kind.h" // IWYU pragma: export #include "common/type_kind.h" // IWYU pragma: export #include "common/value_kind.h" // IWYU pragma: export #endif // THIRD_PARTY_CEL_CPP_BASE_KIND_H_ ================================================ FILE: base/operators.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "base/operators.h" #include #include #include "absl/base/attributes.h" #include "absl/base/call_once.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/internal/operators.h" namespace cel { namespace { using base_internal::OperatorData; struct OperatorDataNameComparer { using is_transparent = void; bool operator()(const OperatorData* lhs, const OperatorData* rhs) const { return lhs->name < rhs->name; } bool operator()(const OperatorData* lhs, absl::string_view rhs) const { return lhs->name < rhs; } bool operator()(absl::string_view lhs, const OperatorData* rhs) const { return lhs < rhs->name; } }; struct OperatorDataDisplayNameComparer { using is_transparent = void; bool operator()(const OperatorData* lhs, const OperatorData* rhs) const { return lhs->display_name < rhs->display_name; } bool operator()(const OperatorData* lhs, absl::string_view rhs) const { return lhs->display_name < rhs; } bool operator()(absl::string_view lhs, const OperatorData* rhs) const { return lhs < rhs->display_name; } }; #define CEL_OPERATORS_DATA(id, symbol, name, precedence, arity) \ ABSL_CONST_INIT const OperatorData id##_storage = { \ OperatorId::k##id, name, symbol, precedence, arity}; CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DATA) #undef CEL_OPERATORS_DATA #define CEL_OPERATORS_COUNT(id, symbol, name, precedence, arity) +1 using OperatorsArray = std::array; using UnaryOperatorsArray = std::array; using BinaryOperatorsArray = std::array; using TernaryOperatorsArray = std::array; #undef CEL_OPERATORS_COUNT ABSL_CONST_INIT absl::once_flag operators_once_flag; #define CEL_OPERATORS_DO(id, symbol, name, precedence, arity) &id##_storage, OperatorsArray operators_by_name = { CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DO)}; OperatorsArray operators_by_display_name = { CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATORS_DO)}; UnaryOperatorsArray unary_operators_by_name = { CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; UnaryOperatorsArray unary_operators_by_display_name = { CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; BinaryOperatorsArray binary_operators_by_name = { CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; BinaryOperatorsArray binary_operators_by_display_name = { CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; TernaryOperatorsArray ternary_operators_by_name = { CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; TernaryOperatorsArray ternary_operators_by_display_name = { CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_OPERATORS_DO)}; #undef CEL_OPERATORS_DO void InitializeOperators() { std::stable_sort(operators_by_name.begin(), operators_by_name.end(), OperatorDataNameComparer{}); std::stable_sort(operators_by_display_name.begin(), operators_by_display_name.end(), OperatorDataDisplayNameComparer{}); std::stable_sort(unary_operators_by_name.begin(), unary_operators_by_name.end(), OperatorDataNameComparer{}); std::stable_sort(unary_operators_by_display_name.begin(), unary_operators_by_display_name.end(), OperatorDataDisplayNameComparer{}); std::stable_sort(binary_operators_by_name.begin(), binary_operators_by_name.end(), OperatorDataNameComparer{}); std::stable_sort(binary_operators_by_display_name.begin(), binary_operators_by_display_name.end(), OperatorDataDisplayNameComparer{}); std::stable_sort(ternary_operators_by_name.begin(), ternary_operators_by_name.end(), OperatorDataNameComparer{}); std::stable_sort(ternary_operators_by_display_name.begin(), ternary_operators_by_display_name.end(), OperatorDataDisplayNameComparer{}); } } // namespace UnaryOperator::UnaryOperator(Operator op) : data_(op.data_) { ABSL_CHECK(op.arity() == Arity::kUnary); // Crask OK } BinaryOperator::BinaryOperator(Operator op) : data_(op.data_) { ABSL_CHECK(op.arity() == Arity::kBinary); // Crask OK } TernaryOperator::TernaryOperator(Operator op) : data_(op.data_) { ABSL_CHECK(op.arity() == Arity::kTernary); // Crask OK } #define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ UnaryOperator Operator::id() { return UnaryOperator(&id##_storage); } CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR) #undef CEL_UNARY_OPERATOR #define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ BinaryOperator Operator::id() { return BinaryOperator(&id##_storage); } CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR) #undef CEL_BINARY_OPERATOR #define CEL_TERNARY_OPERATOR(id, symbol, name, precedence, arity) \ TernaryOperator Operator::id() { return TernaryOperator(&id##_storage); } CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) #undef CEL_TERNARY_OPERATOR absl::optional Operator::FindByName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { return absl::nullopt; } auto it = std::lower_bound(operators_by_name.cbegin(), operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == operators_by_name.cend() || (*it)->name != input) { return absl::nullopt; } return Operator(*it); } absl::optional Operator::FindByDisplayName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { return absl::nullopt; } auto it = std::lower_bound(operators_by_display_name.cbegin(), operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == operators_by_name.cend() || (*it)->display_name != input) { return absl::nullopt; } return Operator(*it); } absl::optional UnaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { return absl::nullopt; } auto it = std::lower_bound(unary_operators_by_name.cbegin(), unary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == unary_operators_by_name.cend() || (*it)->name != input) { return absl::nullopt; } return UnaryOperator(*it); } absl::optional UnaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { return absl::nullopt; } auto it = std::lower_bound(unary_operators_by_display_name.cbegin(), unary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == unary_operators_by_display_name.cend() || (*it)->display_name != input) { return absl::nullopt; } return UnaryOperator(*it); } absl::optional BinaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { return absl::nullopt; } auto it = std::lower_bound(binary_operators_by_name.cbegin(), binary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == binary_operators_by_name.cend() || (*it)->name != input) { return absl::nullopt; } return BinaryOperator(*it); } absl::optional BinaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { return absl::nullopt; } auto it = std::lower_bound(binary_operators_by_display_name.cbegin(), binary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == binary_operators_by_display_name.cend() || (*it)->display_name != input) { return absl::nullopt; } return BinaryOperator(*it); } absl::optional TernaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { return absl::nullopt; } auto it = std::lower_bound(ternary_operators_by_name.cbegin(), ternary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == ternary_operators_by_name.cend() || (*it)->name != input) { return absl::nullopt; } return TernaryOperator(*it); } absl::optional TernaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { return absl::nullopt; } auto it = std::lower_bound(ternary_operators_by_display_name.cbegin(), ternary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == ternary_operators_by_display_name.cend() || (*it)->display_name != input) { return absl::nullopt; } return TernaryOperator(*it); } } // namespace cel ================================================ FILE: base/operators.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ #define THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ #include #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/internal/operators.h" namespace cel { enum class Arity { kUnary = 1, kBinary = 2, kTernary = 3, }; enum class OperatorId { kConditional = 1, kLogicalAnd, kLogicalOr, kLogicalNot, kEquals, kNotEquals, kLess, kLessEquals, kGreater, kGreaterEquals, kAdd, kSubtract, kMultiply, kDivide, kModulo, kNegate, kIndex, kIn, kNotStrictlyFalse, kOldIn, kOldNotStrictlyFalse, }; enum class UnaryOperatorId { kLogicalNot = static_cast(OperatorId::kLogicalNot), kNegate = static_cast(OperatorId::kNegate), kNotStrictlyFalse = static_cast(OperatorId::kNotStrictlyFalse), kOldNotStrictlyFalse = static_cast(OperatorId::kOldNotStrictlyFalse), }; enum class BinaryOperatorId { kLogicalAnd = static_cast(OperatorId::kLogicalAnd), kLogicalOr = static_cast(OperatorId::kLogicalOr), kEquals = static_cast(OperatorId::kEquals), kNotEquals = static_cast(OperatorId::kNotEquals), kLess = static_cast(OperatorId::kLess), kLessEquals = static_cast(OperatorId::kLessEquals), kGreater = static_cast(OperatorId::kGreater), kGreaterEquals = static_cast(OperatorId::kGreaterEquals), kAdd = static_cast(OperatorId::kAdd), kSubtract = static_cast(OperatorId::kSubtract), kMultiply = static_cast(OperatorId::kMultiply), kDivide = static_cast(OperatorId::kDivide), kModulo = static_cast(OperatorId::kModulo), kIndex = static_cast(OperatorId::kIndex), kIn = static_cast(OperatorId::kIn), kOldIn = static_cast(OperatorId::kOldIn), }; enum class TernaryOperatorId { kConditional = static_cast(OperatorId::kConditional), }; class UnaryOperator; class BinaryOperator; class TernaryOperator; class Operator final { public: ABSL_ATTRIBUTE_PURE_FUNCTION static TernaryOperator Conditional(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalAnd(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalOr(); ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator LogicalNot(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Equals(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator NotEquals(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Less(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LessEquals(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Greater(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator GreaterEquals(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Add(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Subtract(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Multiply(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Divide(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Modulo(); ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator Negate(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Index(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator In(); ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator NotStrictlyFalse(); ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator OldIn(); ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator OldNotStrictlyFalse(); ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( absl::string_view input); ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByDisplayName(absl::string_view input); Operator() = delete; Operator(const Operator&) = default; Operator(Operator&&) = default; Operator& operator=(const Operator&) = default; Operator& operator=(Operator&&) = default; constexpr OperatorId id() const { return data_->id; } // Returns the name of the operator. This is the managed representation of the // operator, for example "_&&_". constexpr absl::string_view name() const { return data_->name; } // Returns the source text representation of the operator. This is the // unmanaged text representation of the operator, for example "&&". // // Note that this will be empty for operators like Conditional() and Index(). constexpr absl::string_view display_name() const { return data_->display_name; } constexpr int precedence() const { return data_->precedence; } constexpr Arity arity() const { return static_cast(data_->arity); } private: friend class UnaryOperator; friend class BinaryOperator; friend class TernaryOperator; constexpr explicit Operator(const base_internal::OperatorData* data) : data_(data) {} const base_internal::OperatorData* data_; }; constexpr bool operator==(const Operator& lhs, const Operator& rhs) { return lhs.id() == rhs.id(); } constexpr bool operator==(OperatorId lhs, const Operator& rhs) { return lhs == rhs.id(); } constexpr bool operator==(const Operator& lhs, OperatorId rhs) { return operator==(rhs, lhs); } constexpr bool operator!=(const Operator& lhs, const Operator& rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(OperatorId lhs, const Operator& rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(const Operator& lhs, OperatorId rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const Operator& op) { return H::combine(std::move(state), static_cast(op.id())); } class UnaryOperator final { public: ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator LogicalNot() { return Operator::LogicalNot(); } ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator Negate() { return Operator::Negate(); } ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator NotStrictlyFalse() { return Operator::NotStrictlyFalse(); } ABSL_ATTRIBUTE_PURE_FUNCTION static UnaryOperator OldNotStrictlyFalse() { return Operator::OldNotStrictlyFalse(); } ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( absl::string_view input); ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByDisplayName(absl::string_view input); UnaryOperator() = delete; UnaryOperator(const UnaryOperator&) = default; UnaryOperator(UnaryOperator&&) = default; UnaryOperator& operator=(const UnaryOperator&) = default; UnaryOperator& operator=(UnaryOperator&&) = default; // Support for explicit casting of Operator to UnaryOperator. // `Operator::arity()` must return `Arity::kUnary`, or this will crash. explicit UnaryOperator(Operator op); constexpr UnaryOperatorId id() const { return static_cast(data_->id); } // Returns the name of the operator. This is the managed representation of the // operator, for example "_&&_". constexpr absl::string_view name() const { return data_->name; } // Returns the source text representation of the operator. This is the // unmanaged text representation of the operator, for example "&&". // // Note that this will be empty for operators like Conditional() and Index(). constexpr absl::string_view display_name() const { return data_->display_name; } constexpr int precedence() const { return data_->precedence; } constexpr Arity arity() const { ABSL_ASSERT(data_->arity == 1); return Arity::kUnary; } constexpr operator Operator() const { // NOLINT(google-explicit-constructor) return Operator(data_); } private: friend class Operator; constexpr explicit UnaryOperator(const base_internal::OperatorData* data) : data_(data) {} const base_internal::OperatorData* data_; }; constexpr bool operator==(const UnaryOperator& lhs, const UnaryOperator& rhs) { return lhs.id() == rhs.id(); } constexpr bool operator==(UnaryOperatorId lhs, const UnaryOperator& rhs) { return lhs == rhs.id(); } constexpr bool operator==(const UnaryOperator& lhs, UnaryOperatorId rhs) { return operator==(rhs, lhs); } constexpr bool operator!=(const UnaryOperator& lhs, const UnaryOperator& rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(UnaryOperatorId lhs, const UnaryOperator& rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(const UnaryOperator& lhs, UnaryOperatorId rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const UnaryOperator& op) { return H::combine(std::move(state), static_cast(op.id())); } class BinaryOperator final { public: ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalAnd() { return Operator::LogicalAnd(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LogicalOr() { return Operator::LogicalOr(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Equals() { return Operator::Equals(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator NotEquals() { return Operator::NotEquals(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Less() { return Operator::Less(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator LessEquals() { return Operator::LessEquals(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Greater() { return Operator::Greater(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator GreaterEquals() { return Operator::GreaterEquals(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Add() { return Operator::Add(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Subtract() { return Operator::Subtract(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Multiply() { return Operator::Multiply(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Divide() { return Operator::Divide(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Modulo() { return Operator::Modulo(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator Index() { return Operator::Index(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator In() { return Operator::In(); } ABSL_ATTRIBUTE_PURE_FUNCTION static BinaryOperator OldIn() { return Operator::OldIn(); } ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName( absl::string_view input); ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByDisplayName(absl::string_view input); BinaryOperator() = delete; BinaryOperator(const BinaryOperator&) = default; BinaryOperator(BinaryOperator&&) = default; BinaryOperator& operator=(const BinaryOperator&) = default; BinaryOperator& operator=(BinaryOperator&&) = default; // Support for explicit casting of Operator to BinaryOperator. // `Operator::arity()` must return `Arity::kBinary`, or this will crash. explicit BinaryOperator(Operator op); constexpr BinaryOperatorId id() const { return static_cast(data_->id); } // Returns the name of the operator. This is the managed representation of the // operator, for example "_&&_". constexpr absl::string_view name() const { return data_->name; } // Returns the source text representation of the operator. This is the // unmanaged text representation of the operator, for example "&&". // // Note that this will be empty for operators like Conditional() and Index(). constexpr absl::string_view display_name() const { return data_->display_name; } constexpr int precedence() const { return data_->precedence; } constexpr Arity arity() const { ABSL_ASSERT(data_->arity == 2); return Arity::kBinary; } constexpr operator Operator() const { // NOLINT(google-explicit-constructor) return Operator(data_); } private: friend class Operator; constexpr explicit BinaryOperator(const base_internal::OperatorData* data) : data_(data) {} const base_internal::OperatorData* data_; }; constexpr bool operator==(const BinaryOperator& lhs, const BinaryOperator& rhs) { return lhs.id() == rhs.id(); } constexpr bool operator==(BinaryOperatorId lhs, const BinaryOperator& rhs) { return lhs == rhs.id(); } constexpr bool operator==(const BinaryOperator& lhs, BinaryOperatorId rhs) { return operator==(rhs, lhs); } constexpr bool operator!=(const BinaryOperator& lhs, const BinaryOperator& rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(BinaryOperatorId lhs, const BinaryOperator& rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(const BinaryOperator& lhs, BinaryOperatorId rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const BinaryOperator& op) { return H::combine(std::move(state), static_cast(op.id())); } class TernaryOperator final { public: ABSL_ATTRIBUTE_PURE_FUNCTION static TernaryOperator Conditional() { return Operator::Conditional(); } ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByName(absl::string_view input); ABSL_ATTRIBUTE_PURE_FUNCTION static absl::optional FindByDisplayName(absl::string_view input); TernaryOperator() = delete; TernaryOperator(const TernaryOperator&) = default; TernaryOperator(TernaryOperator&&) = default; TernaryOperator& operator=(const TernaryOperator&) = default; TernaryOperator& operator=(TernaryOperator&&) = default; // Support for explicit casting of Operator to TernaryOperator. // `Operator::arity()` must return `Arity::kTernary`, or this will crash. explicit TernaryOperator(Operator op); constexpr TernaryOperatorId id() const { return static_cast(data_->id); } // Returns the name of the operator. This is the managed representation of the // operator, for example "_&&_". constexpr absl::string_view name() const { return data_->name; } // Returns the source text representation of the operator. This is the // unmanaged text representation of the operator, for example "&&". // // Note that this will be empty for operators like Conditional() and Index(). constexpr absl::string_view display_name() const { return data_->display_name; } constexpr int precedence() const { return data_->precedence; } constexpr Arity arity() const { ABSL_ASSERT(data_->arity == 3); return Arity::kTernary; } constexpr operator Operator() const { // NOLINT(google-explicit-constructor) return Operator(data_); } private: friend class Operator; constexpr explicit TernaryOperator(const base_internal::OperatorData* data) : data_(data) {} const base_internal::OperatorData* data_; }; constexpr bool operator==(const TernaryOperator& lhs, const TernaryOperator& rhs) { return lhs.id() == rhs.id(); } constexpr bool operator==(TernaryOperatorId lhs, const TernaryOperator& rhs) { return lhs == rhs.id(); } constexpr bool operator==(const TernaryOperator& lhs, TernaryOperatorId rhs) { return operator==(rhs, lhs); } constexpr bool operator!=(const TernaryOperator& lhs, const TernaryOperator& rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(TernaryOperatorId lhs, const TernaryOperator& rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(const TernaryOperator& lhs, TernaryOperatorId rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const TernaryOperator& op) { return H::combine(std::move(state), static_cast(op.id())); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_BASE_OPERATORS_H_ ================================================ FILE: base/operators_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "base/operators.h" #include #include "absl/hash/hash_testing.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/internal/operators.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::Eq; using ::testing::Optional; template void TestOperator(Op op, OpId id, absl::string_view name, absl::string_view display_name, int precedence, Arity arity) { EXPECT_EQ(op.id(), id); EXPECT_EQ(Operator(op).id(), static_cast(id)); EXPECT_EQ(op.name(), name); EXPECT_EQ(op.display_name(), display_name); EXPECT_EQ(op.precedence(), precedence); EXPECT_EQ(op.arity(), arity); EXPECT_EQ(Operator(op).arity(), arity); EXPECT_EQ(Op(Operator(op)), op); } void TestUnaryOperator(UnaryOperator op, UnaryOperatorId id, absl::string_view name, absl::string_view display_name, int precedence) { TestOperator(op, id, name, display_name, precedence, Arity::kUnary); } void TestBinaryOperator(BinaryOperator op, BinaryOperatorId id, absl::string_view name, absl::string_view display_name, int precedence) { TestOperator(op, id, name, display_name, precedence, Arity::kBinary); } void TestTernaryOperator(TernaryOperator op, TernaryOperatorId id, absl::string_view name, absl::string_view display_name, int precedence) { TestOperator(op, id, name, display_name, precedence, Arity::kTernary); } TEST(Operator, TypeTraits) { EXPECT_FALSE(std::is_default_constructible_v); EXPECT_TRUE(std::is_copy_constructible_v); EXPECT_TRUE(std::is_move_constructible_v); EXPECT_TRUE(std::is_copy_assignable_v); EXPECT_TRUE(std::is_move_assignable_v); EXPECT_FALSE((std::is_convertible_v)); EXPECT_FALSE((std::is_convertible_v)); EXPECT_FALSE((std::is_convertible_v)); } TEST(UnaryOperator, TypeTraits) { EXPECT_FALSE(std::is_default_constructible_v); EXPECT_TRUE(std::is_copy_constructible_v); EXPECT_TRUE(std::is_move_constructible_v); EXPECT_TRUE(std::is_copy_assignable_v); EXPECT_TRUE(std::is_move_assignable_v); EXPECT_TRUE((std::is_convertible_v)); } TEST(BinaryOperator, TypeTraits) { EXPECT_FALSE(std::is_default_constructible_v); EXPECT_TRUE(std::is_copy_constructible_v); EXPECT_TRUE(std::is_move_constructible_v); EXPECT_TRUE(std::is_copy_assignable_v); EXPECT_TRUE(std::is_move_assignable_v); EXPECT_TRUE((std::is_convertible_v)); } TEST(TernaryOperator, TypeTraits) { EXPECT_FALSE(std::is_default_constructible_v); EXPECT_TRUE(std::is_copy_constructible_v); EXPECT_TRUE(std::is_move_constructible_v); EXPECT_TRUE(std::is_copy_assignable_v); EXPECT_TRUE(std::is_move_assignable_v); EXPECT_TRUE((std::is_convertible_v)); } #define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ TEST(UnaryOperator, id) { \ TestUnaryOperator(UnaryOperator::id(), UnaryOperatorId::k##id, name, \ symbol, precedence); \ } CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR) #undef CEL_UNARY_OPERATOR #define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ TEST(BinaryOperator, id) { \ TestBinaryOperator(BinaryOperator::id(), BinaryOperatorId::k##id, name, \ symbol, precedence); \ } CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR) #undef CEL_BINARY_OPERATOR #define CEL_TERNARY_OPERATOR(id, symbol, name, precedence, arity) \ TEST(TernaryOperator, id) { \ TestTernaryOperator(TernaryOperator::id(), TernaryOperatorId::k##id, name, \ symbol, precedence); \ } CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) #undef CEL_TERNARY_OPERATOR TEST(Operator, FindByName) { EXPECT_THAT(Operator::FindByName("@in"), Optional(Eq(Operator::In()))); EXPECT_THAT(Operator::FindByName("_in_"), Optional(Eq(Operator::OldIn()))); EXPECT_THAT(Operator::FindByName("in"), Eq(absl::nullopt)); EXPECT_THAT(Operator::FindByName(""), Eq(absl::nullopt)); } TEST(Operator, FindByDisplayName) { EXPECT_THAT(Operator::FindByDisplayName("-"), Optional(Eq(Operator::Subtract()))); EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(absl::nullopt)); EXPECT_THAT(Operator::FindByDisplayName(""), Eq(absl::nullopt)); } TEST(UnaryOperator, FindByName) { EXPECT_THAT(UnaryOperator::FindByName("-_"), Optional(Eq(Operator::Negate()))); EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(absl::nullopt)); EXPECT_THAT(UnaryOperator::FindByName(""), Eq(absl::nullopt)); } TEST(UnaryOperator, FindByDisplayName) { EXPECT_THAT(UnaryOperator::FindByDisplayName("-"), Optional(Eq(Operator::Negate()))); EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(absl::nullopt)); EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); } TEST(BinaryOperator, FindByName) { EXPECT_THAT(BinaryOperator::FindByName("_-_"), Optional(Eq(Operator::Subtract()))); EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(absl::nullopt)); EXPECT_THAT(BinaryOperator::FindByName(""), Eq(absl::nullopt)); } TEST(BinaryOperator, FindByDisplayName) { EXPECT_THAT(BinaryOperator::FindByDisplayName("-"), Optional(Eq(Operator::Subtract()))); EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); } TEST(TernaryOperator, FindByName) { EXPECT_THAT(TernaryOperator::FindByName("_?_:_"), Optional(Eq(TernaryOperator::Conditional()))); EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(absl::nullopt)); EXPECT_THAT(TernaryOperator::FindByName(""), Eq(absl::nullopt)); } TEST(TernaryOperator, FindByDisplayName) { EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); } TEST(Operator, SupportsAbslHash) { #define CEL_OPERATOR(id, symbol, name, precedence, arity) \ Operator(Operator::id()), EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( {CEL_INTERNAL_OPERATORS_ENUM(CEL_OPERATOR)})); #undef CEL_OPERATOR } TEST(UnaryOperator, SupportsAbslHash) { #define CEL_UNARY_OPERATOR(id, symbol, name, precedence, arity) \ UnaryOperator::id(), EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( {CEL_INTERNAL_UNARY_OPERATORS_ENUM(CEL_UNARY_OPERATOR)})); #undef CEL_UNARY_OPERATOR } TEST(BinaryOperator, SupportsAbslHash) { #define CEL_BINARY_OPERATOR(id, symbol, name, precedence, arity) \ BinaryOperator::id(), EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( {CEL_INTERNAL_BINARY_OPERATORS_ENUM(CEL_BINARY_OPERATOR)})); #undef CEL_BINARY_OPERATOR } } // namespace } // namespace cel ================================================ FILE: base/type_provider.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ #include "common/type_reflector.h" // IWYU pragma: export namespace cel { using TypeProvider = TypeReflector; } #endif // THIRD_PARTY_CEL_CPP_BASE_TYPE_PROVIDER_H_ ================================================ FILE: bazel/BUILD ================================================ load("@rules_cc//cc:cc_binary.bzl", "cc_binary") load("@rules_java//java:defs.bzl", "java_binary") java_binary( name = "antlr4_tool", main_class = "org.antlr.v4.Tool", runtime_deps = ["@antlr4_jar//jar"], ) package(default_visibility = ["//visibility:public"]) exports_files( srcs = [ "antlr.patch", ], visibility = ["//:__subpackages__"], ) cc_binary( name = "cel_cc_embed", srcs = ["cel_cc_embed.cc"], visibility = ["//:__subpackages__"], deps = [ "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:initialize", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) cc_binary( name = "cat_param_file", srcs = ["cat_param_file.cc"], visibility = ["//:__subpackages__"], deps = [ "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/log:initialize", ], ) ================================================ FILE: bazel/antlr.bzl ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Generate C++ parser and lexer from a grammar file. """ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc/common:cc_common.bzl", "cc_common") load("@rules_cc//cc/common:cc_info.bzl", "CcInfo") def antlr_cc_library(name, src, package): """Creates a C++ lexer and parser from a source grammar. Args: name: Base name for the lexer and the parser rules. src: source ANTLR grammar file package: The namespace for the generated code """ generated = name + "_grammar" antlr_library( name = generated, src = src, package = package, shell = select( { "@platforms//os:windows": "PowerShell.exe", "//conditions:default": "bash", }, ), genfiles_prefixed = select( { "@platforms//os:windows": False, "//conditions:default": True, }, ), ) cc_library( name = name + "_cc_parser", srcs = [generated], defines = [ "ANTLR4CPP_STATIC", ], deps = [ generated, "@antlr4-cpp-runtime//:antlr4-cpp-runtime", ], linkstatic = 1, ) def _antlr_library(ctx): output = ctx.actions.declare_directory(ctx.attr.name) antlr_args = ctx.actions.args() antlr_args.add("-Dlanguage=Cpp") antlr_args.add("-no-listener") antlr_args.add("-visitor") antlr_args.add("-o", output.path) antlr_args.add("-package", ctx.attr.package) antlr_args.add(ctx.file.src) # Strip ".g4" extension. basename = ctx.file.src.basename[:-3] suffixes = ["Lexer", "Parser", "BaseVisitor", "Visitor"] ctx.actions.run( mnemonic = "GenAntlr", arguments = [antlr_args], inputs = [ctx.file.src], outputs = [output], executable = ctx.executable._tool, progress_message = "Processing ANTLR grammar. -o " + output.path, ) files = [] for suffix in suffixes: header = ctx.actions.declare_file(basename + suffix + ".h") source = ctx.actions.declare_file(basename + suffix + ".cpp") prefix = ctx.file.src.path[:-3] if ctx.attr.genfiles_prefixed else basename generated = output.path + "/" + prefix + suffix executable = ctx.attr.shell ctx.actions.run( mnemonic = "CopyHeader" + suffix, inputs = [output], outputs = [header], executable = executable, arguments = [ "-c", 'cp "{generated}" "{out}"'.format(generated = generated + ".h", out = header.path), ], ) ctx.actions.run( mnemonic = "CopySource" + suffix, inputs = [output], outputs = [source], executable = executable, arguments = [ "-c", 'cp "{generated}" "{out}"'.format(generated = generated + ".cpp", out = source.path), ], ) files.append(header) files.append(source) compilation_context = cc_common.create_compilation_context(headers = depset(files)) return [DefaultInfo(files = depset(files)), CcInfo(compilation_context = compilation_context)] antlr_library = rule( implementation = _antlr_library, attrs = { "src": attr.label(allow_single_file = [".g4"], mandatory = True), "package": attr.string(), "_tool": attr.label( executable = True, cfg = "exec", # buildifier: disable=attr-cfg default = Label("//bazel:antlr4_tool"), ), "shell": attr.string( mandatory = True, ), "genfiles_prefixed": attr.bool( mandatory = True, ), }, ) ================================================ FILE: bazel/antlr.patch ================================================ --- BUILD.bazel +++ BUILD.bazel @@ -17,21 +17,21 @@ cc_library( name = "antlr4-cpp-runtime", srcs = glob(["runtime/src/**/*.cpp"]), hdrs = ["runtime/src/antlr4-runtime.h"], copts = ["-fexceptions"], - defines = ["ANTLR4CPP_USING_ABSEIL"], + defines = ["ANTLR4CPP_USING_ABSEIL", "ANTLR4CPP_STATIC"], features = ["-use_header_modules"], includes = ["runtime/src"], textual_hdrs = glob( ["runtime/src/**/*.h"], exclude = ["runtime/src/antlr4-runtime.h"], ), visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/synchronization", ], ) --- VERSION +++ /dev/null @@ -1,1 +1,0 @@ -4.13.2 ================================================ FILE: bazel/cat_param_file.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/log/initialize.h" // Read a bazel param file and concatenate the inputs. // The param file is line delimited with each line a file to concat. int main(int argc, char** argv) { absl::InitializeLog(); if (argc != 3) { std::cerr << "usage: cat_param_file " << std::endl; std::cerr << "args " << argc << std::endl; return 2; } const char* param_file = argv[1]; const char* out_file = argv[2]; std::ifstream ifs(param_file, std::ios::binary); std::ofstream ofs(out_file, std::ios::binary); ABSL_QCHECK(ifs.good()) << "failed to open param file " << param_file; ABSL_QCHECK(ofs.good()) << "failed to open out file " << out_file; for (std::string line; std::getline(ifs, line);) { std::ifstream in(line, std::ios::binary); if (!in.good()) { ABSL_LOG(ERROR) << "failed to open input file " << line; continue; } constexpr size_t kBufSize = 256; char buf[kBufSize]; while (true) { in.read(buf, kBufSize); size_t read = in.gcount(); if (read == 0) { break; } ofs.write(buf, read); } } ofs.flush(); return 0; } ================================================ FILE: bazel/cel_cc_embed.bzl ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Provides the `cel_cc_embed` build rule. """ def _cel_cc_embed(ctx): output = ctx.actions.declare_file(ctx.attr.name + ".inc") args = ctx.actions.args() src = ctx.file.src args.add("--in", src) args.add("--out", output.path) ctx.actions.run( mnemonic = "GenerateEmbedTextualHeader", outputs = [output], inputs = [src], progress_message = "generating embed textual header", executable = ctx.executable.gen_tool, arguments = [args], ) return DefaultInfo( files = depset([output]), ) cel_cc_embed = rule( implementation = _cel_cc_embed, attrs = { "src": attr.label(allow_single_file = True, mandatory = True), "gen_tool": attr.label( executable = True, cfg = "exec", allow_files = True, default = Label("//bazel:cel_cc_embed"), ), }, ) ================================================ FILE: bazel/cel_cc_embed.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include "absl/flags/flag.h" #include "absl/flags/parse.h" #include "absl/log/absl_check.h" #include "absl/log/initialize.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" ABSL_FLAG(std::string, in, "", ""); ABSL_FLAG(std::string, out, "", ""); namespace { std::vector ReadFile(const std::string& path) { ABSL_CHECK(!path.empty()) << "--in is required"; std::ifstream file(path, std::ifstream::binary); ABSL_CHECK(file.is_open()) << path; file.seekg(0, file.end); ABSL_CHECK(file.good()); size_t size = static_cast(file.tellg()); file.seekg(0, file.beg); ABSL_CHECK(file.good()); std::vector buffer; buffer.resize(size); file.read(reinterpret_cast(buffer.data()), size); ABSL_CHECK(file.good()); return buffer; } void WriteFile(const std::string& path, absl::Span data) { ABSL_CHECK(!path.empty()) << "--out is required"; std::ofstream file(path); ABSL_CHECK(file.is_open()) << path; file.write(data.data(), data.size()); ABSL_CHECK(file.good()); file.flush(); ABSL_CHECK(file.good()); } } // namespace int main(int argc, char** argv) { { auto args = absl::ParseCommandLine(argc, argv); ABSL_CHECK(args.empty() || args.size() == 1) << "unexpected positional args: " << absl::StrJoin(args, ", "); } absl::InitializeLog(); auto in_buffer = ReadFile(absl::GetFlag(FLAGS_in)); std::string out_buffer; out_buffer.reserve(in_buffer.size() * 6); for (const auto& in_byte : in_buffer) { absl::StrAppend(&out_buffer, "0x", absl::Hex(in_byte, absl::PadSpec::kZeroPad2), ", "); } if (!in_buffer.empty()) { // Replace last space with newline. out_buffer.back() = '\n'; } WriteFile(absl::GetFlag(FLAGS_out), out_buffer); return EXIT_SUCCESS; } ================================================ FILE: bazel/cel_proto_transitive_descriptor_set.bzl ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Provides the `cel_proto_transitive_descriptor_set` build rule. """ load("@com_google_protobuf//bazel/common:proto_info.bzl", "ProtoInfo") def _cel_proto_transitive_descriptor_set(ctx): output = ctx.actions.declare_file(ctx.attr.name + ".binarypb") transitive_descriptor_sets = depset(transitive = [dep[ProtoInfo].transitive_descriptor_sets for dep in ctx.attr.deps]) args = ctx.actions.args() args.use_param_file(param_file_arg = "%s", use_always = True) args.add_all(transitive_descriptor_sets) ctx.actions.run( mnemonic = "CelProtoTransitiveDescriptorSet", outputs = [output], inputs = transitive_descriptor_sets, progress_message = "Joining descriptors.", executable = ctx.executable.cat_tool, arguments = [args] + [output.path], ) return DefaultInfo( files = depset([output]), runfiles = ctx.runfiles(files = [output]), ) cel_proto_transitive_descriptor_set = rule( attrs = { "deps": attr.label_list(providers = [[ProtoInfo]]), "cat_tool": attr.label( executable = True, cfg = "exec", allow_files = True, default = Label("//bazel:cat_param_file"), ), }, outputs = { "out": "%{name}.binarypb", }, implementation = _cel_proto_transitive_descriptor_set, ) ================================================ FILE: bazel/deps.bzl ================================================ """ Legacy workspace dependencies of cel-cpp. Dependencies are now managed by MODULE.bazel. The values here are not updated, but this file is retained for clients that referenced it directly. """ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_jar") def base_deps(): """Base evaluator and test dependencies.""" # Abseil LTS 20240722.0 ABSL_SHA1 = "4447c7562e3bc702ade25105912dce503f0c4010" ABSL_SHA256 = "d8342ad77aa9e16103c486b615460c24a695a1f04cdb760eb02fef780df99759" http_archive( name = "com_google_absl", urls = ["https://github.com/abseil/abseil-cpp/archive/" + ABSL_SHA1 + ".zip"], strip_prefix = "abseil-cpp-" + ABSL_SHA1, sha256 = ABSL_SHA256, ) # v1.15.2 GOOGLETEST_SHA1 = "b514bdc898e2951020cbdca1304b75f5950d1f59" GOOGLETEST_SHA256 = "8c0ceafa3ea24bf78e3519b7846d99e76c45899aa4dac4d64e7dd62e495de9fd" http_archive( name = "com_google_googletest", urls = ["https://github.com/google/googletest/archive/" + GOOGLETEST_SHA1 + ".zip"], strip_prefix = "googletest-" + GOOGLETEST_SHA1, sha256 = GOOGLETEST_SHA256, ) # v1.6.0 BENCHMARK_SHA1 = "f91b6b42b1b9854772a90ae9501464a161707d1e" BENCHMARK_SHA256 = "00bd0837db9266c758a087cdf0831a0d3e337c6bb9e3fad75d2be4f9bf480d95" http_archive( name = "com_github_google_benchmark", urls = ["https://github.com/google/benchmark/archive/" + BENCHMARK_SHA1 + ".zip"], strip_prefix = "benchmark-" + BENCHMARK_SHA1, sha256 = BENCHMARK_SHA256, ) # 2024-02-01 RE2_SHA1 = "9665465b69ab699279ef9fb9454559d90fed1d76" RE2_SHA256 = "dcd82922c7a1d3b7c2a147c045585a9f76066f9c0269a06b857eccbbf6f96dba" http_archive( name = "com_googlesource_code_re2", urls = ["https://github.com/google/re2/archive/" + RE2_SHA1 + ".zip"], strip_prefix = "re2-" + RE2_SHA1, sha256 = RE2_SHA256, ) # v28.0 PROTOBUF_SHA1 = "439c42c735ae1efed57ab7771986f2a3c0b99319" PROTOBUF_SHA256 = "495b76871df8d102e5c539f9d43f990f5ca53ac183702f5ed90070ba8c8759d1" http_archive( name = "com_google_protobuf", sha256 = PROTOBUF_SHA256, strip_prefix = "protobuf-" + PROTOBUF_SHA1, urls = ["https://github.com/protocolbuffers/protobuf/archive/" + PROTOBUF_SHA1 + ".zip"], ) GOOGLEAPIS_GIT_SHA = "6eb56cdf5f54f70d0dbfce051add28a35c1203ce" # June 26, 2024 GOOGLEAPIS_SHA = "6321a7eac9e5280e7abca07ddf2cab9179cbd49a6828c26f4c7c73d5a45f39ad" http_archive( name = "com_google_googleapis", sha256 = GOOGLEAPIS_SHA, strip_prefix = "googleapis-" + GOOGLEAPIS_GIT_SHA, urls = ["https://github.com/googleapis/googleapis/archive/" + GOOGLEAPIS_GIT_SHA + ".tar.gz"], ) http_archive( name = "rules_cc", urls = ["https://github.com/bazelbuild/rules_cc/releases/download/0.0.10-rc1/rules_cc-0.0.10-rc1.tar.gz"], sha256 = "d75a040c32954da0d308d3f2ea2ba735490f49b3a7aa3e4b40259ca4b814f825", ) http_archive( name = "rules_proto", sha256 = "6fb6767d1bef535310547e03247f7518b03487740c11b6c6adb7952033fe1295", strip_prefix = "rules_proto-6.0.2", url = "https://github.com/bazelbuild/rules_proto/releases/download/6.0.2/rules_proto-6.0.2.tar.gz", ) def parser_deps(): """ANTLR dependency for the parser.""" # Sept 4, 2023 ANTLR4_VERSION = "4.13.1" http_archive( name = "antlr4_runtimes", build_file_content = """ package(default_visibility = ["//visibility:public"]) cc_library( name = "cpp", srcs = glob(["runtime/Cpp/runtime/src/**/*.cpp"]), hdrs = glob(["runtime/Cpp/runtime/src/**/*.h"]), defines = ["ANTLR4CPP_USING_ABSEIL"], includes = ["runtime/Cpp/runtime/src"], deps = [ "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/synchronization", ], ) """, sha256 = "365ff6aec0b1612fb964a763ca73748d80e0b3379cbdd9f82d86333eb8ae4638", strip_prefix = "antlr4-" + ANTLR4_VERSION, urls = ["https://github.com/antlr/antlr4/archive/refs/tags/" + ANTLR4_VERSION + ".zip"], ) http_jar( name = "antlr4_jar", urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], sha256 = "bc13a9c57a8dd7d5196888211e5ede657cb64a3ce968608697e4f668251a8487", ) def flatbuffers_deps(): """FlatBuffers support.""" FLAT_BUFFERS_SHA = "a83caf5910644ba1c421c002ef68e42f21c15f9f" http_archive( name = "com_github_google_flatbuffers", sha256 = "b8efbc25721e76780752bad775a97c3f77a0250271e2db37fc747b20e8b0f24a", strip_prefix = "flatbuffers-" + FLAT_BUFFERS_SHA, url = "https://github.com/google/flatbuffers/archive/" + FLAT_BUFFERS_SHA + ".tar.gz", ) def cel_spec_deps(): """CEL Spec conformance testing.""" http_archive( name = "io_bazel_rules_go", sha256 = "b2038e2de2cace18f032249cb4bb0048abf583a36369fa98f687af1b3f880b26", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.48.1/rules_go-v0.48.1.zip", "https://github.com/bazelbuild/rules_go/releases/download/v0.48.1/rules_go-v0.48.1.zip", ], ) http_archive( name = "rules_python", sha256 = "e3f1cc7a04d9b09635afb3130731ed82b5f58eadc8233d4efb59944d92ffc06f", strip_prefix = "rules_python-0.33.2", url = "https://github.com/bazelbuild/rules_python/releases/download/0.33.2/rules_python-0.33.2.tar.gz", ) CEL_SPEC_GIT_SHA = "afa18f9bd5a83f5960ca06c1f9faea406ab34ccc" # Dec 2, 2024 http_archive( name = "com_google_cel_spec", sha256 = "19b4084ba33cc8da7a640d999e46731efbec585ad2995951dc61a7af24f059cb", strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], ) _ICU4C_VERSION_MAJOR = "76" _ICU4C_VERSION_MINOR = "1" _ICU4C_BUILD = """ load("@rules_foreign_cc//foreign_cc:configure.bzl", "configure_make") filegroup( name = "all", srcs = glob(["**"]), visibility = ["//visibility:private"], ) config_setting( name = "dbg", values = {{ "compilation_mode": "dbg", }}, visibility = ["//visibility:private"], ) configure_make( name = "icu4c", configure_command = "source/configure", configure_in_place = True, configure_options = [ "--enable-shared", "--enable-static", "--disable-extras", "--disable-icuio", "--disable-layoutex", "--disable-icu-config", ] + select({{ ":dbg": ["--enable-debug"], "//conditions:default": [], }}), lib_source = ":all", out_shared_libs = [ "libicudata.so", "libicudata.so.{version_major}", "libicudata.so.{version_major}.{version_minor}", "libicui18n.so", "libicui18n.so.{version_major}", "libicui18n.so.{version_major}.{version_minor}", "libicutu.so", "libicutu.so.{version_major}", "libicutu.so.{version_major}.{version_minor}", "libicuuc.so", "libicuuc.so.{version_major}", "libicuuc.so.{version_major}.{version_minor}", ], out_static_libs = [ "libicudata.a", "libicui18n.a", "libicutu.a", "libicuuc.a", ], args = ["-j 8"], visibility = ["//visibility:public"], ) """.format(version_major = _ICU4C_VERSION_MAJOR, version_minor = _ICU4C_VERSION_MINOR) def cel_cpp_extensions_deps(): http_archive( name = "rules_foreign_cc", sha256 = "8e5605dc2d16a4229cb8fbe398514b10528553ed4f5f7737b663fdd92f48e1c2", strip_prefix = "rules_foreign_cc-0.13.0", url = "https://github.com/bazel-contrib/rules_foreign_cc/releases/download/0.13.0/rules_foreign_cc-0.13.0.tar.gz", ) http_archive( name = "icu4c", sha256 = "dfacb46bfe4747410472ce3e1144bf28a102feeaa4e3875bac9b4c6cf30f4f3e", url = "https://github.com/unicode-org/icu/releases/download/release-{version_major}-{version_minor}/icu4c-{version_major}_{version_minor}-src.tgz".format(version_major = _ICU4C_VERSION_MAJOR, version_minor = _ICU4C_VERSION_MINOR), strip_prefix = "icu", patch_cmds = [ "rm -f source/common/BUILD.bazel", "rm -f source/i18n/BUILD.bazel", "rm -f source/stubdata/BUILD.bazel", "rm -f source/tools/gennorm2/BUILD.bazel", "rm -f source/tools/toolutil/BUILD.bazel", "rm -f source/tools/unicode/c/genprops/BUILD.bazel", "rm -f source/tools/unicode/c/genuca/BUILD.bazel", "rm -f source/vendor/double-conversion/upstream/WORKSPACE", ], build_file_content = _ICU4C_BUILD, ) def cel_cpp_deps(): """All core dependencies of cel-cpp.""" base_deps() parser_deps() flatbuffers_deps() cel_spec_deps() ================================================ FILE: checker/BUILD ================================================ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. package(default_visibility = ["//visibility:public"]) cc_library( name = "checker_options", hdrs = ["checker_options.h"], ) cc_library( name = "type_check_issue", srcs = ["type_check_issue.cc"], hdrs = ["type_check_issue.h"], deps = [ "//common:source", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "type_check_issue_test", srcs = ["type_check_issue_test.cc"], deps = [ ":type_check_issue", "//common:source", "//internal:testing", ], ) cc_library( name = "validation_result", srcs = ["validation_result.cc"], hdrs = ["validation_result.h"], deps = [ ":type_check_issue", "//common:ast", "//common:source", "//common:type", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) cc_test( name = "validation_result_test", srcs = ["validation_result_test.cc"], deps = [ ":type_check_issue", ":validation_result", "//common:ast", "//common:source", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", ], ) cc_library( name = "type_checker", srcs = ["type_checker.cc"], hdrs = ["type_checker.h"], deps = [ ":validation_result", "//common:ast", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "type_checker_builder", hdrs = ["type_checker_builder.h"], deps = [ ":checker_options", ":type_checker", "//common:container", "//common:decl", "//common:type", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "type_checker_builder_factory", srcs = ["type_checker_builder_factory.cc"], hdrs = ["type_checker_builder_factory.h"], deps = [ ":checker_options", ":type_checker_builder", "//checker/internal:type_checker_impl", "//internal:noop_delete", "//internal:status_macros", "//internal:well_known_types", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "type_checker_builder_factory_test", srcs = ["type_checker_builder_factory_test.cc"], deps = [ ":checker_options", ":optional", ":standard_library", ":type_checker", ":type_checker_builder", ":type_checker_builder_factory", ":validation_result", "//checker/internal:test_ast_helpers", "//common:ast", "//common:decl", "//common:type", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:string_view", ], ) cc_library( name = "standard_library", srcs = ["standard_library.cc"], hdrs = ["standard_library.h"], deps = [ ":type_checker_builder", "//checker/internal:builtins_arena", "//common:constant", "//common:decl", "//common:standard_definitions", "//common:type", "//internal:status_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", ], ) cc_test( name = "standard_library_test", srcs = ["standard_library_test.cc"], deps = [ ":checker_options", ":standard_library", ":type_checker", ":type_checker_builder", ":type_checker_builder_factory", ":validation_result", "//checker/internal:test_ast_helpers", "//common:ast", "//common:constant", "//common:decl", "//common:type", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "optional", srcs = ["optional.cc"], hdrs = ["optional.h"], deps = [ ":type_checker_builder", "//base:builtins", "//checker/internal:builtins_arena", "//common:decl", "//common:type", "//internal:status_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", ], ) cc_test( name = "optional_test", srcs = ["optional_test.cc"], deps = [ ":checker_options", ":optional", ":standard_library", ":type_check_issue", ":type_checker", ":type_checker_builder", ":type_checker_builder_factory", "//checker/internal:test_ast_helpers", "//common:ast", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", ], ) cc_library( name = "type_checker_subset_factory", srcs = ["type_checker_subset_factory.cc"], hdrs = ["type_checker_subset_factory.h"], deps = [ ":type_checker_builder", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) cc_test( name = "type_checker_subset_factory_test", srcs = ["type_checker_subset_factory_test.cc"], deps = [ ":type_checker_subset_factory", ":validation_result", "//common:standard_definitions", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:string_view", ], ) ================================================ FILE: checker/checker_options.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_CHECKER_OPTIONS_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_CHECKER_OPTIONS_H_ namespace cel { // Options for enabling core type checker features. struct CheckerOptions { // Enable overloads for numeric comparisons across types. // For example, 1.0 < 2 will resolve to lt_double_int. // // By default, this is disabled and expressions must explicitly cast to dyn or // the same type to compare. bool enable_cross_numeric_comparisons = false; // Enable legacy behavior for null assignment. // // Historically, CEL has allowed null to be assigned to structs, abstract // types, durations, timestamps, and any types. This is inconsistent with // CEL's usual interpretation of null as a literal JSON null. // // TODO(uncreated-issue/75): Need a concrete plan for updating existing CEL // expressions that depend on the old behavior. bool enable_legacy_null_assignment = true; // Enable updating parsed struct type names to the fully qualified type name // when resolved. // // Enabled by default, but can be disabled to preserve the original type name // as parsed. bool update_struct_type_names = true; // Temporary flag to enable type parameter name validation. // // When enabled, the TypeCheckerBuilder will validate that type parameter // names are simple identifiers when declared. bool enable_type_parameter_name_validation = true; // Well-known types defined by protobuf are treated specially in CEL, and // generally don't behave like other messages as runtime values. When used as // context declarations, this introduces some ambiguity about the intended // types of the field declarations, so it is disallowed by default. // // When enabled, the well-known types are treated like a normal message type // for the purposes for declaring context bindings (i.e no unpacking or // adapting), and use the Descriptor that is assumed by CEL. // // E.g. for google.protobuf.Any, the type checker will add a context binding // with `type_url: string` and `value: bytes` as top level variables. bool allow_well_known_type_context_declarations = false; // Maximum number (inclusive) of expression nodes to check for an input // expression. // // If exceeded, the checker should return a status with code InvalidArgument. int max_expression_node_count = 100000; // Maximum number (inclusive) of error-level issues to tolerate for an input // ast. // // If exceeded, the checker will stop processing the ast and return // the current set of issues. int max_error_issues = 20; // Maximum amount of nesting allowed for type declarations in function // signatures and variable declarations. // // If exceeded, the TypeCheckerBuilder will report an error when adding the // declaration. // // For untrusted declarations, the caller should set a lower limit to mitigate // expressions that compound nesting e.g. // type5(T)->type(type(type(type(type(T)))))); type5(type5(T)) -> type10(T) int max_type_decl_nesting = 13; // If true, the checker will include the resolved function name in the // reference map for the function call expr. // // If false, the function name will be empty and implied by the overload id // set. This matches the behavior in cel-go and cel-java. // // Temporary flag to allow rolling out the change. No functional changes to // evaluation behavior in either mode. bool enable_function_name_in_reference = true; // If true, the checker will use the proto json field names for protobuf // messages. Unlike protojson parsers, it will not accept the standard proto // field names as valid json field names. // // Note: The checked AST will contain the json field names and an extension // tag, but will require runtime support for resolving the json field names. bool use_json_field_names = false; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_CHECKER_OPTIONS_H_ ================================================ FILE: checker/internal/BUILD ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package( # Implementation details for the checker library. default_visibility = ["//visibility:public"], ) cc_library( name = "test_ast_helpers", testonly = 1, srcs = ["test_ast_helpers.cc"], hdrs = ["test_ast_helpers.h"], deps = [ "//common:ast", "//internal:status_macros", "//parser", "//parser:options", "//parser:parser_interface", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "test_ast_helpers_test", srcs = ["test_ast_helpers_test.cc"], deps = [ ":test_ast_helpers", "//common:ast", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", ], ) cc_library( name = "builtins_arena", srcs = ["builtins_arena.cc"], hdrs = ["builtins_arena.h"], deps = [ "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "type_check_env", srcs = ["type_check_env.cc"], hdrs = ["type_check_env.h"], deps = [ ":descriptor_pool_type_introspector", "//common:constant", "//common:container", "//common:decl", "//common:type", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "namespace_generator", srcs = ["namespace_generator.cc"], hdrs = ["namespace_generator.h"], deps = [ "//common:container", "//internal:lexis", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) cc_test( name = "namespace_generator_test", srcs = ["namespace_generator_test.cc"], deps = [ ":namespace_generator", "//common:container", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", ], ) cc_library( name = "type_checker_impl", srcs = [ "type_checker_builder_impl.cc", "type_checker_impl.cc", ], hdrs = [ "type_checker_builder_impl.h", "type_checker_impl.h", ], deps = [ ":format_type_name", ":namespace_generator", ":type_check_env", ":type_inference_context", "//checker:checker_options", "//checker:type_check_issue", "//checker:type_checker", "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:ast_rewrite", "//common:ast_traverse", "//common:ast_visitor", "//common:ast_visitor_base", "//common:constant", "//common:container", "//common:decl", "//common:expr", "//common:type", "//common:type_kind", "//internal:lexis", "//internal:status_macros", "//parser:macro", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "type_checker_impl_test", srcs = ["type_checker_impl_test.cc"], deps = [ ":test_ast_helpers", ":type_check_env", ":type_checker_impl", "//checker:checker_options", "//checker:type_check_issue", "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:container", "//common:decl", "//common:expr", "//common:source", "//common:type", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//testutil:baseline_tests", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "type_checker_builder_impl_test", srcs = ["type_checker_builder_impl_test.cc"], deps = [ ":test_ast_helpers", ":type_checker_impl", "//checker:checker_options", "//checker:type_checker", "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:decl", "//common:type", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "type_inference_context", srcs = ["type_inference_context.cc"], hdrs = ["type_inference_context.h"], deps = [ ":format_type_name", "//common:decl", "//common:type", "//common:type_kind", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "type_inference_context_test", srcs = ["type_inference_context_test.cc"], deps = [ ":type_inference_context", "//common:decl", "//common:type", "//common:type_kind", "//internal:testing", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "format_type_name", srcs = ["format_type_name.cc"], hdrs = ["format_type_name.h"], deps = [ "//common:type", "//common:type_kind", "@com_google_absl//absl/strings", ], ) cc_library( name = "descriptor_pool_type_introspector", srcs = ["descriptor_pool_type_introspector.cc"], hdrs = ["descriptor_pool_type_introspector.h"], deps = [ "//common:type", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "descriptor_pool_type_introspector_test", srcs = ["descriptor_pool_type_introspector_test.cc"], deps = [ ":descriptor_pool_type_introspector", "//common:type", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", ], ) ================================================ FILE: checker/internal/builtins_arena.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/builtins_arena.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { google::protobuf::Arena* absl_nonnull BuiltinsArena() { static absl::NoDestructor kArena; return &(*kArena); } } // namespace cel::checker_internal ================================================ FILE: checker/internal/builtins_arena.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_BUILTINS_ARENA_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_BUILTINS_ARENA_H_ #include "absl/base/nullability.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { // Shared arena for builtin types that are shared across all type checker // instances. google::protobuf::Arena* absl_nonnull BuiltinsArena(); } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_BUILTINS_ARENA_H_ ================================================ FILE: checker/internal/descriptor_pool_type_introspector.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/descriptor_pool_type_introspector.h" #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/type.h" #include "common/type_introspector.h" #include "google/protobuf/descriptor.h" namespace cel::checker_internal { namespace { // Standard implementation for field lookups. // Avoids building a FieldTable and just checks the DescriptorPool directly. absl::StatusOr> FindStructTypeFieldByNameDirectly( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl::string_view type, absl::string_view name) { const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool->FindMessageTypeByName(type); if (descriptor == nullptr) { return absl::nullopt; } const google::protobuf::FieldDescriptor* absl_nullable field = descriptor->FindFieldByName(name); if (field != nullptr) { return StructTypeField(MessageTypeField(field)); } field = descriptor_pool->FindExtensionByPrintableName(descriptor, name); if (field != nullptr) { return StructTypeField(MessageTypeField(field)); } return absl::nullopt; } // Standard implementation for listing fields. // Avoids building a FieldTable and just checks the DescriptorPool directly. absl::StatusOr< absl::optional>> ListStructTypeFieldsDirectly( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl::string_view type) { const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool->FindMessageTypeByName(type); if (descriptor == nullptr) { return absl::nullopt; } std::vector extensions; descriptor_pool->FindAllExtensions(descriptor, &extensions); std::vector fields; fields.reserve(descriptor->field_count() + extensions.size()); for (int i = 0; i < descriptor->field_count(); ++i) { const google::protobuf::FieldDescriptor* field = descriptor->field(i); fields.push_back({field->name(), StructTypeField(MessageTypeField(field))}); } return fields; } } // namespace using Field = DescriptorPoolTypeIntrospector::Field; absl::StatusOr> DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool_->FindMessageTypeByName(name); if (descriptor != nullptr) { return Type::Message(descriptor); } const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = descriptor_pool_->FindEnumTypeByName(name); if (enum_descriptor != nullptr) { return Type::Enum(enum_descriptor); } return absl::nullopt; } absl::StatusOr> DescriptorPoolTypeIntrospector::FindEnumConstantImpl( absl::string_view type, absl::string_view value) const { const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = descriptor_pool_->FindEnumTypeByName(type); if (enum_descriptor != nullptr) { const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = enum_descriptor->FindValueByName(value); if (enum_value_descriptor == nullptr) { return absl::nullopt; } return EnumConstant{ .type = Type::Enum(enum_descriptor), .type_full_name = enum_descriptor->full_name(), .value_name = enum_value_descriptor->name(), .number = enum_value_descriptor->number(), }; } return absl::nullopt; } absl::StatusOr> DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { if (!use_json_name_) { return FindStructTypeFieldByNameDirectly(descriptor_pool_, type, name); } const FieldTable* field_table = GetFieldTable(type); if (field_table == nullptr) { return absl::nullopt; } if (auto it = field_table->json_name_map.find(name); it != field_table->json_name_map.end()) { return field_table->fields[it->second].field; } if (auto it = field_table->extension_name_map.find(name); it != field_table->extension_name_map.end()) { return field_table->fields[it->second].field; } return absl::nullopt; } absl::StatusOr< absl::optional>> DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( absl::string_view type) const { if (!use_json_name_) { return ListStructTypeFieldsDirectly(descriptor_pool_, type); } const FieldTable* field_table = GetFieldTable(type); if (field_table == nullptr) { return absl::nullopt; } std::vector fields; fields.reserve(field_table->non_extensions.size()); for (const auto& field : field_table->non_extensions) { fields.push_back({field.json_name, field.field}); } return fields; } const DescriptorPoolTypeIntrospector::FieldTable* DescriptorPoolTypeIntrospector::GetFieldTable( absl::string_view type_name) const { absl::MutexLock lock(mu_); if (auto it = field_tables_.find(type_name); it != field_tables_.end()) { return it->second.get(); } if (cel::IsWellKnownMessageType(type_name)) { return nullptr; } const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool_->FindMessageTypeByName(type_name); if (descriptor == nullptr) { return nullptr; } absl::string_view stable_type_name = descriptor->full_name(); ABSL_DCHECK(stable_type_name == type_name); std::unique_ptr field_table = CreateFieldTable(descriptor); const FieldTable* field_table_ptr = field_table.get(); field_tables_[stable_type_name] = std::move(field_table); return field_table_ptr; } std::unique_ptr DescriptorPoolTypeIntrospector::CreateFieldTable( const google::protobuf::Descriptor* absl_nonnull descriptor) const { ABSL_DCHECK(!IsWellKnownMessageType(descriptor)); std::vector fields; absl::flat_hash_map json_name_map; absl::flat_hash_map field_name_map; absl::flat_hash_map extension_name_map; std::vector extensions; descriptor_pool_->FindAllExtensions(descriptor, &extensions); fields.reserve(descriptor->field_count() + extensions.size()); for (int i = 0; i < descriptor->field_count(); i++) { const google::protobuf::FieldDescriptor* field = descriptor->field(i); fields.push_back(Field{ .field = StructTypeField(MessageTypeField(field)), .json_name = field->json_name(), .is_extension = false, }); field_name_map[field->name()] = fields.size() - 1; if (use_json_name_ && !field->json_name().empty()) { json_name_map[field->json_name()] = fields.size() - 1; } } int non_extension_count = fields.size(); for (const google::protobuf::FieldDescriptor* extension : extensions) { fields.push_back(Field{ .field = StructTypeField(MessageTypeField(extension)), .json_name = "", .is_extension = true, }); extension_name_map[extension->full_name()] = fields.size() - 1; } int extension_count = fields.size() - non_extension_count; auto result = std::make_unique(); result->descriptor = descriptor; result->fields = std::move(fields); result->non_extensions = absl::MakeConstSpan(result->fields).subspan(0, non_extension_count); result->extensions = absl::MakeConstSpan(result->fields) .subspan(non_extension_count, extension_count); result->json_name_map = std::move(json_name_map); result->field_name_map = std::move(field_name_map); result->extension_name_map = std::move(extension_name_map); return result; } } // namespace cel::checker_internal ================================================ FILE: checker/internal/descriptor_pool_type_introspector.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ #include #include #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/type.h" #include "common/type_introspector.h" #include "google/protobuf/descriptor.h" namespace cel::checker_internal { // Implementation of `TypeIntrospector` that uses a `google::protobuf::DescriptorPool`. // // This is used by the type checker to resolve protobuf types and their fields // and apply any options like using JSON names. // // Neither copyable nor movable. Should be managed by a TypeCheckEnv. class DescriptorPoolTypeIntrospector : public TypeIntrospector { public: struct Field { StructTypeField field; absl::string_view json_name; bool is_extension = false; }; DescriptorPoolTypeIntrospector() = delete; explicit DescriptorPoolTypeIntrospector( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) : descriptor_pool_(descriptor_pool) {} DescriptorPoolTypeIntrospector(const DescriptorPoolTypeIntrospector&) = delete; DescriptorPoolTypeIntrospector& operator=( const DescriptorPoolTypeIntrospector&) = delete; DescriptorPoolTypeIntrospector(DescriptorPoolTypeIntrospector&&) = delete; DescriptorPoolTypeIntrospector& operator=(DescriptorPoolTypeIntrospector&&) = delete; void set_use_json_name(bool use_json_name) { use_json_name_ = use_json_name; } bool use_json_name() const { return use_json_name_; } private: struct FieldTable { const google::protobuf::Descriptor* absl_nonnull descriptor; std::vector fields; absl::Span non_extensions; absl::Span extensions; absl::flat_hash_map json_name_map; absl::flat_hash_map field_name_map; absl::flat_hash_map extension_name_map; }; absl::StatusOr> FindTypeImpl( absl::string_view name) const final; absl::StatusOr> FindEnumConstantImpl( absl::string_view type, absl::string_view value) const final; absl::StatusOr> FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const final; absl::StatusOr>> ListFieldsForStructTypeImpl(absl::string_view type) const final; std::unique_ptr CreateFieldTable( const google::protobuf::Descriptor* absl_nonnull descriptor) const; const FieldTable* GetFieldTable(absl::string_view type_name) const; // Cached map of type to field table. mutable absl::flat_hash_map> field_tables_ ABSL_GUARDED_BY(mu_); mutable absl::Mutex mu_; bool use_json_name_ = false; const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; }; } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ ================================================ FILE: checker/internal/descriptor_pool_type_introspector_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/descriptor_pool_type_introspector.h" #include #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "common/type.h" #include "common/type_introspector.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" namespace cel::checker_internal { namespace { using ::absl_testing::IsOkAndHolds; using ::testing::AllOf; using ::testing::Contains; using ::testing::Eq; using ::testing::Not; using ::testing::Optional; using ::testing::Property; using ::testing::SizeIs; using ::testing::Truly; TEST(DescriptorPoolTypeIntrospectorTest, FindType) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); EXPECT_THAT(introspector.FindType("cel.expr.conformance.proto3.TestAllTypes"), IsOkAndHolds(Optional(Property(&Type::IsMessage, true)))); EXPECT_THAT(introspector.FindType( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"), IsOkAndHolds(Optional(Property(&Type::IsEnum, true)))); EXPECT_THAT(introspector.FindType("non.existent.Type"), IsOkAndHolds(Eq(absl::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, FindEnumConstant) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); auto result = introspector.FindEnumConstant( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", "FOO"); ASSERT_THAT(result, IsOkAndHolds(Optional(AllOf( Truly([](const TypeIntrospector::EnumConstant& v) { return v.value_name == "FOO" && v.number == 0; }))))); } TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByName) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); auto field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); introspector.set_use_json_name(false); ASSERT_THAT(field, IsOkAndHolds(Optional(Property(&StructTypeField::GetType, Property(&Type::IsInt, true))))); } TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameJsonNameIgnored) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); introspector.set_use_json_name(false); auto field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); EXPECT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, FindExtension) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); auto field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto2.TestAllTypes", "cel.expr.conformance.proto2.int32_ext"); ASSERT_THAT(field, IsOkAndHolds(Optional(Property(&StructTypeField::GetType, Property(&Type::IsInt, true))))); } TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameWithJsonOpt) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); introspector.set_use_json_name(true); auto field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); ASSERT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameWithJsonNameOpt) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); introspector.set_use_json_name(true); absl::StatusOr> field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); ASSERT_THAT(field, IsOkAndHolds(Optional(Property(&StructTypeField::GetType, Property(&Type::IsInt, true))))); } MATCHER_P(FieldListingIs, field_name, "") { return arg.name == field_name; } TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructType) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); absl::StatusOr< absl::optional>> fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto3.TestAllTypes"); ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); EXPECT_THAT(*fields, Optional(Contains(FieldListingIs("single_int64")))); } TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeExtensions) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); auto fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto2.TestAllTypes"); ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(259)))); EXPECT_THAT(**fields, Contains(FieldListingIs("single_int64"))); EXPECT_THAT( **fields, Not(Contains(FieldListingIs("cel.expr.conformance.proto2.int32_ext")))); } TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeWithJsonNameOpt) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); introspector.set_use_json_name(true); auto fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto3.TestAllTypes"); ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); EXPECT_THAT(**fields, Contains(FieldListingIs("singleInt64"))); EXPECT_THAT(**fields, Not(Contains(FieldListingIs("single_int64")))); } TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeNotFound) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); auto fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto3.SomeOtherType"); EXPECT_THAT(fields, IsOkAndHolds(Eq(absl::nullopt))); } } // namespace } // namespace cel::checker_internal ================================================ FILE: checker/internal/format_type_name.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/format_type_name.h" #include #include #include "absl/strings/str_cat.h" #include "common/type.h" #include "common/type_kind.h" namespace cel::checker_internal { namespace { struct FormatImplRecord { Type type; int offset; }; // Parameterized types can be arbitrarily nested, so we use a vector as // a stack to avoid overflow. Practically, we don't expect nesting // to ever be very deep, but fuzzers and pathological inputs can easily // trigger stack overflow with a recursive implementation. void FormatImpl(const Type& cur, int offset, std::vector& stack, std::string* out) { switch (cur.kind()) { case TypeKind::kDyn: absl::StrAppend(out, "dyn"); return; case TypeKind::kAny: absl::StrAppend(out, "any"); return; case TypeKind::kBool: absl::StrAppend(out, "bool"); return; case TypeKind::kBoolWrapper: absl::StrAppend(out, "wrapper(bool)"); return; case TypeKind::kBytes: absl::StrAppend(out, "bytes"); return; case TypeKind::kBytesWrapper: absl::StrAppend(out, "wrapper(bytes)"); return; case TypeKind::kDouble: absl::StrAppend(out, "double"); return; case TypeKind::kDoubleWrapper: absl::StrAppend(out, "wrapper(double)"); return; case TypeKind::kDuration: absl::StrAppend(out, "google.protobuf.Duration"); return; case TypeKind::kEnum: absl::StrAppend(out, "int"); return; case TypeKind::kInt: absl::StrAppend(out, "int"); return; case TypeKind::kIntWrapper: absl::StrAppend(out, "wrapper(int)"); return; case TypeKind::kList: if (offset == 0) { absl::StrAppend(out, "list("); stack.push_back({cur, 1}); stack.push_back({cur.AsList()->GetElement(), 0}); } else { absl::StrAppend(out, ")"); } return; case TypeKind::kMap: if (offset == 0) { absl::StrAppend(out, "map("); stack.push_back({cur, 1}); stack.push_back({cur.AsMap()->GetKey(), 0}); return; } if (offset == 1) { absl::StrAppend(out, ", "); stack.push_back({cur, 2}); stack.push_back({cur.AsMap()->GetValue(), 0}); return; } absl::StrAppend(out, ")"); return; case TypeKind::kNull: absl::StrAppend(out, "null_type"); return; case TypeKind::kOpaque: { OpaqueType opaque = *cur.AsOpaque(); if (offset == 0) { absl::StrAppend(out, cur.AsOpaque()->name()); if (!opaque.GetParameters().empty()) { absl::StrAppend(out, "("); stack.push_back({cur, 1}); stack.push_back({cur.AsOpaque()->GetParameters()[0], 0}); } return; } if (offset >= opaque.GetParameters().size()) { absl::StrAppend(out, ")"); return; } absl::StrAppend(out, ", "); stack.push_back({cur, offset + 1}); stack.push_back({cur.AsOpaque()->GetParameters()[offset], 0}); return; } case TypeKind::kString: absl::StrAppend(out, "string"); return; case TypeKind::kStringWrapper: absl::StrAppend(out, "wrapper(string)"); return; case TypeKind::kStruct: absl::StrAppend(out, cur.AsStruct()->name()); return; case TypeKind::kTimestamp: absl::StrAppend(out, "google.protobuf.Timestamp"); return; case TypeKind::kType: { TypeType type_type = *cur.AsType(); if (offset == 0) { absl::StrAppend(out, type_type.name()); if (!type_type.GetParameters().empty()) { absl::StrAppend(out, "("); stack.push_back({cur, 1}); stack.push_back({cur.AsType()->GetParameters()[0], 0}); } return; } absl::StrAppend(out, ")"); return; } case TypeKind::kTypeParam: absl::StrAppend(out, cur.AsTypeParam()->name()); return; case TypeKind::kUint: absl::StrAppend(out, "uint"); return; case TypeKind::kUintWrapper: absl::StrAppend(out, "wrapper(uint)"); return; case TypeKind::kUnknown: absl::StrAppend(out, "*unknown*"); return; case TypeKind::kError: case TypeKind::kFunction: default: absl::StrAppend(out, "*error*"); return; } } } // namespace std::string FormatTypeName(const Type& type) { std::vector stack; std::string out; stack.push_back({type, 0}); while (!stack.empty()) { auto [type, offset] = stack.back(); stack.pop_back(); FormatImpl(type, offset, stack, &out); } return out; } } // namespace cel::checker_internal ================================================ FILE: checker/internal/format_type_name.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ #include #include "common/type.h" namespace cel::checker_internal { // Format the type name for presentation in error messages. Matches the // formatting used in github.com/cel-spec. std::string FormatTypeName(const Type& type); } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ ================================================ FILE: checker/internal/format_type_name_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/format_type_name.h" #include "common/type.h" #include "internal/testing.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { namespace { using ::cel::expr::conformance::proto2::GlobalEnum_descriptor; using ::cel::expr::conformance::proto2::TestAllTypes; using ::testing::MatchesRegex; TEST(FormatTypeNameTest, PrimitiveTypes) { EXPECT_EQ(FormatTypeName(IntType()), "int"); EXPECT_EQ(FormatTypeName(UintType()), "uint"); EXPECT_EQ(FormatTypeName(DoubleType()), "double"); EXPECT_EQ(FormatTypeName(StringType()), "string"); EXPECT_EQ(FormatTypeName(BytesType()), "bytes"); EXPECT_EQ(FormatTypeName(BoolType()), "bool"); EXPECT_EQ(FormatTypeName(NullType()), "null_type"); EXPECT_EQ(FormatTypeName(DynType()), "dyn"); } TEST(FormatTypeNameTest, SpecialTypes) { EXPECT_EQ(FormatTypeName(ErrorType()), "*error*"); EXPECT_EQ(FormatTypeName(UnknownType()), "*unknown*"); EXPECT_EQ(FormatTypeName(FunctionType()), "*error*"); } TEST(FormatTypeNameTest, WellKnownTypes) { EXPECT_EQ(FormatTypeName(AnyType()), "any"); EXPECT_EQ(FormatTypeName(DurationType()), "google.protobuf.Duration"); EXPECT_EQ(FormatTypeName(TimestampType()), "google.protobuf.Timestamp"); } TEST(FormatTypeNameTest, Wrappers) { EXPECT_EQ(FormatTypeName(IntWrapperType()), "wrapper(int)"); EXPECT_EQ(FormatTypeName(UintWrapperType()), "wrapper(uint)"); EXPECT_EQ(FormatTypeName(DoubleWrapperType()), "wrapper(double)"); EXPECT_EQ(FormatTypeName(StringWrapperType()), "wrapper(string)"); EXPECT_EQ(FormatTypeName(BytesWrapperType()), "wrapper(bytes)"); EXPECT_EQ(FormatTypeName(BoolWrapperType()), "wrapper(bool)"); } TEST(FormatTypeNameTest, ProtobufTypes) { EXPECT_EQ(FormatTypeName(MessageType(TestAllTypes::descriptor())), "cel.expr.conformance.proto2.TestAllTypes"); EXPECT_EQ(FormatTypeName(EnumType(GlobalEnum_descriptor())), "int"); } TEST(FormatTypeNameTest, Type) { google::protobuf::Arena arena; EXPECT_EQ(FormatTypeName(TypeType()), "type"); EXPECT_EQ(FormatTypeName(TypeType(&arena, IntType())), "type(int)"); EXPECT_EQ(FormatTypeName(TypeType(&arena, TypeType(&arena, IntType()))), "type(type(int))"); EXPECT_EQ(FormatTypeName(TypeType(&arena, TypeParamType("T"))), "type(T)"); } TEST(FormatTypeNameTest, List) { google::protobuf::Arena arena; EXPECT_EQ(FormatTypeName(ListType()), "list(dyn)"); EXPECT_EQ(FormatTypeName(ListType(&arena, IntType())), "list(int)"); EXPECT_EQ(FormatTypeName(ListType(&arena, ListType(&arena, IntType()))), "list(list(int))"); } TEST(FormatTypeNameTest, Map) { google::protobuf::Arena arena; EXPECT_EQ(FormatTypeName(MapType()), "map(dyn, dyn)"); EXPECT_EQ(FormatTypeName(MapType(&arena, IntType(), IntType())), "map(int, int)"); EXPECT_EQ(FormatTypeName(MapType(&arena, IntType(), MapType(&arena, IntType(), IntType()))), "map(int, map(int, int))"); } TEST(FormatTypeNameTest, Opaque) { google::protobuf::Arena arena; EXPECT_EQ(FormatTypeName(OpaqueType(&arena, "opaque", {})), "opaque"); Type two_tuple_type = OpaqueType(&arena, "tuple", {IntType(), IntType()}); Type three_tuple_type = OpaqueType( &arena, "tuple", {two_tuple_type, two_tuple_type, two_tuple_type}); EXPECT_EQ(FormatTypeName(three_tuple_type), "tuple(tuple(int, int), tuple(int, int), tuple(int, int))"); } #ifndef __APPLE__ TEST(FormatTypeNameTest, ArbitraryNesting) { google::protobuf::Arena arena; Type type = IntType(); for (int i = 0; i < 1000; ++i) { type = OpaqueType(&arena, "ptype", {type}); } EXPECT_THAT(FormatTypeName(type), MatchesRegex(R"(^(ptype\(){1000}int(\)){1000})")); } #endif } // namespace } // namespace cel::checker_internal ================================================ FILE: checker/internal/namespace_generator.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/namespace_generator.h" #include #include #include #include #include "absl/functional/function_ref.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/container.h" #include "internal/lexis.h" namespace cel::checker_internal { namespace { bool FieldSelectInterpretationCandidatesImpl( absl::string_view prefix, absl::Span partly_qualified_name, bool prefix_is_alias, absl::FunctionRef callback) { for (int i = 0; i < partly_qualified_name.size(); ++i) { std::string buf; int count = partly_qualified_name.size() - i; auto end_idx = count - (prefix_is_alias ? 0 : 1); auto ident = absl::StrJoin(partly_qualified_name.subspan(0, count), "."); absl::string_view candidate = ident; if (absl::StartsWith(candidate, ".")) { candidate = candidate.substr(1); } if (!prefix.empty()) { buf = absl::StrCat(prefix, ".", candidate); candidate = buf; } if (!callback(candidate, end_idx)) { return false; } } if (prefix_is_alias) { return callback(prefix, 0); } return true; } bool FieldSelectInterpretationCandidates( absl::string_view prefix, absl::Span partly_qualified_name, absl::FunctionRef callback) { return FieldSelectInterpretationCandidatesImpl( prefix, partly_qualified_name, /*prefix_is_alias=*/false, callback); } bool FieldSelectInterpretationCandidatesWithAlias( absl::string_view prefix, absl::Span partly_qualified_name, absl::FunctionRef callback) { return FieldSelectInterpretationCandidatesImpl( prefix, partly_qualified_name, /*prefix_is_alias=*/true, callback); } } // namespace absl::StatusOr NamespaceGenerator::Create( const ExpressionContainer& expression_container) { std::vector candidates; absl::string_view container = expression_container.container(); if (container.empty()) { return NamespaceGenerator(&expression_container, std::move(candidates)); } std::string prefix; for (auto segment : absl::StrSplit(container, '.')) { // Assumes the the ExpressionContainer has already validated the container // and aliases. ABSL_DCHECK(internal::LexisIsIdentifier(segment)); if (prefix.empty()) { prefix = segment; } else { absl::StrAppend(&prefix, ".", segment); } candidates.push_back(prefix); } std::reverse(candidates.begin(), candidates.end()); return NamespaceGenerator(&expression_container, std::move(candidates)); } void NamespaceGenerator::GenerateCandidates( absl::string_view simple_name, absl::FunctionRef callback) const { // Special case for root-relative names. Aliases still apply first. bool is_root_relative = absl::StartsWith(simple_name, "."); if (is_root_relative) { simple_name = simple_name.substr(1); } // The name is unqualified, but may include a namespace (struct creation). // This is just a quirk of the parser. if (auto dot_pos = simple_name.find('.'); dot_pos != absl::string_view::npos) { absl::string_view first_segment = simple_name.substr(0, dot_pos); absl::string_view rest = simple_name.substr(dot_pos + 1); if (auto resolved_alias = expression_container_->FindAlias(first_segment); !resolved_alias.empty()) { callback(absl::StrCat(resolved_alias, ".", rest)); return; } } else { if (auto resolved_alias = expression_container_->FindAlias(simple_name); !resolved_alias.empty()) { callback(resolved_alias); return; } } if (is_root_relative) { callback(simple_name); return; } for (const auto& prefix : candidates_) { std::string candidate = absl::StrCat(prefix, ".", simple_name); if (!callback(candidate)) { return; } } callback(simple_name); } void NamespaceGenerator::GenerateCandidates( absl::Span partly_qualified_name, absl::FunctionRef callback) const { if (partly_qualified_name.empty()) { return; } // Special case for root-relative names. Aliases still apply first. absl::string_view first_segment = partly_qualified_name[0]; bool is_root_relative = absl::StartsWith(first_segment, "."); if (is_root_relative) { first_segment = first_segment.substr(1); } if (auto resolved_alias = expression_container_->FindAlias(first_segment); !resolved_alias.empty()) { FieldSelectInterpretationCandidatesWithAlias( resolved_alias, partly_qualified_name.subspan(1), callback); // If the alias matches, we don't check the container even if name // resolution fails. return; } if (is_root_relative) { FieldSelectInterpretationCandidates("", partly_qualified_name, callback); return; } for (const auto& prefix : candidates_) { if (!FieldSelectInterpretationCandidates(prefix, partly_qualified_name, callback)) { return; } } FieldSelectInterpretationCandidates("", partly_qualified_name, callback); } } // namespace cel::checker_internal ================================================ FILE: checker/internal/namespace_generator.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_NAMESPACE_GENERATOR_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_NAMESPACE_GENERATOR_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/container.h" namespace cel::checker_internal { // Utility class for generating namespace qualified candidates for reference // resolution. // // This class is expected to be scoped to a single type checking operation and // borrows the ExpressionContainer from the TypeCheckEnv. class NamespaceGenerator { public: static absl::StatusOr Create( const ExpressionContainer& expression_container ABSL_ATTRIBUTE_LIFETIME_BOUND); // Copyable and movable. NamespaceGenerator(const NamespaceGenerator&) = default; NamespaceGenerator& operator=(const NamespaceGenerator&) = default; NamespaceGenerator(NamespaceGenerator&&) = default; NamespaceGenerator& operator=(NamespaceGenerator&&) = default; // For the simple case of an unqualified name, generate all qualified // candidates and pass them to the provided callback. The callback may return // false to terminate early. // // The supplied string_view is only valid for the duration of the callback // invocation: the callback must handle copying the underlying string if the // value needs to be persisted. // // Example: // For container (com.google) // and unqualified name foo // // com.google.foo, com.foo, foo // // If aliases are present, they override the normal container resolution. // // Example: // container (com.google) // alias (foo = com.example) // unqualified name foo // // com.example void GenerateCandidates( absl::string_view simple_name, absl::FunctionRef callback) const; // For a partially qualified name, generate all the qualified candidates in // order of resolution precedence and pass them to the provided callback. The // callback may return false to terminate early. // // The supplied string_view is only valid for the duration of the callback // invocation: the callback must handle copying the underlying string if the // value needs to be persisted. // // Example: // For container (com.google) // and partially qualified name Foo.bar // // (com.google.Foo.bar), // (com.google.Foo).bar, // (com.Foo.bar), // (com.Foo).bar, // (Foo.bar), // (Foo).bar, // // If aliases are present, they override the normal container resolution. // // Example: // container (com.google) // alias (Foo = com.example.Foo) // partially qualified name Foo.bar // // (com.example.Foo.bar), // (com.example.Foo).bar, void GenerateCandidates( absl::Span partly_qualified_name, absl::FunctionRef callback) const; private: explicit NamespaceGenerator( const ExpressionContainer* absl_nonnull expression_container, std::vector candidates) : candidates_(std::move(candidates)), expression_container_(expression_container) {} // list of prefixes ordered from most qualified to least. std::vector candidates_; const ExpressionContainer* absl_nonnull expression_container_; }; } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_NAMESPACE_GENERATOR_H_ ================================================ FILE: checker/internal/namespace_generator_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/namespace_generator.h" #include #include #include #include "absl/strings/string_view.h" #include "common/container.h" #include "internal/testing.h" namespace cel::checker_internal { namespace { using ::absl_testing::IsOk; using ::testing::ElementsAre; using ::testing::Pair; TEST(NamespaceGeneratorTest, EmptyContainer) { ExpressionContainer container; ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates("foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); return true; }); EXPECT_THAT(candidates, ElementsAre("foo")); } TEST(NamespaceGeneratorTest, MultipleSegments) { ExpressionContainer container; ASSERT_THAT(container.SetContainer("com.example"), IsOk()); ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates("foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); return true; }); EXPECT_THAT(candidates, ElementsAre("com.example.foo", "com.foo", "foo")); } TEST(NamespaceGeneratorTest, MultipleSegmentsRootNamespace) { ExpressionContainer container; ASSERT_THAT(container.SetContainer("com.example"), IsOk()); ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates(".foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); return true; }); EXPECT_THAT(candidates, ElementsAre("foo")); } TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { ExpressionContainer container; ASSERT_THAT(container.SetContainer("com.example"), IsOk()); ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector qualified_ident = {"foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( qualified_ident, [&](absl::string_view candidate, int segment_index) { candidates.push_back(std::pair(std::string(candidate), segment_index)); return true; }); EXPECT_THAT( candidates, ElementsAre(Pair("com.example.foo.Bar", 1), Pair("com.example.foo", 0), Pair("com.foo.Bar", 1), Pair("com.foo", 0), Pair("foo.Bar", 1), Pair("foo", 0))); } TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasMatch) { ExpressionContainer container; ASSERT_THAT(container.SetContainer("com.example"), IsOk()); ASSERT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector qualified_ident = {"foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( qualified_ident, [&](absl::string_view candidate, int segment_index) { candidates.push_back(std::pair(std::string(candidate), segment_index)); return true; }); EXPECT_THAT(candidates, ElementsAre(Pair("bar.baz.Bar", 1), Pair("bar.baz", 0))); } TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasNoMatch) { ExpressionContainer container; ASSERT_THAT(container.SetContainer("com.example"), IsOk()); ASSERT_THAT(container.AddAbbreviation("foo.Bar"), IsOk()); ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); // No match on the alias (Bar) since it's not the first segment. std::vector qualified_ident = {"foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( qualified_ident, [&](absl::string_view candidate, int segment_index) { candidates.push_back(std::pair(std::string(candidate), segment_index)); return true; }); EXPECT_THAT( candidates, ElementsAre(Pair("com.example.foo.Bar", 1), Pair("com.example.foo", 0), Pair("com.foo.Bar", 1), Pair("com.foo", 0), Pair("foo.Bar", 1), Pair("foo", 0))); } TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationRootNamespace) { ExpressionContainer container; ASSERT_THAT(container.SetContainer("com.example"), IsOk()); ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector qualified_ident = {".foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( qualified_ident, [&](absl::string_view candidate, int segment_index) { candidates.push_back(std::pair(std::string(candidate), segment_index)); return true; }); EXPECT_THAT(candidates, ElementsAre(Pair("foo.Bar", 1), Pair("foo", 0))); } } // namespace } // namespace cel::checker_internal ================================================ FILE: checker/internal/test_ast_helpers.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/test_ast_helpers.h" #include #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/ast.h" #include "internal/status_macros.h" #include "parser/options.h" #include "parser/parser.h" #include "parser/parser_interface.h" namespace cel::checker_internal { absl::StatusOr> MakeTestParsedAst( absl::string_view expression) { static const cel::Parser* parser = []() { cel::ParserOptions options = {.enable_optional_syntax = true}; auto parser = NewParserBuilder(options)->Build(); ABSL_CHECK_OK(parser); return parser->release(); }(); CEL_ASSIGN_OR_RETURN( auto source, cel::NewSource(expression, /*description=*/std::string(expression))); return parser->Parse(*source); } } // namespace cel::checker_internal ================================================ FILE: checker/internal/test_ast_helpers.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TESTING_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TESTING_H_ #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/ast.h" namespace cel::checker_internal { absl::StatusOr> MakeTestParsedAst( absl::string_view expression); } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TESTING_H_ ================================================ FILE: checker/internal/test_ast_helpers_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/test_ast_helpers.h" #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "common/ast.h" #include "internal/testing.h" namespace cel::checker_internal { namespace { using ::absl_testing::StatusIs; TEST(MakeTestParsedAstTest, Works) { ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, MakeTestParsedAst("123")); EXPECT_TRUE(ast->root_expr().has_const_expr()); } TEST(MakeTestParsedAstTest, ForwardsParseError) { EXPECT_THAT(MakeTestParsedAst("%123"), StatusIs(absl::StatusCode::kInvalidArgument)); } } // namespace } // namespace cel::checker_internal ================================================ FILE: checker/internal/type_check_env.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/type_check_env.h" #include #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/constant.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { const VariableDecl* absl_nullable TypeCheckEnv::LookupVariable( absl::string_view name) const { if (auto it = variables_.find(name); it != variables_.end()) { return &it->second; } return nullptr; } const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( absl::string_view name) const { if (auto it = functions_.find(name); it != functions_.end()) { return &it->second; } return nullptr; } absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::string_view name) const { for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { CEL_ASSIGN_OR_RETURN(auto type, (*iter)->FindType(name)); if (type.has_value()) { return type; } } return absl::nullopt; } absl::StatusOr> TypeCheckEnv::LookupEnumConstant( absl::string_view type, absl::string_view value) const { for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { CEL_ASSIGN_OR_RETURN(auto enum_constant, (*iter)->FindEnumConstant(type, value)); if (enum_constant.has_value()) { auto decl = MakeVariableDecl(absl::StrCat(enum_constant->type_full_name, ".", enum_constant->value_name), enum_constant->type); decl.set_value(Constant(static_cast(enum_constant->number))); return decl; } } return absl::nullopt; } absl::StatusOr> TypeCheckEnv::LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const { CEL_ASSIGN_OR_RETURN(absl::optional type, LookupTypeName(name)); if (type.has_value()) { return MakeVariableDecl(type->name(), TypeType(arena, *type)); } if (name.find('.') != name.npos) { size_t last_dot = name.rfind('.'); absl::string_view enum_name_candidate = name.substr(0, last_dot); absl::string_view value_name_candidate = name.substr(last_dot + 1); return LookupEnumConstant(enum_name_candidate, value_name_candidate); } return absl::nullopt; } absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the // same name -- the later type provider will still be considered when // checking field accesses. for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { CEL_ASSIGN_OR_RETURN( auto field, (*iter)->FindStructTypeFieldByName(type_name, field_name)); if (field.has_value()) { return field; } } return absl::nullopt; } const VariableDecl* absl_nullable VariableScope::LookupLocalVariable( absl::string_view name) const { const VariableScope* scope = this; while (scope != nullptr) { if (auto it = scope->variables_.find(name); it != scope->variables_.end()) { return &it->second; } scope = scope->parent_; } return nullptr; } } // namespace cel::checker_internal ================================================ FILE: checker/internal/type_check_env.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECK_ENV_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECK_ENV_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/internal/descriptor_pool_type_introspector.h" #include "common/constant.h" #include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel::checker_internal { class TypeCheckEnv; // Helper class for managing nested scopes and the local variables they // implicitly declare. // // Nested scopes have a lifetime dependency on any parent scopes and should // generally be managed by unique_ptrs. class VariableScope { public: explicit VariableScope() : parent_(nullptr) {} VariableScope(const VariableScope&) = delete; VariableScope& operator=(const VariableScope&) = delete; VariableScope(VariableScope&&) = default; VariableScope& operator=(VariableScope&&) = default; bool InsertVariableIfAbsent(VariableDecl decl) { return variables_.insert({decl.name(), std::move(decl)}).second; } std::unique_ptr MakeNestedScope() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return absl::WrapUnique(new VariableScope(this)); } const VariableDecl* absl_nullable LookupLocalVariable( absl::string_view name) const; private: explicit VariableScope( const VariableScope* parent ABSL_ATTRIBUTE_LIFETIME_BOUND) : parent_(parent) {} const VariableScope* absl_nullable parent_; absl::flat_hash_map variables_; }; // Class managing the state of the type check environment. // // Maintains lookup maps for variables and functions and the set of type // providers. // // This class is thread-compatible. class TypeCheckEnv { private: using VariableDeclPtr = const VariableDecl* absl_nonnull; using FunctionDeclPtr = const FunctionDecl* absl_nonnull; public: explicit TypeCheckEnv( absl_nonnull std::shared_ptr descriptor_pool) : descriptor_pool_(std::move(descriptor_pool)), proto_type_introspector_( std::make_shared( descriptor_pool_.get())) { type_providers_.push_back( std::make_shared()); type_providers_.push_back(proto_type_introspector_); } TypeCheckEnv(const TypeCheckEnv&) = default; TypeCheckEnv& operator=(const TypeCheckEnv&) = default; TypeCheckEnv(TypeCheckEnv&&) = default; TypeCheckEnv& operator=(TypeCheckEnv&&) = default; const ExpressionContainer& container() const { return container_; } void set_container(ExpressionContainer container) { container_ = std::move(container); } const DescriptorPoolTypeIntrospector& proto_type_introspector() const { return *proto_type_introspector_; } DescriptorPoolTypeIntrospector& proto_type_introspector() { return *proto_type_introspector_; } void set_expected_type(const Type& type) { expected_type_ = std::move(type); } const absl::optional& expected_type() const { return expected_type_; } absl::Span> type_providers() const { return type_providers_; } void AddTypeProvider(std::unique_ptr provider) { type_providers_.push_back(std::move(provider)); } void AddTypeProvider(std::shared_ptr provider) { type_providers_.push_back(std::move(provider)); } const absl::flat_hash_map& variables() const { return variables_; } // Inserts a variable declaration into the environment of the current scope if // is is not already present. Parent scopes are not searched. // // Returns true if the variable was inserted, false otherwise. bool InsertVariableIfAbsent(VariableDecl decl) { return variables_.insert({decl.name(), std::move(decl)}).second; } // Inserts a variable declaration into the environment of the current scope. // Parent scopes are not searched. void InsertOrReplaceVariable(VariableDecl decl) { variables_[decl.name()] = std::move(decl); } const absl::flat_hash_map& functions() const { return functions_; } // Inserts a function declaration into the environment of the current scope if // is is not already present. Parent scopes are not searched (allowing for // shadowing). // // Returns true if the decl was inserted, false otherwise. bool InsertFunctionIfAbsent(FunctionDecl decl) { return functions_.insert({decl.name(), std::move(decl)}).second; } void InsertOrReplaceFunction(FunctionDecl decl) { functions_[decl.name()] = std::move(decl); } // Returns the declaration for the given name if it is found in the current // or any parent scope. // Note: the returned declaration ptr is only valid as long as no changes are // made to the environment. const VariableDecl* absl_nullable LookupVariable( absl::string_view name) const; const FunctionDecl* absl_nullable LookupFunction( absl::string_view name) const; absl::StatusOr> LookupTypeName( absl::string_view name) const; absl::StatusOr> LookupStructField( absl::string_view type_name, absl::string_view field_name) const; absl::StatusOr> LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view type_name) const; const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { return descriptor_pool_.get(); } // Used to keep an arena alive if one was needed to allocate types. // // Expected to be called exactly once if at all. void set_arena(std::shared_ptr arena) { ABSL_DCHECK(arena_ == nullptr || arena == arena_); arena_ = std::move(arena); } // Returns the arena if one was set, nullptr otherwise. std::shared_ptr arena() const { return arena_; } private: absl::StatusOr> LookupEnumConstant( absl::string_view type, absl::string_view value) const; absl_nonnull std::shared_ptr descriptor_pool_; // If set, an arena was needed to allocate types in the environment. // // The TypeCheckEnv does not otherwise use the arena, though it may be used by // derived TypeCheckerBuilders. absl_nullable std::shared_ptr arena_; ExpressionContainer container_; // Used to resolve fields on message types. std::shared_ptr proto_type_introspector_; // Maps fully qualified names to declarations. absl::flat_hash_map variables_; absl::flat_hash_map functions_; // Type providers for custom types. std::vector> type_providers_; absl::optional expected_type_; }; } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECK_ENV_H_ ================================================ FILE: checker/internal/type_checker_builder_impl.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/type_checker_builder_impl.h" #include #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" #include "common/type_kind.h" #include "internal/lexis.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "google/protobuf/descriptor.h" namespace cel::checker_internal { namespace { const absl::flat_hash_map>& GetStdMacros() { static const absl::NoDestructor< absl::flat_hash_map>> kStdMacros({ {"has", {HasMacro()}}, {"all", {AllMacro()}}, {"exists", {ExistsMacro()}}, {"exists_one", {ExistsOneMacro()}}, {"filter", {FilterMacro()}}, {"map", {Map2Macro(), Map3Macro()}}, {"optMap", {OptMapMacro()}}, {"optFlatMap", {OptFlatMapMacro()}}, }); return *kStdMacros; } absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { const auto& std_macros = GetStdMacros(); auto it = std_macros.find(decl.name()); if (it == std_macros.end()) { return absl::OkStatus(); } const auto& macros = it->second; for (const auto& macro : macros) { bool macro_member = macro.is_receiver_style(); size_t macro_arg_count = macro.argument_count() + (macro_member ? 1 : 0); for (const auto& ovl : decl.overloads()) { if (ovl.member() == macro_member && ovl.args().size() == macro_arg_count) { return absl::InvalidArgumentError(absl::StrCat( "overload for name '", macro.function(), "' with ", macro_arg_count, " argument(s) overlaps with predefined macro")); } } } return absl::OkStatus(); } absl::Status AddWellKnownContextDeclarationVariables( const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env, bool use_json_name) { for (int i = 0; i < descriptor->field_count(); ++i) { const google::protobuf::FieldDescriptor* field = descriptor->field(i); Type type = MessageTypeField(field).GetType(); if (type.IsEnum()) { type = IntType(); } absl::string_view name = field->name(); if (use_json_name) { name = field->json_name(); } if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { return absl::AlreadyExistsError( absl::StrCat("variable '", name, "' declared multiple times (from context declaration: '", descriptor->full_name(), "')")); } } return absl::OkStatus(); } absl::Status AddContextDeclarationVariables( const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) { const bool use_json_name = env.proto_type_introspector().use_json_name(); if (IsWellKnownMessageType(descriptor)) { return AddWellKnownContextDeclarationVariables(descriptor, env, use_json_name); } CEL_ASSIGN_OR_RETURN(auto fields, env.proto_type_introspector().ListFieldsForStructType( descriptor->full_name())); if (!fields.has_value()) { return absl::InternalError(absl::StrCat("context declaration '", descriptor->full_name(), "' not found, but was expected")); } for (const auto& field_entry : *fields) { Type type = field_entry.field.GetType(); if (type.IsEnum()) { type = IntType(); } absl::string_view name = field_entry.name; if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { return absl::AlreadyExistsError( absl::StrCat("variable '", name, "' declared multiple times (from context declaration: '", descriptor->full_name(), "')")); } } return absl::OkStatus(); } absl::StatusOr MergeFunctionDecls( const FunctionDecl& existing_decl, const FunctionDecl& new_decl) { if (existing_decl.name() != new_decl.name()) { return absl::InternalError( "Attempted to merge function decls with different names"); } FunctionDecl merged_decl = existing_decl; for (const auto& ovl : new_decl.overloads()) { // We do not tolerate signature collisions, even if they are exact matches. CEL_RETURN_IF_ERROR(merged_decl.AddOverload(ovl)); } return merged_decl; } absl::optional FilterDecl(FunctionDecl decl, const TypeCheckerSubset& subset) { FunctionDecl filtered; std::string name = decl.release_name(); std::vector overloads = decl.release_overloads(); for (const auto& ovl : overloads) { if (subset.should_include_overload(name, ovl.id())) { absl::Status s = filtered.AddOverload(std::move(ovl)); if (!s.ok()) { // Should not be possible to construct the original decl in a way that // would cause this to fail. ABSL_LOG(DFATAL) << "failed to add overload to filtered decl: " << s; } } } if (filtered.overloads().empty()) { return absl::nullopt; } filtered.set_name(std::move(name)); return filtered; } absl::Status ValidateType(const Type& t, bool check_type_param_name, int depth_limit, int remaining_depth) { if (remaining_depth-- <= 0) { return absl::InvalidArgumentError( absl::StrCat("type nesting limit of ", depth_limit, " exceeded")); } switch (t.kind()) { case TypeKind::kTypeParam: { if (!check_type_param_name) { return absl::OkStatus(); } const TypeParamType& type_param = t.GetTypeParam(); if (!internal::LexisIsIdentifier(type_param.name())) { return absl::InvalidArgumentError( absl::StrCat("type parameter name '", type_param.name(), "' is not a valid identifier")); } return absl::OkStatus(); } case TypeKind::kList: { Type element_type = t.AsList()->GetElement(); return ValidateType(element_type, check_type_param_name, depth_limit, remaining_depth); } case TypeKind::kMap: { Type key_type = t.AsMap()->GetKey(); Type value_type = t.AsMap()->GetValue(); CEL_RETURN_IF_ERROR(ValidateType(key_type, check_type_param_name, depth_limit, remaining_depth)); return ValidateType(value_type, check_type_param_name, depth_limit, remaining_depth); } case TypeKind::kStruct: { auto message_type = t.AsMessage(); if (message_type.has_value() && !static_cast(*message_type)) { return absl::InvalidArgumentError( "an empty message type cannot be used in a type declaration"); } return absl::OkStatus(); } case TypeKind::kOpaque: { for (Type type_param : t.AsOpaque()->GetParameters()) { CEL_RETURN_IF_ERROR(ValidateType(type_param, check_type_param_name, depth_limit, remaining_depth)); } return absl::OkStatus(); } case TypeKind::kType: { for (Type type_param : t.AsType()->GetParameters()) { CEL_RETURN_IF_ERROR(ValidateType(type_param, check_type_param_name, depth_limit, remaining_depth)); } return absl::OkStatus(); } default: break; } return absl::OkStatus(); } absl::Status ValidateFunctionDecl(const FunctionDecl& decl, bool check_type_param_name, int depth_limit) { CEL_RETURN_IF_ERROR(CheckStdMacroOverlap(decl)); for (const auto& ovl : decl.overloads()) { CEL_RETURN_IF_ERROR(ValidateType(ovl.result(), check_type_param_name, depth_limit, depth_limit)); for (const auto& arg : ovl.args()) { CEL_RETURN_IF_ERROR( ValidateType(arg, check_type_param_name, depth_limit, depth_limit)); } } return absl::OkStatus(); } absl::Status ValidateVariableDecl(const VariableDecl& decl, bool check_type_param_name, int depth_limit) { return ValidateType(decl.type(), check_type_param_name, depth_limit, depth_limit); } } // namespace absl::Status TypeCheckerBuilderImpl::BuildLibraryConfig( const CheckerLibrary& library, TypeCheckerBuilderImpl::ConfigRecord* config) { target_config_ = config; absl::Cleanup reset([this] { target_config_ = &default_config_; }); return library.configure(*this); } absl::Status TypeCheckerBuilderImpl::ApplyConfig( TypeCheckerBuilderImpl::ConfigRecord config, const TypeCheckerSubset* subset, TypeCheckEnv& env) { using FunctionDeclRecord = TypeCheckerBuilderImpl::FunctionDeclRecord; for (auto& type_provider : config.type_providers) { env.AddTypeProvider(std::move(type_provider)); } for (FunctionDeclRecord& fn : config.functions) { FunctionDecl decl = std::move(fn.decl); if (subset != nullptr) { absl::optional filtered = FilterDecl(std::move(decl), *subset); if (!filtered.has_value()) { continue; } decl = std::move(*filtered); } switch (fn.add_semantic) { case AddSemantic::kInsertIfAbsent: { std::string name = decl.name(); if (!env.InsertFunctionIfAbsent(std::move(decl))) { return absl::AlreadyExistsError( absl::StrCat("function '", name, "' declared multiple times")); } break; } case AddSemantic::kTryMerge: { const FunctionDecl* existing_decl = env.LookupFunction(decl.name()); FunctionDecl to_add = std::move(decl); if (existing_decl != nullptr) { CEL_ASSIGN_OR_RETURN( to_add, MergeFunctionDecls(*existing_decl, std::move(to_add))); } env.InsertOrReplaceFunction(std::move(to_add)); break; } default: return absl::InternalError(absl::StrCat( "unsupported function add semantic: ", fn.add_semantic)); } } for (const google::protobuf::Descriptor* context_type : config.context_types) { CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env)); } for (VariableDeclRecord& var : config.variables) { switch (var.add_semantic) { case AddSemantic::kInsertIfAbsent: { if (!env.InsertVariableIfAbsent(var.decl)) { return absl::AlreadyExistsError(absl::StrCat( "variable '", var.decl.name(), "' declared multiple times")); } break; } case AddSemantic::kInsertOrReplace: { env.InsertOrReplaceVariable(var.decl); break; } default: return absl::InternalError(absl::StrCat( "unsupported variable add semantic: ", var.add_semantic)); } } return absl::OkStatus(); } absl::StatusOr> TypeCheckerBuilderImpl::Build() { TypeCheckEnv env(template_env_); CEL_RETURN_IF_ERROR(ConfigureTypeCheckEnv(env)); return std::make_unique(std::move(env), options_); } absl::Status TypeCheckerBuilderImpl::ConfigureTypeCheckEnv(TypeCheckEnv& env) { if (expression_container_.has_value()) { env.set_container(*expression_container_); } if (expected_type_.has_value()) { env.set_expected_type(*expected_type_); } ConfigRecord anonymous_config; std::vector configs; for (const auto& library : libraries_) { ConfigRecord* config = &anonymous_config; if (!library.id.empty()) { configs.emplace_back(); config = &configs.back(); config->id = library.id; } CEL_RETURN_IF_ERROR(BuildLibraryConfig(library, config)); } env.proto_type_introspector().set_use_json_name( options_.use_json_field_names); for (const ConfigRecord& config : configs) { TypeCheckerSubset* subset = nullptr; if (!config.id.empty()) { auto it = subsets_.find(config.id); if (it != subsets_.end()) { subset = &it->second; } } CEL_RETURN_IF_ERROR(ApplyConfig(std::move(config), subset, env)); } CEL_RETURN_IF_ERROR(ApplyConfig(std::move(anonymous_config), /*subset=*/nullptr, env)); CEL_RETURN_IF_ERROR(ApplyConfig(default_config_, /*subset=*/nullptr, env)); if (type_arena_ != nullptr) { env.set_arena(type_arena_); } return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddLibrary(CheckerLibrary library) { if (!library.id.empty() && !library_ids_.insert(library.id).second) { return absl::AlreadyExistsError( absl::StrCat("library '", library.id, "' already exists")); } if (!library.configure) { return absl::OkStatus(); } libraries_.push_back(std::move(library)); return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddLibrarySubset( TypeCheckerSubset subset) { if (subset.library_id.empty()) { return absl::InvalidArgumentError( "library_id must not be empty for subset"); } std::string id = subset.library_id; if (!subsets_.insert({id, std::move(subset)}).second) { return absl::AlreadyExistsError( absl::StrCat("library subset for '", id, "' already exists")); } return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddVariable(const VariableDecl& decl) { CEL_RETURN_IF_ERROR( ValidateVariableDecl(decl, options_.enable_type_parameter_name_validation, options_.max_type_decl_nesting)); target_config_->variables.push_back({decl, AddSemantic::kInsertIfAbsent}); return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddOrReplaceVariable( const VariableDecl& decl) { CEL_RETURN_IF_ERROR( ValidateVariableDecl(decl, options_.enable_type_parameter_name_validation, options_.max_type_decl_nesting)); target_config_->variables.push_back({decl, AddSemantic::kInsertOrReplace}); return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( absl::string_view type) { const google::protobuf::Descriptor* desc = template_env_.descriptor_pool()->FindMessageTypeByName(type); if (desc == nullptr) { return absl::NotFoundError( absl::StrCat("context declaration '", type, "' not found")); } if (IsWellKnownMessageType(desc) && !options_.allow_well_known_type_context_declarations) { return absl::InvalidArgumentError( absl::StrCat("context declaration '", type, "' is not a struct")); } for (const auto* context_type : target_config_->context_types) { if (context_type->full_name() == desc->full_name()) { return absl::AlreadyExistsError( absl::StrCat("context declaration '", type, "' already exists")); } } target_config_->context_types.push_back(desc); return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR( ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, options_.max_type_decl_nesting)); target_config_->functions.push_back( {std::move(decl), AddSemantic::kInsertIfAbsent}); return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::MergeFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR( ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, options_.max_type_decl_nesting)); target_config_->functions.push_back( {std::move(decl), AddSemantic::kTryMerge}); return absl::OkStatus(); } void TypeCheckerBuilderImpl::AddTypeProvider( std::unique_ptr provider) { target_config_->type_providers.push_back(std::move(provider)); } void TypeCheckerBuilderImpl::set_container(absl::string_view container) { if (!expression_container_.has_value()) { expression_container_.emplace(); } expression_container_->SetContainer(container).IgnoreError(); } void TypeCheckerBuilderImpl::SetExpressionContainer( ExpressionContainer container) { expression_container_ = std::move(container); } void TypeCheckerBuilderImpl::SetExpectedType(const Type& type) { expected_type_ = type; } } // namespace cel::checker_internal ================================================ FILE: checker/internal/type_checker_builder_impl.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_BUILDER_IMPL_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_BUILDER_IMPL_H_ #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel::checker_internal { // Builder for TypeChecker instances. class TypeCheckerBuilderImpl : public TypeCheckerBuilder { public: TypeCheckerBuilderImpl( absl_nonnull std::shared_ptr descriptor_pool, const CheckerOptions& options) : options_(options), target_config_(&default_config_), template_env_(std::move(descriptor_pool)) {} // Constructor for building an extended TypeChecker. explicit TypeCheckerBuilderImpl(const CheckerOptions& options, const TypeCheckEnv& template_env) : options_(options), target_config_(&default_config_), template_env_(template_env) { if (auto arena = template_env_.arena(); arena != nullptr) { type_arena_ = std::move(arena); } } // Move only. TypeCheckerBuilderImpl(const TypeCheckerBuilderImpl&) = delete; TypeCheckerBuilderImpl(TypeCheckerBuilderImpl&&) = default; TypeCheckerBuilderImpl& operator=(const TypeCheckerBuilderImpl&) = delete; TypeCheckerBuilderImpl& operator=(TypeCheckerBuilderImpl&&) = default; absl::StatusOr> Build() override; absl::Status AddLibrary(CheckerLibrary library) override; absl::Status AddLibrarySubset(TypeCheckerSubset subset) override; absl::Status AddVariable(const VariableDecl& decl) override; absl::Status AddOrReplaceVariable(const VariableDecl& decl) override; absl::Status AddContextDeclaration(absl::string_view type) override; absl::Status AddFunction(const FunctionDecl& decl) override; absl::Status MergeFunction(const FunctionDecl& decl) override; void SetExpectedType(const Type& type) override; void AddTypeProvider(std::unique_ptr provider) override; void set_container(absl::string_view container) override; void SetExpressionContainer( ExpressionContainer expression_container) override; const CheckerOptions& options() const override { return options_; } google::protobuf::Arena* absl_nonnull arena() override { if (type_arena_ == nullptr) { type_arena_ = std::make_shared(); } return type_arena_.get(); } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const override { return template_env_.descriptor_pool(); } private: // Sematic for adding a possibly duplicated declaration. enum class AddSemantic { kInsertIfAbsent, kInsertOrReplace, // Attempts to merge with any existing overloads for the same function. // Will fail if any of the IDs or signatures collide. kTryMerge, }; struct VariableDeclRecord { VariableDecl decl; AddSemantic add_semantic; }; struct FunctionDeclRecord { FunctionDecl decl; AddSemantic add_semantic; }; // A record of configuration calls. // Used to replay the configuration in calls to Build(). struct ConfigRecord { std::string id = ""; std::vector variables; std::vector functions; std::vector> type_providers; std::vector context_types; }; absl::Status BuildLibraryConfig(const CheckerLibrary& library, ConfigRecord* absl_nonnull config); absl::Status ApplyConfig(ConfigRecord config, const TypeCheckerSubset* subset, TypeCheckEnv& env); absl::Status ConfigureTypeCheckEnv(TypeCheckEnv& env); CheckerOptions options_; // Default target for configuration changes. Used for direct calls to // AddVariable, AddFunction, etc. ConfigRecord default_config_; // Active target for configuration changes. // This is used to track which library the change is made on behalf of. ConfigRecord* absl_nonnull target_config_; TypeCheckEnv template_env_; std::shared_ptr type_arena_; std::vector libraries_; absl::flat_hash_map subsets_; absl::flat_hash_set library_ids_; absl::optional expression_container_; absl::optional expected_type_; }; } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ ================================================ FILE: checker/internal/type_checker_builder_impl_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/type_checker_builder_impl.h" #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; struct ContextDeclsTestCase { std::string expr; TypeSpec expected_type; }; class ContextDeclsFieldsDefinedTest : public testing::TestWithParam {}; TEST_P(ContextDeclsFieldsDefinedTest, ContextDeclsFieldsDefined) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); ASSERT_THAT( builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder.Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(GetParam().expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type); } INSTANTIATE_TEST_SUITE_P( TestAllTypes, ContextDeclsFieldsDefinedTest, testing::Values( ContextDeclsTestCase{"single_int64", TypeSpec(PrimitiveType::kInt64)}, ContextDeclsTestCase{"single_uint32", TypeSpec(PrimitiveType::kUint64)}, ContextDeclsTestCase{"single_double", TypeSpec(PrimitiveType::kDouble)}, ContextDeclsTestCase{"single_string", TypeSpec(PrimitiveType::kString)}, ContextDeclsTestCase{"single_any", TypeSpec(WellKnownTypeSpec::kAny)}, ContextDeclsTestCase{"single_duration", TypeSpec(WellKnownTypeSpec::kDuration)}, ContextDeclsTestCase{ "single_bool_wrapper", TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, ContextDeclsTestCase{ "list_value", TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec())))}, ContextDeclsTestCase{ "standalone_message", TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))}, ContextDeclsTestCase{"standalone_enum", TypeSpec(PrimitiveType::kInt64)}, ContextDeclsTestCase{"repeated_bytes", TypeSpec(ListTypeSpec(std::make_unique( PrimitiveType::kBytes)))}, ContextDeclsTestCase{ "repeated_nested_message", TypeSpec(ListTypeSpec(std::make_unique(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))}, ContextDeclsTestCase{ "map_int32_timestamp", TypeSpec(MapTypeSpec( std::make_unique(PrimitiveType::kInt64), std::make_unique(WellKnownTypeSpec::kTimestamp)))}, ContextDeclsTestCase{ "single_struct", TypeSpec( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec())))})); TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); ASSERT_THAT( builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), IsOk()); EXPECT_THAT( builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), StatusIs(absl::StatusCode::kAlreadyExists, "context declaration 'cel.expr.conformance.proto3.TestAllTypes' " "already exists")); } TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); EXPECT_THAT( builder.AddContextDeclaration("com.example.UnknownType"), StatusIs(absl::StatusCode::kNotFound, "context declaration 'com.example.UnknownType' not found")); } TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); EXPECT_THAT( builder.AddContextDeclaration("google.protobuf.Timestamp"), StatusIs( absl::StatusCode::kInvalidArgument, "context declaration 'google.protobuf.Timestamp' is not a struct")); } TEST(ContextDeclsTest, CustomStructNotSupported) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); class MyTypeProvider : public cel::TypeIntrospector { public: absl::StatusOr> FindTypeImpl( absl::string_view name) const override { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); } return absl::nullopt; } }; builder.AddTypeProvider(std::make_unique()); EXPECT_THAT(builder.AddContextDeclaration("com.example.MyStruct"), StatusIs(absl::StatusCode::kNotFound, "context declaration 'com.example.MyStruct' not found")); } TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); ASSERT_THAT( builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), IsOk()); // We resolve the context declaration variables at the Build() call, so the // error surfaces then. ASSERT_THAT( builder.AddContextDeclaration("cel.expr.conformance.proto2.TestAllTypes"), IsOk()); EXPECT_THAT( builder.Build(), StatusIs(absl::StatusCode::kAlreadyExists, "variable 'single_int32' declared multiple times (from context " "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); } TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); ASSERT_THAT( builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), IsOk()); ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), IsOk()); EXPECT_THAT(builder.Build(), StatusIs(absl::StatusCode::kAlreadyExists, "variable 'single_int64' declared multiple times")); } TEST(TypeCheckerBuilderImplTest, InvalidTypeParamNameVariableValidationDisabled) { CheckerOptions options; options.enable_type_parameter_name_validation = false; TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), options); ASSERT_THAT(builder.AddVariable(MakeVariableDecl("x", TypeParamType(""))), IsOk()); ASSERT_THAT(builder.AddOrReplaceVariable( MakeVariableDecl("x", TypeParamType("T% foo"))), IsOk()); } TEST(TypeCheckerBuilderImplTest, ErrorOnUnspecifiedMessageType) { CheckerOptions options; options.enable_type_parameter_name_validation = true; TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), options); ASSERT_THAT( builder.AddVariable(MakeVariableDecl("x", MessageType())), StatusIs(absl::StatusCode::kInvalidArgument, "an empty message type cannot be used in a type declaration")); } TEST(TypeCheckerBuilderImplTest, ErrorOnInvalidTypeParamNameVariable) { CheckerOptions options; options.enable_type_parameter_name_validation = true; TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), options); ASSERT_THAT(builder.AddVariable(MakeVariableDecl("x", TypeParamType(""))), StatusIs(absl::StatusCode::kInvalidArgument, "type parameter name '' is not a valid identifier")); ASSERT_THAT( builder.AddOrReplaceVariable( MakeVariableDecl("x", TypeParamType("T% foo"))), StatusIs(absl::StatusCode::kInvalidArgument, "type parameter name 'T% foo' is not a valid identifier")); } TEST(TypeCheckerBuilderImplTest, ErrorOnTooDeepTypeNestingVariable) { CheckerOptions options; options.max_type_decl_nesting = 2; google::protobuf::Arena arena; TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), options); ASSERT_THAT(builder.AddVariable( MakeVariableDecl("x", TypeType(&arena, TypeParamType("T")))), IsOk()); ASSERT_THAT( builder.AddOrReplaceVariable(MakeVariableDecl( "x", TypeType(&arena, TypeType(&arena, TypeParamType("T% foo"))))), StatusIs(absl::StatusCode::kInvalidArgument, "type nesting limit of 2 exceeded")); } TEST(TypeCheckerBuilderImplTest, ErrorOnInvalidTypeParamNameFunction) { CheckerOptions options; options.enable_type_parameter_name_validation = true; TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), options); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl( "type2", MakeOverloadDecl("type2", TypeType(&arena, TypeParamType("")), TypeParamType("")))); ASSERT_THAT(builder.AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "type parameter name '' is not a valid identifier")); } TEST(TypeCheckerBuilderImplTest, ErrorOnTooDeepTypeNestingFunction) { CheckerOptions options; options.max_type_decl_nesting = 2; google::protobuf::Arena arena; TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), options); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); ASSERT_THAT(builder.AddFunction(fn_decl), IsOk()); Type list_type = ListType(&arena, ListType(&arena, IntType())); ASSERT_OK_AND_ASSIGN( fn_decl, MakeFunctionDecl("add", MakeOverloadDecl("add_list_list_int", list_type, list_type, list_type))); ASSERT_THAT(builder.MergeFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "type nesting limit of 2 exceeded")); } TEST(TypeCheckerBuilderImplTest, ReplaceVariable) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); ASSERT_THAT( builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), IsOk()); ASSERT_THAT(builder.AddOrReplaceVariable( MakeVariableDecl("single_int64", StringType())), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder.Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("single_int64")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); const auto& checked_ast = *result.GetAst(); EXPECT_EQ(checked_ast.GetReturnType(), TypeSpec(PrimitiveType::kString)); } TEST(TypeCheckerBuilderImplTest, LazyArenaInitialization) { auto builder = std::make_unique( internal::GetSharedTestingDescriptorPool(), CheckerOptions{}); ASSERT_THAT(builder->AddLibrary(CheckerLibrary{ .id = "test_lib", .configure = [](TypeCheckerBuilder& builder) -> absl::Status { auto l = ListType(builder.arena(), IntType()); return builder.AddVariable(MakeVariableDecl("foo", l)); }, }), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); builder.reset(); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); const auto& checked_ast = *result.GetAst(); EXPECT_EQ(checked_ast.GetReturnType(), TypeSpec(ListTypeSpec( std::make_unique(PrimitiveType::kInt64)))); } } // namespace } // namespace cel::checker_internal ================================================ FILE: checker/internal/type_checker_impl.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/type_checker_impl.h" #include #include #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/checker_options.h" #include "checker/internal/format_type_name.h" #include "checker/internal/namespace_generator.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_builder_impl.h" #include "checker/internal/type_inference_context.h" #include "checker/type_check_issue.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/ast_rewrite.h" #include "common/ast_traverse.h" #include "common/ast_visitor.h" #include "common/ast_visitor_base.h" #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { namespace { using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; constexpr const char kOptionalSelect[] = "_?._"; std::string FormatCandidate(absl::Span qualifiers) { return absl::StrJoin(qualifiers, "."); } // Flatten the type to the AST type representation to remove any lifecycle // dependency between the type check environment and the AST. // // TODO(uncreated-issue/72): It may be better to do this at the point of serialization // in the future, but requires corresponding change for the runtime to correctly // rehydrate the serialized Ast. absl::StatusOr FlattenType(const Type& type); absl::StatusOr FlattenAbstractType(const OpaqueType& type) { std::vector parameter_types; parameter_types.reserve(type.GetParameters().size()); for (const auto& param : type.GetParameters()) { CEL_ASSIGN_OR_RETURN(auto param_type, FlattenType(param)); parameter_types.push_back(std::move(param_type)); } return AstType( AbstractType(std::string(type.name()), std::move(parameter_types))); } absl::StatusOr FlattenMapType(const MapType& type) { CEL_ASSIGN_OR_RETURN(auto key, FlattenType(type.key())); CEL_ASSIGN_OR_RETURN(auto value, FlattenType(type.value())); return AstType(MapTypeSpec(std::make_unique(std::move(key)), std::make_unique(std::move(value)))); } absl::StatusOr FlattenListType(const ListType& type) { CEL_ASSIGN_OR_RETURN(auto elem, FlattenType(type.element())); return AstType(ListTypeSpec(std::make_unique(std::move(elem)))); } absl::StatusOr FlattenMessageType(const StructType& type) { return AstType(MessageTypeSpec(std::string(type.name()))); } absl::StatusOr FlattenTypeType(const TypeType& type) { if (type.GetParameters().size() > 1) { return absl::InternalError( absl::StrCat("Unsupported type: ", type.DebugString())); } if (type.GetParameters().empty()) { return AstType(std::make_unique()); } CEL_ASSIGN_OR_RETURN(auto param, FlattenType(type.GetParameters()[0])); return AstType(std::make_unique(std::move(param))); } absl::StatusOr FlattenType(const Type& type) { switch (type.kind()) { case TypeKind::kDyn: return AstType(DynTypeSpec()); case TypeKind::kError: return AstType(ErrorTypeSpec()); case TypeKind::kNull: return AstType(NullTypeSpec()); case TypeKind::kBool: return AstType(PrimitiveType::kBool); case TypeKind::kInt: return AstType(PrimitiveType::kInt64); case TypeKind::kEnum: return AstType(PrimitiveType::kInt64); case TypeKind::kUint: return AstType(PrimitiveType::kUint64); case TypeKind::kDouble: return AstType(PrimitiveType::kDouble); case TypeKind::kString: return AstType(PrimitiveType::kString); case TypeKind::kBytes: return AstType(PrimitiveType::kBytes); case TypeKind::kDuration: return AstType(WellKnownTypeSpec::kDuration); case TypeKind::kTimestamp: return AstType(WellKnownTypeSpec::kTimestamp); case TypeKind::kStruct: return FlattenMessageType(type.GetStruct()); case TypeKind::kList: return FlattenListType(type.GetList()); case TypeKind::kMap: return FlattenMapType(type.GetMap()); case TypeKind::kOpaque: return FlattenAbstractType(type.GetOpaque()); case TypeKind::kBoolWrapper: return AstType(PrimitiveTypeWrapper(PrimitiveType::kBool)); case TypeKind::kIntWrapper: return AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)); case TypeKind::kUintWrapper: return AstType(PrimitiveTypeWrapper(PrimitiveType::kUint64)); case TypeKind::kDoubleWrapper: return AstType(PrimitiveTypeWrapper(PrimitiveType::kDouble)); case TypeKind::kStringWrapper: return AstType(PrimitiveTypeWrapper(PrimitiveType::kString)); case TypeKind::kBytesWrapper: return AstType(PrimitiveTypeWrapper(PrimitiveType::kBytes)); case TypeKind::kTypeParam: // Convert any remaining free type params to dyn. return AstType(DynTypeSpec()); case TypeKind::kType: return FlattenTypeType(type.GetType()); case TypeKind::kAny: return AstType(WellKnownTypeSpec::kAny); default: return absl::InternalError( absl::StrCat("unsupported type encountered making AST serializable: ", type.DebugString())); } } class ResolveVisitor : public AstVisitorBase { public: struct FunctionResolution { const FunctionDecl* decl; bool namespace_rewrite; }; struct AttributeResolution { const VariableDecl* decl; bool requires_disambiguation; }; ResolveVisitor(NamespaceGenerator namespace_generator, const TypeCheckEnv& env, const Ast& ast, TypeInferenceContext& inference_context, std::vector& issues, google::protobuf::Arena* absl_nonnull arena) : namespace_generator_(std::move(namespace_generator)), env_(&env), inference_context_(&inference_context), issues_(&issues), ast_(&ast), root_scope_(), arena_(arena), current_scope_(&root_scope_) {} void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); } void PostVisitExpr(const Expr& expr) override { if (expr_stack_.empty()) { return; } expr_stack_.pop_back(); } void PostVisitConst(const Expr& expr, const Constant& constant) override; void PreVisitComprehension(const Expr& expr, const ComprehensionExpr& comprehension) override; void PostVisitComprehension(const Expr& expr, const ComprehensionExpr& comprehension) override; void PostVisitMap(const Expr& expr, const MapExpr& map) override; void PostVisitList(const Expr& expr, const ListExpr& list) override; void PreVisitComprehensionSubexpression( const Expr& expr, const ComprehensionExpr& comprehension, ComprehensionArg comprehension_arg) override; void PostVisitComprehensionSubexpression( const Expr& expr, const ComprehensionExpr& comprehension, ComprehensionArg comprehension_arg) override; void PostVisitIdent(const Expr& expr, const IdentExpr& ident) override; void PostVisitSelect(const Expr& expr, const SelectExpr& select) override; void PostVisitCall(const Expr& expr, const CallExpr& call) override; void PostVisitStruct(const Expr& expr, const StructExpr& create_struct) override; // Accessors for resolved values. const absl::flat_hash_map& functions() const { return functions_; } const absl::flat_hash_map& attributes() const { return attributes_; } const absl::flat_hash_map& struct_types() const { return struct_types_; } const absl::flat_hash_map& types() const { return types_; } const absl::Status& status() const { return status_; } int error_count() const { return error_count_; } void AssertExpectedType(const Expr& expr, const Type& expected_type) { Type observed = GetDeducedType(&expr); if (!inference_context_->IsAssignable(observed, expected_type)) { ReportTypeMismatch(expr.id(), expected_type, observed); } } private: struct ComprehensionScope { const Expr* comprehension_expr; const VariableScope* parent; VariableScope* accu_scope; VariableScope* iter_scope; }; struct FunctionOverloadMatch { // Overall result type. // If resolution is incomplete, this will be DynType. Type result_type; // A new declaration with the narrowed overload candidates. // Owned by the Check call scoped arena. const FunctionDecl* decl; }; void ResolveSimpleIdentifier(const Expr& expr, absl::string_view name); void ResolveQualifiedIdentifier(const Expr& expr, absl::Span qualifiers); // Resolves the function call shape (i.e. the number of arguments and call // style) for the given function call. const FunctionDecl* ResolveFunctionCallShape(const Expr& expr, absl::string_view function_name, int arg_count, bool is_receiver); // Resolves a global identifier (i.e. declared in the CEL environment). const VariableDecl* absl_nullable LookupGlobalIdentifier( absl::string_view name); // Resolves a local identifier (i.e. a bind or comprehension var). const VariableDecl* absl_nullable LookupLocalIdentifier( absl::string_view name); // Resolves the applicable function overloads for the given function call. // // If found, assigns a new function decl with the resolved overloads. void ResolveFunctionOverloads(const Expr& expr, const FunctionDecl& decl, int arg_count, bool is_receiver, bool is_namespaced); void ResolveSelectOperation(const Expr& expr, absl::string_view field, const Expr& operand); void ReportIssue(TypeCheckIssue issue) { if (issue.severity() == Severity::kError) { error_count_++; } issues_->push_back(std::move(issue)); } void ReportMissingReference(const Expr& expr, absl::string_view name) { ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(expr.id()), absl::StrCat("undeclared reference to '", name, "' (in container '", env_->container().container(), "')"))); } void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, absl::string_view struct_name) { ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(expr_id), absl::StrCat("undefined field '", field_name, "' not found in struct '", struct_name, "'"))); } void ReportTypeMismatch(int64_t expr_id, const Type& expected, const Type& actual) { ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(expr_id), absl::StrCat("expected type '", FormatTypeName(inference_context_->FinalizeType(expected)), "' but found '", FormatTypeName(inference_context_->FinalizeType(actual)), "'"))); } absl::Status CheckFieldAssignments(const Expr& expr, const StructExpr& create_struct, Type struct_type, absl::string_view resolved_name) { for (const auto& field : create_struct.fields()) { const Expr* value = &field.value(); Type value_type = GetDeducedType(value); // Lookup message type by name to support WellKnownType creation. CEL_ASSIGN_OR_RETURN( absl::optional field_info, env_->LookupStructField(resolved_name, field.name())); if (!field_info.has_value()) { ReportUndefinedField(field.id(), field.name(), resolved_name); continue; } Type field_type = field_info->GetType(); if (field.optional()) { field_type = OptionalType(arena_, field_type); } if (!inference_context_->IsAssignable(value_type, field_type)) { ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(field.id()), absl::StrCat( "expected type of field '", field_info->name(), "' is '", FormatTypeName(inference_context_->FinalizeType(field_type)), "' but provided type is '", FormatTypeName(inference_context_->FinalizeType(value_type)), "'"))); continue; } } return absl::OkStatus(); } absl::optional CheckFieldType(int64_t expr_id, const Type& operand_type, absl::string_view field_name); void HandleOptSelect(const Expr& expr); // Get the assigned type of the given subexpression. Should only be called if // the given subexpression is expected to have already been checked. // // If unknown, returns DynType as a placeholder and reports an error. // Whether or not the subexpression is valid for the checker configuration, // the type checker should have assigned a type (possibly ErrorType). If there // is no assigned type, the type checker failed to handle the subexpression // and should not attempt to continue type checking. Type GetDeducedType(const Expr* expr) { auto iter = types_.find(expr); if (iter != types_.end()) { return iter->second; } status_.Update(absl::InvalidArgumentError( absl::StrCat("Could not deduce type for expression id: ", expr->id()))); return DynType(); } NamespaceGenerator namespace_generator_; const TypeCheckEnv* absl_nonnull env_; TypeInferenceContext* absl_nonnull inference_context_; std::vector* absl_nonnull issues_; const Ast* absl_nonnull ast_; VariableScope root_scope_; google::protobuf::Arena* absl_nonnull arena_; // state tracking for the traversal. const VariableScope* current_scope_; std::vector expr_stack_; absl::flat_hash_map> maybe_namespaced_functions_; // Select operations that need to be resolved outside of the traversal. // These are handled separately to disambiguate between namespaces and field // accesses absl::flat_hash_set deferred_select_operations_; std::vector> comprehension_vars_; std::vector comprehension_scopes_; absl::Status status_; int error_count_ = 0; // References that were resolved and may require AST rewrites. absl::flat_hash_map functions_; absl::flat_hash_map attributes_; absl::flat_hash_map struct_types_; absl::flat_hash_map types_; }; void ResolveVisitor::PostVisitIdent(const Expr& expr, const IdentExpr& ident) { if (expr_stack_.size() == 1) { ResolveSimpleIdentifier(expr, ident.name()); return; } // Walk up the stack to find the qualifiers. // // If the identifier is the target of a receiver call, then note // the function so we can disambiguate namespaced functions later. int stack_pos = expr_stack_.size() - 1; std::vector qualifiers; qualifiers.push_back(ident.name()); const Expr* receiver_call = nullptr; const Expr* root_candidate = expr_stack_[stack_pos]; // Try to identify the root of the select chain, possibly as the receiver of // a function call. while (stack_pos > 0) { --stack_pos; const Expr* parent = expr_stack_[stack_pos]; if (parent->has_call_expr() && (&parent->call_expr().target() == root_candidate)) { receiver_call = parent; break; } else if (!parent->has_select_expr()) { break; } qualifiers.push_back(parent->select_expr().field()); deferred_select_operations_.insert(parent); root_candidate = parent; if (parent->select_expr().test_only()) { break; } } if (receiver_call == nullptr) { ResolveQualifiedIdentifier(*root_candidate, qualifiers); } else { maybe_namespaced_functions_[receiver_call] = std::move(qualifiers); } } void ResolveVisitor::PostVisitConst(const Expr& expr, const Constant& constant) { switch (constant.kind().index()) { case ConstantKindIndexOf(): types_[&expr] = NullType(); break; case ConstantKindIndexOf(): types_[&expr] = BoolType(); break; case ConstantKindIndexOf(): types_[&expr] = IntType(); break; case ConstantKindIndexOf(): types_[&expr] = UintType(); break; case ConstantKindIndexOf(): types_[&expr] = DoubleType(); break; case ConstantKindIndexOf(): types_[&expr] = BytesType(); break; case ConstantKindIndexOf(): types_[&expr] = StringType(); break; case ConstantKindIndexOf(): types_[&expr] = DurationType(); break; case ConstantKindIndexOf(): types_[&expr] = TimestampType(); break; default: ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(expr.id()), absl::StrCat("unsupported constant type: ", constant.kind().index()))); types_[&expr] = ErrorType(); break; } } bool IsSupportedKeyType(const Type& type) { switch (type.kind()) { case TypeKind::kBool: case TypeKind::kInt: case TypeKind::kUint: case TypeKind::kString: case TypeKind::kDyn: return true; default: return false; } } void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { // Roughly follows map type inferencing behavior in Go. // // We try to infer the type of the map if all of the keys or values are // homogeneously typed, otherwise assume the type parameter is dyn (defer to // runtime for enforcing type compatibility). // // TODO(uncreated-issue/72): Widening behavior is not well documented for map / list // construction in the spec and is a bit inconsistent between implementations. // // In the future, we should probably default enforce homogeneously // typed maps unless tagged as JSON (and the values are assignable to // the JSON value union type). Type overall_key_type = inference_context_->InstantiateTypeParams(TypeParamType("K")); Type overall_value_type = inference_context_->InstantiateTypeParams(TypeParamType("V")); auto assignability_context = inference_context_->CreateAssignabilityContext(); for (const auto& entry : map.entries()) { const Expr* key = &entry.key(); Type key_type = GetDeducedType(key); if (!IsSupportedKeyType(key_type)) { // The Go type checker implementation can allow any type as a map key, but // per the spec this should be limited to the types listed in // IsSupportedKeyType. // // To match the Go implementation, we just warn here, but in the future // we should consider making this an error. ReportIssue(TypeCheckIssue( Severity::kWarning, ast_->ComputeSourceLocation(key->id()), absl::StrCat( "unsupported map key type: ", FormatTypeName(inference_context_->FinalizeType(key_type))))); } if (!assignability_context.IsAssignable(key_type, overall_key_type)) { overall_key_type = DynType(); } } if (!overall_key_type.IsDyn()) { assignability_context.UpdateInferredTypeAssignments(); } assignability_context.Reset(); for (const auto& entry : map.entries()) { const Expr* value = &entry.value(); Type value_type = GetDeducedType(value); if (entry.optional()) { if (value_type.IsOptional()) { value_type = value_type.GetOptional().GetParameter(); } else { ReportTypeMismatch(entry.value().id(), OptionalType(arena_, value_type), value_type); continue; } } if (!inference_context_->IsAssignable(value_type, overall_value_type)) { overall_value_type = DynType(); } } if (!overall_value_type.IsDyn()) { assignability_context.UpdateInferredTypeAssignments(); } types_[&expr] = inference_context_->FullySubstitute( MapType(arena_, overall_key_type, overall_value_type)); } void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { // Follows list type inferencing behavior in Go (see map comments above). Type overall_elem_type = inference_context_->InstantiateTypeParams(TypeParamType("E")); auto assignability_context = inference_context_->CreateAssignabilityContext(); for (const auto& element : list.elements()) { const Expr* value = &element.expr(); Type value_type = GetDeducedType(value); if (element.optional()) { if (value_type.IsOptional()) { value_type = value_type.GetOptional().GetParameter(); } else { ReportTypeMismatch(element.expr().id(), OptionalType(arena_, value_type), value_type); continue; } } if (!assignability_context.IsAssignable(value_type, overall_elem_type)) { overall_elem_type = DynType(); } } if (!overall_elem_type.IsDyn()) { assignability_context.UpdateInferredTypeAssignments(); } types_[&expr] = inference_context_->FullySubstitute(ListType(arena_, overall_elem_type)); } void ResolveVisitor::PostVisitStruct(const Expr& expr, const StructExpr& create_struct) { absl::Status status; std::string resolved_name; Type resolved_type; namespace_generator_.GenerateCandidates( create_struct.name(), [&](const absl::string_view name) { auto type = env_->LookupTypeName(name); if (!type.ok()) { status.Update(type.status()); return false; } else if (type->has_value()) { resolved_name = name; resolved_type = **type; return false; } return true; }); if (!status.ok()) { status_.Update(status); return; } if (resolved_name.empty()) { ReportMissingReference(expr, create_struct.name()); types_[&expr] = ErrorType(); return; } if (resolved_type.kind() != TypeKind::kStruct && !IsWellKnownMessageType(resolved_name)) { ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(expr.id()), absl::StrCat("type '", resolved_name, "' does not support message creation"))); types_[&expr] = ErrorType(); return; } types_[&expr] = resolved_type; struct_types_[&expr] = resolved_name; status_.Update( CheckFieldAssignments(expr, create_struct, resolved_type, resolved_name)); } void ResolveVisitor::PostVisitCall(const Expr& expr, const CallExpr& call) { if (call.function() == kOptionalSelect) { HandleOptSelect(expr); return; } // Handle disambiguation of namespaced functions. if (auto iter = maybe_namespaced_functions_.find(&expr); iter != maybe_namespaced_functions_.end()) { std::string namespaced_name = absl::StrCat(FormatCandidate(iter->second), ".", call.function()); const FunctionDecl* decl = ResolveFunctionCallShape(expr, namespaced_name, call.args().size(), /* is_receiver= */ false); if (decl != nullptr) { ResolveFunctionOverloads(expr, *decl, call.args().size(), /* is_receiver= */ false, /* is_namespaced= */ true); return; } // Else, resolve the target as an attribute (deferred earlier), then // resolve the function call normally. ResolveQualifiedIdentifier(call.target(), iter->second); } int arg_count = call.args().size(); if (call.has_target()) { ++arg_count; } const FunctionDecl* decl = ResolveFunctionCallShape( expr, call.function(), arg_count, call.has_target()); if (decl == nullptr) { ReportMissingReference(expr, call.function()); types_[&expr] = ErrorType(); return; } ResolveFunctionOverloads(expr, *decl, arg_count, call.has_target(), /* is_namespaced= */ false); } void ResolveVisitor::PreVisitComprehension( const Expr& expr, const ComprehensionExpr& comprehension) { std::unique_ptr accu_scope = current_scope_->MakeNestedScope(); auto* accu_scope_ptr = accu_scope.get(); std::unique_ptr iter_scope = accu_scope->MakeNestedScope(); auto* iter_scope_ptr = iter_scope.get(); // Keep the temporary decls alive as long as the visitor. comprehension_vars_.push_back(std::move(accu_scope)); comprehension_vars_.push_back(std::move(iter_scope)); comprehension_scopes_.push_back( {&expr, current_scope_, accu_scope_ptr, iter_scope_ptr}); } void ResolveVisitor::PostVisitComprehension( const Expr& expr, const ComprehensionExpr& comprehension) { comprehension_scopes_.pop_back(); types_[&expr] = inference_context_->FullySubstitute( GetDeducedType(&comprehension.result())); } void ResolveVisitor::PreVisitComprehensionSubexpression( const Expr& expr, const ComprehensionExpr& comprehension, ComprehensionArg comprehension_arg) { if (comprehension_scopes_.empty()) { status_.Update(absl::InternalError( "Comprehension scope stack is empty in comprehension")); return; } auto& scope = comprehension_scopes_.back(); if (scope.comprehension_expr != &expr) { status_.Update(absl::InternalError("Comprehension scope stack broken")); return; } switch (comprehension_arg) { case ComprehensionArg::LOOP_CONDITION: current_scope_ = scope.accu_scope; break; case ComprehensionArg::LOOP_STEP: current_scope_ = scope.iter_scope; break; case ComprehensionArg::RESULT: current_scope_ = scope.accu_scope; break; default: current_scope_ = scope.parent; break; } } void ResolveVisitor::PostVisitComprehensionSubexpression( const Expr& expr, const ComprehensionExpr& comprehension, ComprehensionArg comprehension_arg) { if (comprehension_scopes_.empty()) { status_.Update(absl::InternalError( "Comprehension scope stack is empty in comprehension")); return; } auto& scope = comprehension_scopes_.back(); if (scope.comprehension_expr != &expr) { status_.Update(absl::InternalError("Comprehension scope stack broken")); return; } current_scope_ = scope.parent; // Setting the type depends on the order the visitor is called -- the visitor // guarantees iter range and accu init are visited before subexpressions where // the corresponding variables can be referenced. switch (comprehension_arg) { case ComprehensionArg::ACCU_INIT: scope.accu_scope->InsertVariableIfAbsent( MakeVariableDecl(comprehension.accu_var(), GetDeducedType(&comprehension.accu_init()))); break; case ComprehensionArg::ITER_RANGE: { Type range_type = GetDeducedType(&comprehension.iter_range()); Type iter_type = DynType(); // iter_var for non comprehensions v2. Type iter_type1 = DynType(); // iter_var for comprehensions v2. Type iter_type2 = DynType(); // iter_var2 for comprehensions v2. switch (range_type.kind()) { case TypeKind::kList: iter_type1 = IntType(); iter_type = iter_type2 = range_type.GetList().element(); break; case TypeKind::kMap: iter_type = iter_type1 = range_type.GetMap().key(); iter_type2 = range_type.GetMap().value(); break; case TypeKind::kDyn: break; default: ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(comprehension.iter_range().id()), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(range_type)), "' cannot be the range of a comprehension (must be " "list, map, or dynamic)"))); break; } if (comprehension.iter_var2().empty()) { scope.iter_scope->InsertVariableIfAbsent( MakeVariableDecl(comprehension.iter_var(), iter_type)); } else { scope.iter_scope->InsertVariableIfAbsent( MakeVariableDecl(comprehension.iter_var(), iter_type1)); scope.iter_scope->InsertVariableIfAbsent( MakeVariableDecl(comprehension.iter_var2(), iter_type2)); } break; } default: break; } } void ResolveVisitor::PostVisitSelect(const Expr& expr, const SelectExpr& select) { if (!deferred_select_operations_.contains(&expr)) { ResolveSelectOperation(expr, select.field(), select.operand()); } } const FunctionDecl* ResolveVisitor::ResolveFunctionCallShape( const Expr& expr, absl::string_view function_name, int arg_count, bool is_receiver) { const FunctionDecl* decl = nullptr; namespace_generator_.GenerateCandidates( function_name, [&, this](absl::string_view candidate) -> bool { decl = env_->LookupFunction(candidate); if (decl == nullptr) { return true; } for (const auto& ovl : decl->overloads()) { if (ovl.member() == is_receiver && ovl.args().size() == arg_count) { return false; } } // Name match, but no matching overloads. decl = nullptr; return true; }); return decl; } void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, const FunctionDecl& decl, int arg_count, bool is_receiver, bool is_namespaced) { std::vector arg_types; arg_types.reserve(arg_count); if (is_receiver) { arg_types.push_back(GetDeducedType(&expr.call_expr().target())); } for (int i = 0; i < expr.call_expr().args().size(); ++i) { arg_types.push_back(GetDeducedType(&expr.call_expr().args()[i])); } absl::optional resolution = inference_context_->ResolveOverload(decl, arg_types, is_receiver); if (!resolution.has_value()) { ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(expr.id()), absl::StrCat("found no matching overload for '", decl.name(), "' applied to '(", absl::StrJoin(arg_types, ", ", [](std::string* out, const Type& type) { out->append(FormatTypeName(type)); }), ")'"))); types_[&expr] = ErrorType(); return; } auto* result_decl = google::protobuf::Arena::Create(arena_); result_decl->set_name(decl.name()); for (const auto& ovl : resolution->overloads) { absl::Status s = result_decl->AddOverload(ovl); if (!s.ok()) { // Overloads should be filtered list from the original declaration, // so a status means an invariant was broken. status_.Update(absl::InternalError(absl::StrCat( "failed to add overload to resolved function declaration: ", s))); } } functions_[&expr] = {result_decl, is_namespaced}; types_[&expr] = resolution->result_type; } const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier( absl::string_view name) { // Note: if we see a leading dot, this shouldn't resolve to a local variable, // but we need to check whether we need to disambiguate against a global in // the reference map. if (absl::StartsWith(name, ".")) { name = name.substr(1); } return current_scope_->LookupLocalVariable(name); } const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier( absl::string_view name) { if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) { return decl; } absl::StatusOr> constant = env_->LookupTypeConstant(arena_, name); if (!constant.ok()) { status_.Update(constant.status()); return nullptr; } if (constant->has_value()) { if (constant->value().type().kind() == TypeKind::kEnum) { // Treat enum constant as just an int after resolving the reference. // This preserves existing behavior in the other type checkers. constant->value().set_type(IntType()); } return google::protobuf::Arena::Create( arena_, std::move(constant).value().value()); } return nullptr; } void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, absl::string_view name) { // Local variables (comprehension, bind) are simple identifiers so we can // skip generating the different namespace-qualified candidates. const VariableDecl* local_decl = LookupLocalIdentifier(name); if (local_decl != nullptr && !absl::StartsWith(name, ".")) { attributes_[&expr] = {local_decl, false}; types_[&expr] = inference_context_->InstantiateTypeParams(local_decl->type()); return; } const VariableDecl* decl = nullptr; namespace_generator_.GenerateCandidates( name, [&decl, this](absl::string_view candidate) { decl = LookupGlobalIdentifier(candidate); // continue searching. return decl == nullptr; }); if (decl != nullptr) { attributes_[&expr] = {decl, /* requires_disambiguation= */ local_decl != nullptr}; types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); return; } ReportMissingReference(expr, name); types_[&expr] = ErrorType(); } void ResolveVisitor::ResolveQualifiedIdentifier( const Expr& expr, absl::Span qualifiers) { if (qualifiers.size() == 1) { ResolveSimpleIdentifier(expr, qualifiers[0]); return; } // Local variables (comprehension, bind) are simple identifiers so we can // skip generating the different namespace-qualified candidates. const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); const VariableDecl* decl = nullptr; int matched_segment_index = -1; if (local_decl != nullptr && !absl::StartsWith(qualifiers[0], ".")) { decl = local_decl; matched_segment_index = 0; } else { namespace_generator_.GenerateCandidates( qualifiers, [&decl, &matched_segment_index, this]( absl::string_view candidate, int segment_index) { decl = LookupGlobalIdentifier(candidate); if (decl != nullptr) { matched_segment_index = segment_index; return false; } return true; }); } if (decl == nullptr) { ReportMissingReference(expr, FormatCandidate(qualifiers)); types_[&expr] = ErrorType(); return; } const int num_select_opts = qualifiers.size() - matched_segment_index - 1; const Expr* root = &expr; std::vector select_opts; select_opts.reserve(num_select_opts); for (int i = 0; i < num_select_opts; ++i) { select_opts.push_back(root); root = &root->select_expr().operand(); } attributes_[root] = {decl, /* requires_disambiguation= */ decl != local_decl && local_decl != nullptr}; types_[root] = inference_context_->InstantiateTypeParams(decl->type()); // fix-up select operations that were deferred. for (auto iter = select_opts.rbegin(); iter != select_opts.rend(); ++iter) { ResolveSelectOperation(**iter, (*iter)->select_expr().field(), (*iter)->select_expr().operand()); } } absl::optional ResolveVisitor::CheckFieldType(int64_t id, const Type& operand_type, absl::string_view field) { if (operand_type.kind() == TypeKind::kDyn || operand_type.kind() == TypeKind::kAny) { return DynType(); } switch (operand_type.kind()) { case TypeKind::kStruct: { StructType struct_type = operand_type.GetStruct(); auto field_info = env_->LookupStructField(struct_type.name(), field); if (!field_info.ok()) { status_.Update(field_info.status()); return absl::nullopt; } if (!field_info->has_value()) { ReportUndefinedField(id, field, struct_type.name()); return absl::nullopt; } auto type = field_info->value().GetType(); if (type.kind() == TypeKind::kEnum) { // Treat enum as just an int. return IntType(); } return type; } case TypeKind::kMap: { MapType map_type = operand_type.GetMap(); return map_type.GetValue(); } case TypeKind::kTypeParam: { // If the operand is a free type variable, bind it to dyn to prevent // an alternative type from being inferred. if (inference_context_->IsAssignable(DynType(), operand_type)) { return DynType(); } break; } default: break; } ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(id), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(operand_type)), "' cannot be the operand of a select operation"))); return absl::nullopt; } void ResolveVisitor::ResolveSelectOperation(const Expr& expr, absl::string_view field, const Expr& operand) { const Type& operand_type = GetDeducedType(&operand); absl::optional result_type; int64_t id = expr.id(); // Support short-hand optional chaining. if (operand_type.IsOptional()) { auto optional_type = operand_type.GetOptional(); Type held_type = optional_type.GetParameter(); result_type = CheckFieldType(id, held_type, field); if (result_type.has_value()) { result_type = OptionalType(arena_, *result_type); } } else { result_type = CheckFieldType(id, operand_type, field); } if (!result_type.has_value()) { types_[&expr] = ErrorType(); return; } if (expr.select_expr().test_only()) { types_[&expr] = BoolType(); } else { types_[&expr] = *result_type; } } void ResolveVisitor::HandleOptSelect(const Expr& expr) { if (expr.call_expr().function() != kOptionalSelect || expr.call_expr().args().size() != 2) { status_.Update( absl::InvalidArgumentError("Malformed optional select expression.")); return; } const Expr* operand = &expr.call_expr().args().at(0); const Expr* field = &expr.call_expr().args().at(1); if (!field->has_const_expr() || !field->const_expr().has_string_value()) { status_.Update( absl::InvalidArgumentError("Malformed optional select expression.")); return; } Type operand_type = GetDeducedType(operand); if (operand_type.IsOptional()) { operand_type = operand_type.GetOptional().GetParameter(); } absl::optional field_type = CheckFieldType( expr.id(), operand_type, field->const_expr().string_value()); if (!field_type.has_value()) { types_[&expr] = ErrorType(); return; } const FunctionDecl* select_decl = env_->LookupFunction(kOptionalSelect); types_[&expr] = OptionalType(arena_, field_type.value()); // Remove the type annotation for the field now that we've validated it as // a valid field access instead of a string literal. types_.erase(field); if (select_decl != nullptr) { functions_[&expr] = FunctionResolution{select_decl, /*.namespace_rewrite=*/false}; } } class ResolveRewriter : public AstRewriterBase { public: explicit ResolveRewriter(const ResolveVisitor& visitor, const TypeInferenceContext& inference_context, const CheckerOptions& options, Ast::ReferenceMap& references, Ast::TypeMap& types, ValidationResult::TypeMap& resolved_types) : visitor_(visitor), inference_context_(inference_context), reference_map_(references), type_map_(types), resolved_types_(resolved_types), options_(options) {} bool PostVisitRewrite(Expr& expr) override { bool rewritten = false; if (auto iter = visitor_.attributes().find(&expr); iter != visitor_.attributes().end()) { const VariableDecl* decl = iter->second.decl; auto& ast_ref = reference_map_[expr.id()]; std::string name = decl->name(); if (iter->second.requires_disambiguation && !absl::StartsWith(name, ".")) { name = absl::StrCat(".", name); } ast_ref.set_name(name); if (decl->has_value()) { ast_ref.set_value(decl->value()); } expr.mutable_ident_expr().set_name(std::move(name)); rewritten = true; } else if (auto iter = visitor_.functions().find(&expr); iter != visitor_.functions().end()) { const FunctionDecl* decl = iter->second.decl; const bool needs_rewrite = iter->second.namespace_rewrite; auto& ast_ref = reference_map_[expr.id()]; if (options_.enable_function_name_in_reference) { ast_ref.set_name(decl->name()); } for (const auto& overload : decl->overloads()) { ast_ref.mutable_overload_id().push_back(overload.id()); } expr.mutable_call_expr().set_function(decl->name()); if (needs_rewrite && expr.call_expr().has_target()) { expr.mutable_call_expr().set_target(nullptr); } rewritten = true; } else if (auto iter = visitor_.struct_types().find(&expr); iter != visitor_.struct_types().end()) { auto& ast_ref = reference_map_[expr.id()]; ast_ref.set_name(iter->second); if (expr.has_struct_expr() && options_.update_struct_type_names) { expr.mutable_struct_expr().set_name(iter->second); } rewritten = true; } if (auto iter = visitor_.types().find(&expr); iter != visitor_.types().end()) { auto flattened_type = FlattenType(inference_context_.FinalizeType(iter->second)); if (!flattened_type.ok()) { status_.Update(flattened_type.status()); return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); resolved_types_[expr.id()] = iter->second; rewritten = true; } return rewritten; } const absl::Status& status() const { return status_; } private: absl::Status status_; const ResolveVisitor& visitor_; const TypeInferenceContext& inference_context_; Ast::ReferenceMap& reference_map_; Ast::TypeMap& type_map_; ValidationResult::TypeMap& resolved_types_; const CheckerOptions& options_; }; } // namespace absl::StatusOr TypeCheckerImpl::CheckImpl( std::unique_ptr ast, google::protobuf::Arena* arena) const { std::optional type_arena; if (arena == nullptr) { type_arena.emplace(); arena = &(*type_arena); } std::vector issues; CEL_ASSIGN_OR_RETURN(auto generator, NamespaceGenerator::Create(env_.container())); TypeInferenceContext type_inference_context( arena, options_.enable_legacy_null_assignment); ResolveVisitor visitor(std::move(generator), env_, *ast, type_inference_context, issues, arena); TraversalOptions opts; opts.use_comprehension_callbacks = true; bool error_limit_reached = false; auto traversal = AstTraversal::Create(ast->root_expr(), opts); for (int step = 0; step < options_.max_expression_node_count * 2; ++step) { bool has_next = traversal.Step(visitor); if (!visitor.status().ok()) { return visitor.status(); } if (visitor.error_count() > options_.max_error_issues) { error_limit_reached = true; break; } if (!has_next) { break; } } if (!traversal.IsDone() && !error_limit_reached) { return absl::InvalidArgumentError( absl::StrCat("Maximum expression node count exceeded: ", options_.max_expression_node_count)); } if (error_limit_reached) { issues.push_back(TypeCheckIssue::CreateError( {}, absl::StrCat("maximum number of ERROR issues exceeded: ", options_.max_error_issues))); } else if (env_.expected_type().has_value()) { visitor.AssertExpectedType(ast->root_expr(), *env_.expected_type()); } // If any issues are errors, return without an AST. for (const auto& issue : issues) { if (issue.severity() == Severity::kError) { return ValidationResult(std::move(issues)); } } // Apply updates as needed. // Happens in a second pass to simplify validating that pointers haven't // been invalidated by other updates. ValidationResult::TypeMap resolved_types; ResolveRewriter rewriter(visitor, type_inference_context, options_, ast->mutable_reference_map(), ast->mutable_type_map(), resolved_types); AstRewrite(ast->mutable_root_expr(), rewriter); CEL_RETURN_IF_ERROR(rewriter.status()); ast->set_is_checked(true); if (options_.use_json_field_names) { ast->mutable_source_info().mutable_extensions().push_back( cel::ExtensionSpec("json_name", std::make_unique(1, 1), {cel::ExtensionSpec::Component::kRuntime})); } auto result = ValidationResult(std::move(ast), std::move(issues)); if (!type_arena.has_value()) { // cel::Type values will expire after this function returns when the local // arena is destructed. Only set the resolved type map if we're using the // caller's arena. result.SetResolvedTypeMap(std::move(resolved_types)); } return result; } std::unique_ptr TypeCheckerImpl::ToBuilder() const { return std::make_unique(options_, env_); } } // namespace cel::checker_internal ================================================ FILE: checker/internal/type_checker_impl.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_IMPL_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_IMPL_H_ #include #include #include "absl/status/statusor.h" #include "checker/checker_options.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { // Implementation of the TypeChecker interface. // // See cel::TypeCheckerBuilder for constructing instances. class TypeCheckerImpl : public TypeChecker { public: explicit TypeCheckerImpl(TypeCheckEnv env, CheckerOptions options = {}) : env_(std::move(env)), options_(options) {} TypeCheckerImpl(const TypeCheckerImpl&) = delete; TypeCheckerImpl& operator=(const TypeCheckerImpl&) = delete; TypeCheckerImpl(TypeCheckerImpl&&) = delete; TypeCheckerImpl& operator=(TypeCheckerImpl&&) = delete; absl::StatusOr CheckImpl( std::unique_ptr ast, google::protobuf::Arena* arena) const override; std::unique_ptr ToBuilder() const override; private: TypeCheckEnv env_; google::protobuf::Arena type_arena_; CheckerOptions options_; }; } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_CHECKER_IMPL_H_ ================================================ FILE: checker/internal/type_checker_impl_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/type_checker_impl.h" #include #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" #include "checker/internal/type_check_env.h" #include "checker/type_check_issue.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/container.h" #include "common/decl.h" #include "common/expr.h" #include "common/source.h" #include "common/type.h" #include "common/type_introspector.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "testutil/baseline_tests.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel { namespace checker_internal { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::Reference; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::GetSharedTestingDescriptorPool; using ::testing::_; using ::testing::Contains; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Not; using ::testing::Pair; using ::testing::Property; using ::testing::SizeIs; using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; namespace testpb3 = ::cel::expr::conformance::proto3; namespace testpb2 = ::cel::expr::conformance::proto2; std::string SevString(Severity severity) { switch (severity) { case Severity::kDeprecated: return "Deprecated"; case Severity::kError: return "Error"; case Severity::kWarning: return "Warning"; case Severity::kInformation: return "Information"; } } } // namespace } // namespace checker_internal template void AbslStringify(Sink& sink, const TypeCheckIssue& issue) { absl::Format(&sink, "TypeCheckIssue(%s): %s", checker_internal::SevString(issue.severity()), issue.message()); } namespace checker_internal { namespace { google::protobuf::Arena* absl_nonnull TestTypeArena() { static absl::NoDestructor kArena; return &(*kArena); } FunctionDecl MakeIdentFunction() { auto decl = MakeFunctionDecl( "identity", MakeOverloadDecl("identity", TypeParamType("A"), TypeParamType("A"))); ABSL_CHECK_OK(decl.status()); return decl.value(); } MATCHER_P2(IsIssueWithSubstring, severity, substring, "") { const TypeCheckIssue& issue = arg; if (issue.severity() == severity && absl::StrContains(issue.message(), substring)) { return true; } *result_listener << "expected: " << SevString(severity) << " " << substring << "\nactual: " << SevString(issue.severity()) << " " << issue.message(); return false; } MATCHER_P(IsVariableReference, var_name, "") { const Reference& reference = arg; if (reference.name() == var_name) { return true; } *result_listener << "expected: " << var_name << "\nactual: " << reference.name(); return false; } MATCHER_P2(IsFunctionReference, fn_name, overloads, "") { const Reference& reference = arg; absl::flat_hash_set got_overload_set( reference.overload_id().begin(), reference.overload_id().end()); absl::flat_hash_set want_overload_set(overloads.begin(), overloads.end()); if (got_overload_set != want_overload_set) { *result_listener << "reference to " << fn_name << "\n" << "expected overload_ids: " << absl::StrJoin(want_overload_set, ",") << "\nactual: " << absl::StrJoin(got_overload_set, ","); } return got_overload_set == want_overload_set; } absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena, TypeCheckEnv& env) { Type list_of_a = ListType(arena, TypeParamType("A")); FunctionDecl add_op; add_op.set_name("_+_"); CEL_RETURN_IF_ERROR(add_op.AddOverload( MakeOverloadDecl("add_int_int", IntType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(add_op.AddOverload( MakeOverloadDecl("add_uint_uint", UintType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( "add_double_double", DoubleType(), DoubleType(), DoubleType()))); CEL_RETURN_IF_ERROR(add_op.AddOverload( MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); FunctionDecl not_op; not_op.set_name("!_"); CEL_RETURN_IF_ERROR(not_op.AddOverload( MakeOverloadDecl("logical_not", /*return_type=*/BoolType{}, BoolType{}))); FunctionDecl not_strictly_false; not_strictly_false.set_name("@not_strictly_false"); CEL_RETURN_IF_ERROR(not_strictly_false.AddOverload( MakeOverloadDecl("not_strictly_false", /*return_type=*/BoolType{}, DynType{}))); FunctionDecl mult_op; mult_op.set_name("_*_"); CEL_RETURN_IF_ERROR(mult_op.AddOverload( MakeOverloadDecl("mult_int_int", /*return_type=*/IntType(), IntType(), IntType()))); FunctionDecl or_op; or_op.set_name("_||_"); CEL_RETURN_IF_ERROR(or_op.AddOverload( MakeOverloadDecl("logical_or", /*return_type=*/BoolType{}, BoolType{}, BoolType{}))); FunctionDecl and_op; and_op.set_name("_&&_"); CEL_RETURN_IF_ERROR(and_op.AddOverload( MakeOverloadDecl("logical_and", /*return_type=*/BoolType{}, BoolType{}, BoolType{}))); FunctionDecl lt_op; lt_op.set_name("_<_"); CEL_RETURN_IF_ERROR(lt_op.AddOverload( MakeOverloadDecl("lt_int_int", /*return_type=*/BoolType{}, IntType(), IntType()))); FunctionDecl gt_op; gt_op.set_name("_>_"); CEL_RETURN_IF_ERROR(gt_op.AddOverload( MakeOverloadDecl("gt_int_int", /*return_type=*/BoolType{}, IntType(), IntType()))); FunctionDecl eq_op; eq_op.set_name("_==_"); CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl( "equals", /*return_type=*/BoolType{}, TypeParamType("A"), TypeParamType("A")))); FunctionDecl ne_op; ne_op.set_name("_!=_"); CEL_RETURN_IF_ERROR(ne_op.AddOverload(MakeOverloadDecl( "not_equals", /*return_type=*/BoolType{}, TypeParamType("A"), TypeParamType("A")))); FunctionDecl ternary_op; ternary_op.set_name("_?_:_"); CEL_RETURN_IF_ERROR(ternary_op.AddOverload(MakeOverloadDecl( "conditional", /*return_type=*/ TypeParamType("A"), BoolType{}, TypeParamType("A"), TypeParamType("A")))); FunctionDecl index_op; index_op.set_name("_[_]"); CEL_RETURN_IF_ERROR(index_op.AddOverload(MakeOverloadDecl( "index", /*return_type=*/ TypeParamType("A"), ListType(arena, TypeParamType("A")), IntType()))); FunctionDecl to_int; to_int.set_name("int"); CEL_RETURN_IF_ERROR(to_int.AddOverload( MakeOverloadDecl("to_int", /*return_type=*/IntType(), DynType()))); FunctionDecl to_duration; to_duration.set_name("duration"); CEL_RETURN_IF_ERROR(to_duration.AddOverload( MakeOverloadDecl("to_duration", /*return_type=*/DurationType(), StringType()))); FunctionDecl to_timestamp; to_timestamp.set_name("timestamp"); CEL_RETURN_IF_ERROR(to_timestamp.AddOverload( MakeOverloadDecl("to_timestamp", /*return_type=*/TimestampType(), IntType()))); FunctionDecl to_dyn; to_dyn.set_name("dyn"); CEL_RETURN_IF_ERROR(to_dyn.AddOverload( MakeOverloadDecl("to_dyn", /*return_type=*/DynType(), TypeParamType("A")))); FunctionDecl to_type; to_type.set_name("type"); CEL_RETURN_IF_ERROR(to_type.AddOverload( MakeOverloadDecl("to_type", /*return_type=*/TypeType(arena, TypeParamType("A")), TypeParamType("A")))); env.InsertFunctionIfAbsent(std::move(not_op)); env.InsertFunctionIfAbsent(std::move(not_strictly_false)); env.InsertFunctionIfAbsent(std::move(add_op)); env.InsertFunctionIfAbsent(std::move(mult_op)); env.InsertFunctionIfAbsent(std::move(or_op)); env.InsertFunctionIfAbsent(std::move(and_op)); env.InsertFunctionIfAbsent(std::move(lt_op)); env.InsertFunctionIfAbsent(std::move(gt_op)); env.InsertFunctionIfAbsent(std::move(to_int)); env.InsertFunctionIfAbsent(std::move(eq_op)); env.InsertFunctionIfAbsent(std::move(ne_op)); env.InsertFunctionIfAbsent(std::move(ternary_op)); env.InsertFunctionIfAbsent(std::move(index_op)); env.InsertFunctionIfAbsent(std::move(to_dyn)); env.InsertFunctionIfAbsent(std::move(to_type)); env.InsertFunctionIfAbsent(std::move(to_duration)); env.InsertFunctionIfAbsent(std::move(to_timestamp)); return absl::OkStatus(); } TEST(TypeCheckerImplTest, SmokeTest) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("1 + 2")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(TypeCheckerImplTest, SimpleIdentsResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(TypeCheckerImplTest, ReportMissingIdentDecl) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), ElementsAre(IsIssueWithSubstring(Severity::kError, "undeclared reference to 'y'"))); } TEST(TypeCheckerImplTest, ErrorLimitInclusive) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); CheckerOptions options; options.max_error_issues = 1; TypeCheckerImpl impl(std::move(env), options); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("1 + y")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), ElementsAre(IsIssueWithSubstring(Severity::kError, "undeclared reference to 'y'"))); ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("x + y + z")); ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT( result.GetIssues(), ElementsAre( IsIssueWithSubstring(Severity::kError, "undeclared reference to 'x'"), IsIssueWithSubstring(Severity::kError, "undeclared reference to 'y'"), IsIssueWithSubstring(Severity::kError, "maximum number of ERROR issues exceeded: 1"))); } MATCHER_P3(IsIssueWithLocation, line, column, message, "") { const TypeCheckIssue& issue = arg; if (issue.location().line == line && issue.location().column == column && absl::StrContains(issue.message(), message)) { return true; } return false; } TEST(TypeCheckerImplTest, LocationCalculation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto source, NewSource("a ||\n" "b ||\n" " c ||\n" " d")); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(source->content().ToString())); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT( result.GetIssues(), ElementsAre(IsIssueWithLocation(1, 0, "undeclared reference to 'a'"), IsIssueWithLocation(2, 0, "undeclared reference to 'b'"), IsIssueWithLocation(3, 1, "undeclared reference to 'c'"), IsIssueWithLocation(4, 1, "undeclared reference to 'd'"))) << absl::StrJoin(result.GetIssues(), "\n", [&](std::string* out, const TypeCheckIssue& issue) { absl::StrAppend(out, issue.ToDisplayString(*source)); }); } TEST(TypeCheckerImplTest, QualifiedIdentsResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("x.z", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.y + x.z")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(TypeCheckerImplTest, ReportMissingQualifiedIdentDecl) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("y.x")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), ElementsAre(IsIssueWithSubstring( Severity::kError, "undeclared reference to 'y.x'"))); } TEST(TypeCheckerImplTest, ResolveMostQualfiedIdent) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("x.y", MapType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.y.z")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->reference_map(), Contains(Pair(_, IsVariableReference("x.y")))); } TEST(TypeCheckerImplTest, MemberFunctionCallResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); FunctionDecl foo; foo.set_name("foo"); ASSERT_THAT(foo.AddOverload(MakeMemberOverloadDecl("int_foo_int", /*return_type=*/IntType(), IntType(), IntType())), IsOk()); env.InsertFunctionIfAbsent(std::move(foo)); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.foo(y)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(TypeCheckerImplTest, MemberFunctionCallNotDeclared) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.foo(y)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), ElementsAre(IsIssueWithSubstring( Severity::kError, "undeclared reference to 'foo'"))); } TEST(TypeCheckerImplTest, FunctionShapeMismatch) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); // foo(int, int) -> int ASSERT_OK_AND_ASSIGN( auto foo, MakeFunctionDecl("foo", MakeOverloadDecl("foo_int_int", IntType(), IntType(), IntType()))); env.InsertFunctionIfAbsent(foo); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo(1, 2, 3)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), ElementsAre(IsIssueWithSubstring( Severity::kError, "undeclared reference to 'foo'"))); } TEST(TypeCheckerImplTest, NamespaceFunctionCallResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); // Variables env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); // add x.foo as a namespaced function. FunctionDecl foo; foo.set_name("x.foo"); ASSERT_THAT( foo.AddOverload(MakeOverloadDecl("x_foo_int", /*return_type=*/IntType(), IntType())), IsOk()); env.InsertFunctionIfAbsent(std::move(foo)); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.foo(y)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.foo"); EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); } TEST(TypeCheckerImplTest, NamespacedFunctionSkipsFieldCheck) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); // Variables env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); // add x.foo as a namespaced function. FunctionDecl foo; foo.set_name("x.y.foo"); ASSERT_THAT( foo.AddOverload(MakeOverloadDecl("x_y_foo_int", /*return_type=*/IntType(), IntType())), IsOk()); env.InsertFunctionIfAbsent(std::move(foo)); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.y.foo(x)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.y.foo"); EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); } TEST(TypeCheckerImplTest, NamespacedFunctionWithAbbreviation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); // Variables env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); FunctionDecl foo; foo.set_name("x.y.foo"); ASSERT_THAT( foo.AddOverload(MakeOverloadDecl("x_y_foo_int", /*return_type=*/IntType(), IntType())), IsOk()); env.InsertFunctionIfAbsent(std::move(foo)); env.set_container(*MakeExpressionContainer("", "x.y.foo")); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo(x)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.y.foo"); EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); } TEST(TypeCheckerImplTest, MixedListTypeToDyn) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[1, 'a']")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); EXPECT_TRUE( result.GetAst()->type_map().at(1).list_type().elem_type().has_dyn()); } TEST(TypeCheckerImplTest, FreeListTypeToDyn) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[]")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); EXPECT_TRUE( result.GetAst()->type_map().at(1).list_type().elem_type().has_dyn()); } TEST(TypeCheckerImplTest, FreeMapValueTypeToDyn) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}.field")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); auto root_id = result.GetAst()->root_expr().id(); EXPECT_TRUE(result.GetAst()->type_map().at(root_id).has_dyn()); } TEST(TypeCheckerImplTest, FreeMapTypeToDyn) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_TRUE(checked_ast->type_map().at(1).map_type().key_type().has_dyn()); EXPECT_TRUE(checked_ast->type_map().at(1).map_type().value_type().has_dyn()); } TEST(TypeCheckerImplTest, MapTypeWithMixedKeys) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'a': 1, 2: 3}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); const auto* checked_ast = result.GetAst(); EXPECT_TRUE(checked_ast->type_map().at(1).map_type().key_type().has_dyn()); EXPECT_EQ(checked_ast->type_map().at(1).map_type().value_type().primitive(), PrimitiveType::kInt64); } TEST(TypeCheckerImplTest, MapTypeUnsupportedKeyWarns) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{{}: 'a'}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), ElementsAre(IsIssueWithSubstring(Severity::kWarning, "unsupported map key type:"))); } TEST(TypeCheckerImplTest, MapTypeWithMixedValues) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'a': 1, 'b': '2'}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->type_map().at(1).map_type().key_type().primitive(), PrimitiveType::kString); EXPECT_TRUE(checked_ast->type_map().at(1).map_type().value_type().has_dyn()); } TEST(TypeCheckerImplTest, ComprehensionVariablesResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[1, 2, 3].exists(x, x * x > 10)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(TypeCheckerImplTest, MapComprehensionVariablesResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{1: 3, 2: 4}.exists(x, x == 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(TypeCheckerImplTest, NestedComprehensions) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst("[1, 2].all(x, ['1', '2'].exists(y, int(y) == x))")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(TypeCheckerImplTest, ComprehensionVarsShadowNamespacePriorityRules) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container(*MakeExpressionContainer("com")); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); // Namespace compre var shadows com.x env.InsertVariableIfAbsent(MakeVariableDecl("com.x", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("['1', '2'].exists(x, x == '2')")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->reference_map(), Not(Contains(Pair(_, IsVariableReference("com.x"))))); } TEST(TypeCheckerImplTest, ComprehensionVarsShadowsQualifiedIdent) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[{'y': '2'}].all(x, x.y == '2')")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->reference_map(), Not(Contains(Pair(_, IsVariableReference("x.y"))))); } TEST(TypeCheckerImplTest, ComprehensionVarsShadowsQualifiedIdentTypeError) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[0].all(x, x.y == 0)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT( result.FormatError(), HasSubstr("type 'int' cannot be the operand of a select operation")); } TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesQualifiedIdent) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[{'y': 0}].all(x, .x.y == 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->reference_map(), Contains(Pair(_, IsVariableReference(".x.y")))); } TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesQualifiedIdentMixed) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x.y", StringType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[{'y': 0}].all(x, .x.y != x.y)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT( result.FormatError(), HasSubstr("no matching overload for '_!=_' applied to '(string, int)'")); } TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesIdent) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("['foo'].all(x, .x == 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->reference_map(), Contains(Pair(_, IsVariableReference(".x")))); } TEST(TypeCheckerImplTest, ComprehensionVarsCyclicParamAssignability) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); // This is valid because the list construction in the transform will resolve // to list(dyn) since candidates E1 -> E2 and list(E1) -> E2 don't agree. ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[].map(c, [ c, [c] ])")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); // Remainder are conceptually the same, but confirm generality. ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, [[c]] ])")); ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [c], [[c]] ])")); ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, c ])")); ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [c], c ])")); ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ [[c]], c ])")); ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("[].map(c, [ c, type(c) ])")); ASSERT_OK_AND_ASSIGN(result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } struct PrimitiveLiteralsTestCase { std::string expr; PrimitiveType expected_type; }; class PrimitiveLiteralsTest : public testing::TestWithParam {}; TEST_P(PrimitiveLiteralsTest, LiteralsTypeInferred) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); const PrimitiveLiteralsTestCase& test_case = GetParam(); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->mutable_type_map()[1].primitive(), test_case.expected_type); } INSTANTIATE_TEST_SUITE_P(PrimitiveLiteralsTests, PrimitiveLiteralsTest, ::testing::Values( PrimitiveLiteralsTestCase{ .expr = "1", .expected_type = PrimitiveType::kInt64, }, PrimitiveLiteralsTestCase{ .expr = "1.0", .expected_type = PrimitiveType::kDouble, }, PrimitiveLiteralsTestCase{ .expr = "1u", .expected_type = PrimitiveType::kUint64, }, PrimitiveLiteralsTestCase{ .expr = "'string'", .expected_type = PrimitiveType::kString, }, PrimitiveLiteralsTestCase{ .expr = "b'bytes'", .expected_type = PrimitiveType::kBytes, }, PrimitiveLiteralsTestCase{ .expr = "false", .expected_type = PrimitiveType::kBool, })); struct AstTypeConversionTestCase { Type decl_type; TypeSpec expected_type; }; class AstTypeConversionTest : public testing::TestWithParam {}; TEST_P(AstTypeConversionTest, TypeConversion) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); ASSERT_TRUE( env.InsertVariableIfAbsent(MakeVariableDecl("x", GetParam().decl_type))); const AstTypeConversionTestCase& test_case = GetParam(); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->mutable_type_map()[1], test_case.expected_type) << GetParam().decl_type.DebugString(); } INSTANTIATE_TEST_SUITE_P( Primitives, AstTypeConversionTest, ::testing::Values( AstTypeConversionTestCase{ .decl_type = NullType(), .expected_type = AstType(NullTypeSpec()), }, AstTypeConversionTestCase{ .decl_type = DynType(), .expected_type = AstType(DynTypeSpec()), }, AstTypeConversionTestCase{ .decl_type = BoolType(), .expected_type = AstType(PrimitiveType::kBool), }, AstTypeConversionTestCase{ .decl_type = IntType(), .expected_type = AstType(PrimitiveType::kInt64), }, AstTypeConversionTestCase{ .decl_type = UintType(), .expected_type = AstType(PrimitiveType::kUint64), }, AstTypeConversionTestCase{ .decl_type = DoubleType(), .expected_type = AstType(PrimitiveType::kDouble), }, AstTypeConversionTestCase{ .decl_type = StringType(), .expected_type = AstType(PrimitiveType::kString), }, AstTypeConversionTestCase{ .decl_type = BytesType(), .expected_type = AstType(PrimitiveType::kBytes), }, AstTypeConversionTestCase{ .decl_type = TimestampType(), .expected_type = AstType(WellKnownTypeSpec::kTimestamp), }, AstTypeConversionTestCase{ .decl_type = DurationType(), .expected_type = AstType(WellKnownTypeSpec::kDuration), })); INSTANTIATE_TEST_SUITE_P( Wrappers, AstTypeConversionTest, ::testing::Values( AstTypeConversionTestCase{ .decl_type = IntWrapperType(), .expected_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), }, AstTypeConversionTestCase{ .decl_type = UintWrapperType(), .expected_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kUint64)), }, AstTypeConversionTestCase{ .decl_type = DoubleWrapperType(), .expected_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kDouble)), }, AstTypeConversionTestCase{ .decl_type = BoolWrapperType(), .expected_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kBool)), }, AstTypeConversionTestCase{ .decl_type = StringWrapperType(), .expected_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kString)), }, AstTypeConversionTestCase{ .decl_type = BytesWrapperType(), .expected_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kBytes)), })); INSTANTIATE_TEST_SUITE_P( ComplexTypes, AstTypeConversionTest, ::testing::Values( AstTypeConversionTestCase{ .decl_type = ListType(TestTypeArena(), IntType()), .expected_type = AstType( ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), }, AstTypeConversionTestCase{ .decl_type = MapType(TestTypeArena(), IntType(), IntType()), .expected_type = AstType( MapTypeSpec(std::make_unique(PrimitiveType::kInt64), std::make_unique(PrimitiveType::kInt64))), }, AstTypeConversionTestCase{ .decl_type = TypeType(TestTypeArena(), IntType()), .expected_type = AstType(std::make_unique(PrimitiveType::kInt64)), }, AstTypeConversionTestCase{ .decl_type = OpaqueType(TestTypeArena(), "tuple", {IntType(), IntType()}), .expected_type = AstType( AbstractType("tuple", {AstType(PrimitiveType::kInt64), AstType(PrimitiveType::kInt64)})), }, AstTypeConversionTestCase{ .decl_type = StructType(MessageType(TestAllTypes::descriptor())), .expected_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))})); TEST(TypeCheckerImplTest, NullLiteral) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("null")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_TRUE(checked_ast->mutable_type_map()[1].has_null()); } TEST(TypeCheckerImplTest, ExpressionLimitInclusive) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); CheckerOptions options; options.max_expression_node_count = 2; TypeCheckerImpl impl(std::move(env), options); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}.foo")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("{}.foo.bar")); EXPECT_THAT(impl.Check(std::move(ast)), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("expression node count exceeded: 2"))); } TEST(TypeCheckerImplTest, ComprehensionUnsupportedRange) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("'abc'.all(x, y == 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring( Severity::kError, "expression of type 'string' cannot be " "the range of a comprehension"))); } TEST(TypeCheckerImplTest, ComprehensionDynRange) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("range", DynType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("range.all(x, x == 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(TypeCheckerImplTest, BasicOvlResolution) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); // Assumes parser numbering: + should always be id 2. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->mutable_reference_map()[2], IsFunctionReference( "_+_", std::vector{"add_double_double"})); } TEST(TypeCheckerImplTest, OvlResolutionMultipleOverloads) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("dyn(x) + dyn(y)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); // Assumes parser numbering: + should always be id 3. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->mutable_reference_map()[3], IsFunctionReference("_+_", std::vector{ "add_double_double", "add_int_int", "add_list", "add_uint_uint"})); } TEST(TypeCheckerImplTest, BasicFunctionResultTypeResolution) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); env.InsertVariableIfAbsent(MakeVariableDecl("z", DoubleType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y + z")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); // Assumes parser numbering: + should always be id 2 and 4. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->mutable_reference_map()[2], IsFunctionReference( "_+_", std::vector{"add_double_double"})); EXPECT_THAT(checked_ast->mutable_reference_map()[4], IsFunctionReference( "_+_", std::vector{"add_double_double"})); int64_t root_id = checked_ast->root_expr().id(); EXPECT_EQ(checked_ast->mutable_type_map()[root_id].primitive(), PrimitiveType::kDouble); } TEST(TypeCheckerImplTest, BasicOvlResolutionNoMatch) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", StringType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x + y")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring(Severity::kError, "no matching overload for '_+_'" " applied to '(int, string)'"))); } TEST(TypeCheckerImplTest, ParmeterizedOvlResolutionMatch) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", StringType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("([x] + []) == [x]")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } TEST(TypeCheckerImplTest, AliasedTypeVarSameType) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[].exists(x, x == 10 || x == '10')")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT( result.GetIssues(), ElementsAre(IsIssueWithSubstring( Severity::kError, "no matching overload for '_==_' applied to"))); } TEST(TypeCheckerImplTest, TypeVarRange) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertFunctionIfAbsent(MakeIdentFunction()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("identity([]).exists(x, x == 10 )")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()) << absl::StrJoin(result.GetIssues(), "\n"); } TEST(TypeCheckerImplTest, WellKnownTypeCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.AddTypeProvider(std::make_unique()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst("google.protobuf.Int32Value{value: 10}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT( checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)))))); EXPECT_THAT( checked_ast->reference_map(), Contains(Pair(checked_ast->root_expr().id(), Property(&Reference::name, "google.protobuf.Int32Value")))); } TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.AddTypeProvider(std::make_unique()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("google.protobuf.Struct{fields: {}}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); int64_t map_expr_id = checked_ast->root_expr().struct_expr().fields().at(0).value().id(); ASSERT_NE(map_expr_id, 0); EXPECT_THAT( checked_ast->type_map(), Contains(Pair(map_expr_id, Eq(AstType(MapTypeSpec( std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec()))))))); } TEST(TypeCheckerImplTest, ExpectedTypeMatches) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_expected_type(MapType(&arena, StringType(), StringType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT( checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(AstType(MapTypeSpec( std::make_unique(PrimitiveType::kString), std::make_unique(PrimitiveType::kString))))))); } TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_expected_type(MapType(&arena, StringType(), StringType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'abc': 123}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT( result.GetIssues(), Contains(IsIssueWithSubstring( Severity::kError, "expected type 'map(string, string)' but found 'map(string, int)'"))); } TEST(TypeCheckerImplTest, ToBuilder) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); TypeCheckerImpl impl(std::move(env)); auto builder = impl.ToBuilder(); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto new_checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); ASSERT_OK_AND_ASSIGN(ValidationResult result, new_checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } TEST(TypeCheckerImplTest, ToBuilderPropagatesArena) { auto arena = std::make_shared(); TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_arena(arena); Type list_type = ListType(arena.get(), IntType()); ASSERT_TRUE( env.InsertVariableIfAbsent(MakeVariableDecl("my_list", list_type))); auto base_checker = std::make_unique(std::move(env)); std::unique_ptr builder = base_checker->ToBuilder(); base_checker.reset(); arena.reset(); ASSERT_OK_AND_ASSIGN(auto derived_checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("my_list")); ASSERT_OK_AND_ASSIGN(ValidationResult result, derived_checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } TEST(TypeCheckerImplTest, BadSourcePosition) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); ast->mutable_source_info().mutable_positions()[1] = -42; ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(auto source, NewSource("foo")); EXPECT_FALSE(result.IsValid()); ASSERT_THAT(result.GetIssues(), SizeIs(1)); EXPECT_EQ( result.GetIssues()[0].ToDisplayString(*source), "ERROR: :-1:-1: undeclared reference to 'foo' (in container '')"); } // Check that the TypeChecker will fail if no type is deduced for a // subexpression. This is meant to be a guard against failing to account for new // types of expressions in the type checker logic. TEST(TypeCheckerImplTest, FailsIfNoTypeDeduced) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("a", BoolType())); env.InsertVariableIfAbsent(MakeVariableDecl("b", BoolType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("a || b")); // Assume that an unspecified expr kind is not deducible. Expr unspecified_expr; unspecified_expr.set_id(3); ast->mutable_root_expr().mutable_call_expr().mutable_args()[1] = std::move(unspecified_expr); ASSERT_THAT(impl.Check(std::move(ast)), StatusIs(absl::StatusCode::kInvalidArgument, "Could not deduce type for expression id: 3")); } TEST(TypeCheckerImplTest, BadLineOffsets) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto source, NewSource("\nfoo")); { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); ast->mutable_source_info().mutable_line_offsets()[1] = 1; ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); ASSERT_THAT(result.GetIssues(), SizeIs(1)); EXPECT_EQ(result.GetIssues()[0].ToDisplayString(*source), "ERROR: :-1:-1: undeclared reference to 'foo' (in " "container '')"); } { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); ast->mutable_source_info().mutable_line_offsets().clear(); ast->mutable_source_info().mutable_line_offsets().push_back(-1); ast->mutable_source_info().mutable_line_offsets().push_back(2); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); ASSERT_THAT(result.GetIssues(), SizeIs(1)); EXPECT_EQ(result.GetIssues()[0].ToDisplayString(*source), "ERROR: :-1:-1: undeclared reference to 'foo' (in " "container '')"); } } TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container(*MakeExpressionContainer("google.protobuf")); env.AddTypeProvider(std::make_unique()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("Int32Value{value: 10}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT( checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)))))); EXPECT_THAT( checked_ast->reference_map(), Contains(Pair(checked_ast->root_expr().id(), Property(&Reference::name, "google.protobuf.Int32Value")))); } TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container(*MakeExpressionContainer("google.protobuf")); env.AddTypeProvider(std::make_unique()); CheckerOptions options; options.update_struct_type_names = false; TypeCheckerImpl impl(std::move(env), options); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("Int32Value{value: 10}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT( checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)))))); EXPECT_THAT( checked_ast->reference_map(), Contains(Pair(checked_ast->root_expr().id(), Property(&Reference::name, "google.protobuf.Int32Value")))); EXPECT_THAT(checked_ast->root_expr().struct_expr(), Property(&StructExpr::name, "Int32Value")); } TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("TestAllTypes.NestedEnum.BAZ")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); auto ref_iter = checked_ast->reference_map().find(checked_ast->root_expr().id()); ASSERT_NE(ref_iter, checked_ast->reference_map().end()); EXPECT_EQ(ref_iter->second.name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAZ"); EXPECT_EQ(ref_iter->second.value().int_value(), 2); } struct CheckedExprTestCase { std::string expr; TypeSpec expected_result_type; std::string error_substring; }; class WktCreationTest : public testing::TestWithParam {}; TEST_P(WktCreationTest, MessageCreation) { google::protobuf::Arena arena; const CheckedExprTestCase& test_case = GetParam(); TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.AddTypeProvider(std::make_unique()); env.set_container(*MakeExpressionContainer("google.protobuf")); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); if (!test_case.error_substring.empty()) { EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring(Severity::kError, test_case.error_substring))); return; } ASSERT_TRUE(result.IsValid()) << absl::StrJoin(result.GetIssues(), "\n", [](std::string* out, const TypeCheckIssue& issue) { absl::StrAppend(out, issue.message()); }); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(test_case.expected_result_type)))); } INSTANTIATE_TEST_SUITE_P( WellKnownTypes, WktCreationTest, ::testing::Values( CheckedExprTestCase{ .expr = "google.protobuf.Int32Value{value: 10}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), }, CheckedExprTestCase{ .expr = ".google.protobuf.Int32Value{value: 10}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), }, CheckedExprTestCase{ .expr = "Int32Value{value: 10}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), }, CheckedExprTestCase{ .expr = "google.protobuf.Int32Value{value: '10'}", .expected_result_type = AstType(), .error_substring = "expected type of field 'value' is 'int' but " "provided type is 'string'"}, CheckedExprTestCase{ .expr = "google.protobuf.Int32Value{not_a_field: '10'}", .expected_result_type = AstType(), .error_substring = "undefined field 'not_a_field' not found in " "struct 'google.protobuf.Int32Value'"}, CheckedExprTestCase{ .expr = "NotAType{not_a_field: '10'}", .expected_result_type = AstType(), .error_substring = "undeclared reference to 'NotAType' (in container " "'google.protobuf')"}, CheckedExprTestCase{ .expr = ".protobuf.Int32Value{value: 10}", .expected_result_type = AstType(), .error_substring = "undeclared reference to '.protobuf.Int32Value' (in container " "'google.protobuf')"}, CheckedExprTestCase{ .expr = "Int32Value{value: 10}.value", .expected_result_type = AstType(), .error_substring = "expression of type 'wrapper(int)' cannot be the " "operand of a select operation"}, CheckedExprTestCase{ .expr = "Int64Value{value: 10}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), }, CheckedExprTestCase{ .expr = "BoolValue{value: true}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kBool)), }, CheckedExprTestCase{ .expr = "UInt64Value{value: 10u}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kUint64)), }, CheckedExprTestCase{ .expr = "UInt32Value{value: 10u}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kUint64)), }, CheckedExprTestCase{ .expr = "FloatValue{value: 1.25}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kDouble)), }, CheckedExprTestCase{ .expr = "DoubleValue{value: 1.25}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kDouble)), }, CheckedExprTestCase{ .expr = "StringValue{value: 'test'}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kString)), }, CheckedExprTestCase{ .expr = "BytesValue{value: b'test'}", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kBytes)), }, CheckedExprTestCase{ .expr = "Duration{seconds: 10, nanos: 11}", .expected_result_type = AstType(WellKnownTypeSpec::kDuration), }, CheckedExprTestCase{ .expr = "Timestamp{seconds: 10, nanos: 11}", .expected_result_type = AstType(WellKnownTypeSpec::kTimestamp), }, CheckedExprTestCase{ .expr = "Struct{fields: {'key': 'value'}}", .expected_result_type = AstType( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec()))), }, CheckedExprTestCase{ .expr = "ListValue{values: [1, 2, 3]}", .expected_result_type = AstType(ListTypeSpec(std::make_unique(DynTypeSpec()))), }, CheckedExprTestCase{ .expr = R"cel( Any{ type_url:'type.googleapis.com/google.protobuf.Int32Value', value: b'' })cel", .expected_result_type = AstType(WellKnownTypeSpec::kAny), }, CheckedExprTestCase{ .expr = "Int64Value{value: 10} + 1", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "BoolValue{value: false} || true", .expected_result_type = AstType(PrimitiveType::kBool), })); TEST(AliasTest, ImportVariable) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("cel.expr.conformance", "com.example.TestVariable1", "com.example.TestVariable2")); env.set_container(std::move(container)); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent( MakeVariableDecl("com.example.TestVariable1", MessageType(testpb3::TestAllTypes::descriptor())))); ASSERT_TRUE(env.InsertVariableIfAbsent( MakeVariableDecl("com.example.TestVariable2", MessageType(testpb2::TestAllTypes::descriptor())))); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst( "TestVariable1.single_int64 == TestVariable2.single_int64")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); ASSERT_TRUE(checked_ast->root_expr().has_call_expr()); ASSERT_EQ(checked_ast->root_expr().call_expr().function(), "_==_"); ASSERT_THAT(checked_ast->root_expr().call_expr().args(), SizeIs(2)); ASSERT_EQ(checked_ast->root_expr() .call_expr() .args()[0] .select_expr() .operand() .ident_expr() .name(), "com.example.TestVariable1"); ASSERT_EQ(checked_ast->root_expr() .call_expr() .args()[1] .select_expr() .operand() .ident_expr() .name(), "com.example.TestVariable2"); } TEST(AliasTest, AliasToContainerResolvesMessage) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); ExpressionContainer container; ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); env.set_container(std::move(container)); google::protobuf::LinkMessageReflection(); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT( checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(AstType(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))))); EXPECT_THAT( checked_ast->reference_map(), Contains(Pair(checked_ast->root_expr().id(), Property(&Reference::name, "cel.expr.conformance.proto3.TestAllTypes")))); EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), "cel.expr.conformance.proto3.TestAllTypes"); } TEST(AliasTest, AliasSimpleName) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); ExpressionContainer container; ASSERT_THAT(container.AddAlias("foo", "bar"), IsOk()); env.set_container(std::move(container)); google::protobuf::LinkMessageReflection(); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertOrReplaceVariable(MakeVariableDecl("bar", IntType())); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), "bar"); } TEST(AliasTest, AliasPreventsContainerResolution) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("cel.expr")); ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); env.set_container(std::move(container)); ASSERT_TRUE(env.InsertVariableIfAbsent( MakeVariableDecl("cel.expr.pb3.FooVariable", IntType()))); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("FooVariable")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT( result.GetIssues(), Contains(IsIssueWithSubstring( Severity::kError, "undeclared reference to 'FooVariable'"))); } { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.FooVariable")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); EXPECT_THAT( result.GetIssues(), Contains(IsIssueWithSubstring( Severity::kError, "undeclared reference to 'pb3.FooVariable'"))); } { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("expr.pb3.FooVariable")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), "cel.expr.pb3.FooVariable"); } } TEST(AliasTest, AliasPreventsDisambiguation) { // Copying behavior from cel-go and cel-java. google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); ExpressionContainer container; ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); env.set_container(std::move(container)); env.InsertOrReplaceVariable(MakeVariableDecl("pb3.Foo", IntType())); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); { ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), "cel.expr.conformance.proto3.TestAllTypes"); } { ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst(".pb3.TestAllTypes{single_int64: 10}")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), "cel.expr.conformance.proto3.TestAllTypes"); } { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.Foo")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring( Severity::kError, "undeclared reference to 'pb3.Foo'"))); } { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(".pb3.Foo")); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring( Severity::kError, "undeclared reference to '.pb3.Foo'"))); } } class GenericMessagesTest : public testing::TestWithParam { }; TEST_P(GenericMessagesTest, TypeChecksProto3Imports) { const CheckedExprTestCase& test_case = GetParam(); google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container(*MakeExpressionContainer( "", "cel.expr.conformance.proto3.TestAllTypes", "cel.expr.conformance.proto3.NestedTestAllTypes")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); if (!test_case.error_substring.empty()) { EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring(Severity::kError, test_case.error_substring))); return; } ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(test_case.expected_result_type)))) << cel::test::FormatBaselineAst(*checked_ast); } TEST_P(GenericMessagesTest, TypeChecksProto3Container) { const CheckedExprTestCase& test_case = GetParam(); google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); if (!test_case.error_substring.empty()) { EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring(Severity::kError, test_case.error_substring))); return; } ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(test_case.expected_result_type)))) << cel::test::FormatBaselineAst(*checked_ast); } INSTANTIATE_TEST_SUITE_P( TestAllTypesCreation, GenericMessagesTest, ::testing::Values( CheckedExprTestCase{ .expr = "TestAllTypes{not_a_field: 10}", .expected_result_type = AstType(), .error_substring = "undefined field 'not_a_field' not found in " "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, CheckedExprTestCase{ .expr = "TestAllTypes{single_int64: 10}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_int64: 'string'}", .expected_result_type = AstType(), .error_substring = "expected type of field 'single_int64' is 'int' but " "provided type is 'string'"}, CheckedExprTestCase{ .expr = "TestAllTypes{single_int32: 10}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_uint64: 10u}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_uint32: 10u}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_sint64: 10}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_sint32: 10}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_fixed64: 10u}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_fixed32: 10u}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_sfixed64: 10}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_sfixed32: 10}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_double: 1.25}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_float: 1.25}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_string: 'string'}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_bool: true}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_bytes: b'string'}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, // Well-known CheckedExprTestCase{ .expr = "TestAllTypes{single_any: TestAllTypes{single_int64: 10}}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_any: 1}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_any: 'string'}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_any: ['string']}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{repeated_nested_message: " "[TestAllTypes.NestedMessage{bb: 42}]}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: duration('1s')}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_timestamp: timestamp(0)}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_struct: {}}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_struct: {'key': 'value'}}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_struct: {1: 2}}", .expected_result_type = AstType(), .error_substring = "expected type of field 'single_struct' is " "'map(string, dyn)' but " "provided type is 'map(int, int)'"}, CheckedExprTestCase{ .expr = "TestAllTypes{list_value: [1, 2, 3]}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{list_value: []}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{list_value: 1}", .expected_result_type = AstType(), .error_substring = "expected type of field 'list_value' is 'list(dyn)' but " "provided type is 'int'"}, CheckedExprTestCase{ .expr = "TestAllTypes{single_int64_wrapper: 1}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_int64_wrapper: null}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: null}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: 1.0}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: 'string'}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: {'string': 'string'}}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_value: ['string']}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{repeated_int64: [1, 2, 3]}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{repeated_int64: ['string']}", .expected_result_type = AstType(), .error_substring = "expected type of field 'repeated_int64' is 'list(int)'"}, CheckedExprTestCase{ .expr = "TestAllTypes{map_string_int64: ['string']}", .expected_result_type = AstType(), .error_substring = "expected type of field 'map_string_int64' is " "'map(string, int)'"}, CheckedExprTestCase{ .expr = "TestAllTypes{map_string_int64: {'string': 1}}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_nested_enum: 1}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_nested_enum: TestAllTypes.NestedEnum.BAR}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes.NestedEnum.BAR", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "TestAllTypes", .expected_result_type = AstType(std::make_unique( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))), }, CheckedExprTestCase{ .expr = "TestAllTypes == type(TestAllTypes{})", .expected_result_type = AstType(PrimitiveType::kBool), }, // Special case for the NullValue enum. CheckedExprTestCase{ .expr = "TestAllTypes{null_value: 0}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, // Legacy nullability behaviors. CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: null}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_timestamp: null}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{single_nested_message: null}", .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, CheckedExprTestCase{ .expr = "TestAllTypes{}.single_duration == null", .expected_result_type = AstType(PrimitiveType::kBool), }, CheckedExprTestCase{ .expr = "TestAllTypes{}.single_timestamp == null", .expected_result_type = AstType(PrimitiveType::kBool), }, CheckedExprTestCase{ .expr = "TestAllTypes{}.single_nested_message == null", .expected_result_type = AstType(PrimitiveType::kBool), })); INSTANTIATE_TEST_SUITE_P( TestAllTypesFieldSelection, GenericMessagesTest, ::testing::Values( CheckedExprTestCase{ .expr = "test_msg.not_a_field", .expected_result_type = AstType(), .error_substring = "undefined field 'not_a_field' not found in " "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, CheckedExprTestCase{ .expr = "test_msg.single_int64", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "test_msg.single_nested_enum", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "test_msg.single_nested_enum == 1", .expected_result_type = AstType(PrimitiveType::kBool), }, CheckedExprTestCase{ .expr = "test_msg.single_nested_enum == TestAllTypes.NestedEnum.BAR", .expected_result_type = AstType(PrimitiveType::kBool), }, CheckedExprTestCase{ .expr = "has(test_msg.not_a_field)", .expected_result_type = AstType(), .error_substring = "undefined field 'not_a_field' not found in " "struct 'cel.expr.conformance.proto3.TestAllTypes'"}, CheckedExprTestCase{ .expr = "has(test_msg.single_int64)", .expected_result_type = AstType(PrimitiveType::kBool), }, CheckedExprTestCase{ .expr = "test_msg.single_int32", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "test_msg.single_uint64", .expected_result_type = AstType(PrimitiveType::kUint64), }, CheckedExprTestCase{ .expr = "test_msg.single_uint32", .expected_result_type = AstType(PrimitiveType::kUint64), }, CheckedExprTestCase{ .expr = "test_msg.single_sint64", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "test_msg.single_sint32", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "test_msg.single_fixed64", .expected_result_type = AstType(PrimitiveType::kUint64), }, CheckedExprTestCase{ .expr = "test_msg.single_fixed32", .expected_result_type = AstType(PrimitiveType::kUint64), }, CheckedExprTestCase{ .expr = "test_msg.single_sfixed64", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "test_msg.single_sfixed32", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "test_msg.single_float", .expected_result_type = AstType(PrimitiveType::kDouble), }, CheckedExprTestCase{ .expr = "test_msg.single_double", .expected_result_type = AstType(PrimitiveType::kDouble), }, CheckedExprTestCase{ .expr = "test_msg.single_string", .expected_result_type = AstType(PrimitiveType::kString), }, CheckedExprTestCase{ .expr = "test_msg.single_bool", .expected_result_type = AstType(PrimitiveType::kBool), }, CheckedExprTestCase{ .expr = "test_msg.single_bytes", .expected_result_type = AstType(PrimitiveType::kBytes), }, // Basic tests for containers. This is covered in more detail in // conformance tests and the type provider implementation. CheckedExprTestCase{ .expr = "test_msg.repeated_int32", .expected_result_type = AstType( ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), }, CheckedExprTestCase{ .expr = "test_msg.repeated_string", .expected_result_type = AstType(ListTypeSpec( std::make_unique(PrimitiveType::kString))), }, CheckedExprTestCase{ .expr = "test_msg.map_bool_bool", .expected_result_type = AstType( MapTypeSpec(std::make_unique(PrimitiveType::kBool), std::make_unique(PrimitiveType::kBool))), }, // Note: The Go type checker permits this so C++ does as well. Some // test cases expect that field selection on a map is always allowed, // even if a specific, non-string key type is known. CheckedExprTestCase{ .expr = "test_msg.map_bool_bool.field_like_key", .expected_result_type = AstType(PrimitiveType::kBool), }, CheckedExprTestCase{ .expr = "test_msg.map_string_int64", .expected_result_type = AstType( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(PrimitiveType::kInt64))), }, CheckedExprTestCase{ .expr = "test_msg.map_string_int64.field_like_key", .expected_result_type = AstType(PrimitiveType::kInt64), }, // Well-known CheckedExprTestCase{ .expr = "test_msg.single_duration", .expected_result_type = AstType(WellKnownTypeSpec::kDuration), }, CheckedExprTestCase{ .expr = "test_msg.single_timestamp", .expected_result_type = AstType(WellKnownTypeSpec::kTimestamp), }, CheckedExprTestCase{ .expr = "test_msg.single_any", .expected_result_type = AstType(WellKnownTypeSpec::kAny), }, CheckedExprTestCase{ .expr = "test_msg.single_int64_wrapper", .expected_result_type = AstType(PrimitiveTypeWrapper(PrimitiveType::kInt64)), }, CheckedExprTestCase{ .expr = "test_msg.single_struct", .expected_result_type = AstType( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec()))), }, CheckedExprTestCase{ .expr = "test_msg.list_value", .expected_result_type = AstType(ListTypeSpec(std::make_unique(DynTypeSpec()))), }, CheckedExprTestCase{ .expr = "test_msg.list_value", .expected_result_type = AstType(ListTypeSpec(std::make_unique(DynTypeSpec()))), }, // Basic tests for nested messages. CheckedExprTestCase{ .expr = "NestedTestAllTypes{}.child.child.payload.single_int64", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "test_msg.single_struct.field.nested_field", .expected_result_type = AstType(DynTypeSpec()), }, CheckedExprTestCase{ .expr = "{}.field.nested_field", .expected_result_type = AstType(DynTypeSpec()), })); INSTANTIATE_TEST_SUITE_P( TypeInferences, GenericMessagesTest, ::testing::Values( CheckedExprTestCase{.expr = "[1, test_msg.single_int64_wrapper]", .expected_result_type = AstType(ListTypeSpec( std::make_unique(PrimitiveTypeWrapper( PrimitiveType::kInt64))))}, CheckedExprTestCase{.expr = "[1, 2, test_msg.single_int64_wrapper]", .expected_result_type = AstType(ListTypeSpec( std::make_unique(PrimitiveTypeWrapper( PrimitiveType::kInt64))))}, CheckedExprTestCase{.expr = "[test_msg.single_int64_wrapper, 1]", .expected_result_type = AstType(ListTypeSpec( std::make_unique(PrimitiveTypeWrapper( PrimitiveType::kInt64))))}, CheckedExprTestCase{ .expr = "[1, 2, test_msg.single_int64_wrapper, dyn(1)]", .expected_result_type = AstType( ListTypeSpec(std::make_unique(DynTypeSpec())))}, CheckedExprTestCase{.expr = "[null, test_msg][0]", .expected_result_type = AstType(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes"))}, CheckedExprTestCase{ .expr = "[{'k': dyn(1)}, {dyn('k'): 1}][0]", // Ambiguous type resolution, but we prefer the first option. .expected_result_type = AstType( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec())))}, CheckedExprTestCase{ .expr = "[{'k': 1}, {dyn('k'): 1}][0]", .expected_result_type = AstType( MapTypeSpec(std::make_unique(DynTypeSpec()), std::make_unique(PrimitiveType::kInt64)))}, CheckedExprTestCase{ .expr = "[{dyn('k'): 1}, {'k': 1}][0]", .expected_result_type = AstType( MapTypeSpec(std::make_unique(DynTypeSpec()), std::make_unique(PrimitiveType::kInt64)))}, CheckedExprTestCase{ .expr = "[{'k': 1}, {'k': dyn(1)}][0]", .expected_result_type = AstType( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec())))}, CheckedExprTestCase{.expr = "[{'k': 1}, {dyn('k'): dyn(1)}][0]", .expected_result_type = AstType(MapTypeSpec( std::make_unique(DynTypeSpec()), std::make_unique(DynTypeSpec())))}, CheckedExprTestCase{ .expr = "[{'k': 1.0}, {dyn('k'): test_msg.single_int64_wrapper}][0]", .expected_result_type = AstType(DynTypeSpec())}, CheckedExprTestCase{ .expr = "test_msg.single_int64", .expected_result_type = AstType(PrimitiveType::kInt64), }, CheckedExprTestCase{ .expr = "[[1], {1: 2u}][0]", .expected_result_type = AstType(DynTypeSpec()), }, CheckedExprTestCase{ .expr = "[{1: 2u}, [1]][0]", .expected_result_type = AstType(DynTypeSpec()), }, CheckedExprTestCase{ .expr = "[test_msg.single_int64_wrapper," " test_msg.single_string_wrapper][0]", .expected_result_type = AstType(DynTypeSpec()), })); class StrictNullAssignmentTest : public testing::TestWithParam {}; TEST_P(StrictNullAssignmentTest, TypeChecksProto3) { const CheckedExprTestCase& test_case = GetParam(); google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); CheckerOptions options; options.enable_legacy_null_assignment = false; TypeCheckerImpl impl(std::move(env), options); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); if (!test_case.error_substring.empty()) { EXPECT_THAT(result.GetIssues(), Contains(IsIssueWithSubstring(Severity::kError, test_case.error_substring))); return; } ASSERT_TRUE(result.IsValid()) << absl::StrJoin(result.GetIssues(), "\n", [](std::string* out, const TypeCheckIssue& issue) { absl::StrAppend(out, issue.message()); }); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->type_map(), Contains(Pair(checked_ast->root_expr().id(), Eq(test_case.expected_result_type)))); } INSTANTIATE_TEST_SUITE_P( TestStrictNullAssignment, StrictNullAssignmentTest, ::testing::Values( // Legacy nullability behaviors rejected. CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: null}", .expected_result_type = AstType(), .error_substring = "'single_duration' is 'google.protobuf.Duration' but provided " "type is 'null_type'"}, CheckedExprTestCase{ .expr = "TestAllTypes{single_timestamp: null}", .expected_result_type = AstType(), .error_substring = "'single_timestamp' is 'google.protobuf.Timestamp' but " "provided type is 'null_type'"}, CheckedExprTestCase{ .expr = "TestAllTypes{single_nested_message: null}", .expected_result_type = AstType(), // Debug string includes descriptor address. .error_substring = "but provided type is 'null_type'"}, CheckedExprTestCase{ .expr = "TestAllTypes{}.single_duration == null", .expected_result_type = AstType(), .error_substring = "no matching overload for '_==_'", }, CheckedExprTestCase{ .expr = "TestAllTypes{}.single_timestamp == null", .expected_result_type = AstType(), .error_substring = "no matching overload for '_==_'"}, CheckedExprTestCase{ .expr = "TestAllTypes{}.single_nested_message == null", .expected_result_type = AstType(), .error_substring = "no matching overload for '_==_'", })); } // namespace } // namespace checker_internal } // namespace cel ================================================ FILE: checker/internal/type_inference_context.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/type_inference_context.h" #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/internal/format_type_name.h" #include "common/decl.h" #include "common/type.h" #include "common/type_kind.h" namespace cel::checker_internal { namespace { bool IsWildCardType(Type type) { switch (type.kind()) { case TypeKind::kAny: case TypeKind::kDyn: case TypeKind::kError: return true; default: return false; } } // Returns true if the given type is a legacy nullable type. // // Historically, structs and abstract types were considered nullable. This is // inconsistent with CEL's usual interpretation of null as a literal JSON null. // // TODO(uncreated-issue/74): Need a concrete plan for updating existing CEL expressions // that depend on the old behavior. bool IsLegacyNullable(Type type) { switch (type.kind()) { case TypeKind::kStruct: case TypeKind::kDuration: case TypeKind::kTimestamp: case TypeKind::kAny: case TypeKind::kOpaque: return true; default: return false; } } bool IsTypeVar(absl::string_view name) { return absl::StartsWith(name, "T%"); } bool IsUnionType(Type t) { switch (t.kind()) { case TypeKind::kAny: case TypeKind::kBoolWrapper: case TypeKind::kBytesWrapper: case TypeKind::kDyn: case TypeKind::kDoubleWrapper: case TypeKind::kIntWrapper: case TypeKind::kStringWrapper: case TypeKind::kUintWrapper: return true; default: return false; } } // Returns true if `a` is a subset of `b`. // (b is more general than a and admits a). bool IsSubsetOf(Type a, Type b) { switch (b.kind()) { case TypeKind::kAny: return true; case TypeKind::kBoolWrapper: return a.IsBool() || a.IsNull(); case TypeKind::kBytesWrapper: return a.IsBytes() || a.IsNull(); case TypeKind::kDoubleWrapper: return a.IsDouble() || a.IsNull(); case TypeKind::kDyn: return true; case TypeKind::kIntWrapper: return a.IsInt() || a.IsNull(); case TypeKind::kStringWrapper: return a.IsString() || a.IsNull(); case TypeKind::kUintWrapper: return a.IsUint() || a.IsNull(); default: return false; } } struct FunctionOverloadInstance { Type result_type; std::vector param_types; }; FunctionOverloadInstance InstantiateFunctionOverload( TypeInferenceContext& inference_context, const OverloadDecl& ovl) { FunctionOverloadInstance result; result.param_types.reserve(ovl.args().size()); TypeInferenceContext::InstanceMap substitutions; result.result_type = inference_context.InstantiateTypeParams(ovl.result(), substitutions); for (int i = 0; i < ovl.args().size(); ++i) { result.param_types.push_back( inference_context.InstantiateTypeParams(ovl.args()[i], substitutions)); } return result; } // Converts a wrapper type to its corresponding primitive type. // Returns nullopt if the type is not a wrapper type. absl::optional WrapperToPrimitive(const Type& t) { switch (t.kind()) { case TypeKind::kBoolWrapper: return BoolType(); case TypeKind::kBytesWrapper: return BytesType(); case TypeKind::kDoubleWrapper: return DoubleType(); case TypeKind::kStringWrapper: return StringType(); case TypeKind::kIntWrapper: return IntType(); case TypeKind::kUintWrapper: return UintType(); default: return absl::nullopt; } } } // namespace Type TypeInferenceContext::InstantiateTypeParams(const Type& type) { InstanceMap substitutions; return InstantiateTypeParams(type, substitutions); } Type TypeInferenceContext::InstantiateTypeParams( const Type& type, absl::flat_hash_map& substitutions) { switch (type.kind()) { // Unparameterized types -- just forward. case TypeKind::kAny: case TypeKind::kBool: case TypeKind::kBoolWrapper: case TypeKind::kBytes: case TypeKind::kBytesWrapper: case TypeKind::kDouble: case TypeKind::kDoubleWrapper: case TypeKind::kDuration: case TypeKind::kDyn: case TypeKind::kError: case TypeKind::kInt: case TypeKind::kNull: case TypeKind::kString: case TypeKind::kStringWrapper: case TypeKind::kStruct: case TypeKind::kTimestamp: case TypeKind::kUint: case TypeKind::kIntWrapper: case TypeKind::kUintWrapper: return type; case TypeKind::kTypeParam: { absl::string_view name = type.AsTypeParam()->name(); if (IsTypeVar(name)) { // Already instantiated (e.g. list comprehension variable). return type; } if (auto it = substitutions.find(name); it != substitutions.end()) { return TypeParamType(it->second); } absl::string_view substitution = NewTypeVar(name); substitutions[type.AsTypeParam()->name()] = substitution; return TypeParamType(substitution); } case TypeKind::kType: { auto type_type = type.AsType(); auto parameters = type_type->GetParameters(); if (parameters.size() == 1) { Type param = InstantiateTypeParams(parameters[0], substitutions); return TypeType(arena_, param); } else if (parameters.size() > 1) { return ErrorType(); } else { // generic type return type; } } case TypeKind::kList: { Type elem = InstantiateTypeParams(type.AsList()->element(), substitutions); return ListType(arena_, elem); } case TypeKind::kMap: { Type key = InstantiateTypeParams(type.AsMap()->key(), substitutions); Type value = InstantiateTypeParams(type.AsMap()->value(), substitutions); return MapType(arena_, key, value); } case TypeKind::kOpaque: { auto opaque_type = type.AsOpaque(); auto parameters = opaque_type->GetParameters(); std::vector param_instances; param_instances.reserve(parameters.size()); for (int i = 0; i < parameters.size(); ++i) { param_instances.push_back( InstantiateTypeParams(parameters[i], substitutions)); } return OpaqueType(arena_, type.AsOpaque()->name(), param_instances); } default: return ErrorType(); } } bool TypeInferenceContext::IsAssignable(const Type& from, const Type& to) { SubstitutionMap prospective_substitutions; bool result = IsAssignableInternal(from, to, prospective_substitutions); if (result) { UpdateTypeParameterBindings(prospective_substitutions); } return result; } bool TypeInferenceContext::IsAssignableInternal( const Type& from, const Type& to, SubstitutionMap& prospective_substitutions) { Type to_subs = Substitute(to, prospective_substitutions); Type from_subs = Substitute(from, prospective_substitutions); // Types always assignable to themselves. // Remainder is checking for assignability across different types. if (to_subs == from_subs) { return true; } // Resolve free type parameters. if (to_subs.kind() == TypeKind::kTypeParam || from_subs.kind() == TypeKind::kTypeParam) { return IsAssignableWithConstraints(from_subs, to_subs, prospective_substitutions); } // Maybe widen a prospective type binding if another potential binding is // more general and admits the previous binding. if ( // Checking assignability to a specific type var // that has a prospective type assignment. to.kind() == TypeKind::kTypeParam && prospective_substitutions.contains(to.GetTypeParam().name())) { SubstitutionMap prospective_subs_cpy = prospective_substitutions; if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) == RelativeGenerality::kMoreGeneral) { if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) && !OccursWithin(to.GetTypeParam().name(), from_subs, prospective_subs_cpy)) { prospective_subs_cpy[to.GetTypeParam().name()] = from_subs; prospective_substitutions = std::move(prospective_subs_cpy); return true; // otherwise, continue with normal assignability check. } } } // Type is as concrete as it can be under current substitutions. if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); wrapped_type.has_value()) { return from_subs.IsNull() || IsAssignableInternal(*wrapped_type, from_subs, prospective_substitutions); } // Wrapper types are assignable to their corresponding primitive type ( // somewhat similar to auto unboxing). This is a bit odd with CEL's null_type, // but there isn't a dedicated syntax for narrowing from the nullable. if (auto from_wrapper = WrapperToPrimitive(from_subs); from_wrapper.has_value()) { return IsAssignableInternal(*from_wrapper, to_subs, prospective_substitutions); } if (enable_legacy_null_assignment_) { if (from_subs.IsNull() && IsLegacyNullable(to_subs)) { return true; } if (to_subs.IsNull() && IsLegacyNullable(from_subs)) { return true; } } if (from_subs.kind() == TypeKind::kType && to_subs.kind() == TypeKind::kType) { // Types are always assignable to themselves (even if differently // parameterized). return true; } if (to_subs.kind() == TypeKind::kEnum && from_subs.kind() == TypeKind::kInt) { return true; } if (from_subs.kind() == TypeKind::kEnum && to_subs.kind() == TypeKind::kInt) { return true; } if (IsWildCardType(from_subs) || IsWildCardType(to_subs)) { return true; } if (to_subs.kind() != from_subs.kind() || to_subs.name() != from_subs.name()) { return false; } // Recurse for the type parameters. auto to_params = to_subs.GetParameters(); auto from_params = from_subs.GetParameters(); const auto params_size = to_params.size(); if (params_size != from_params.size()) { return false; } for (size_t i = 0; i < params_size; ++i) { if (!IsAssignableInternal(from_params[i], to_params[i], prospective_substitutions)) { return false; } } return true; } Type TypeInferenceContext::Substitute( const Type& type, const SubstitutionMap& substitutions) const { Type subs = type; while (subs.kind() == TypeKind::kTypeParam) { TypeParamType t = subs.GetTypeParam(); if (auto it = substitutions.find(t.name()); it != substitutions.end()) { subs = it->second; continue; } if (auto it = type_parameter_bindings_.find(t.name()); it != type_parameter_bindings_.end()) { if (it->second.type.has_value()) { subs = *it->second.type; continue; } } break; } return subs; } TypeInferenceContext::RelativeGenerality TypeInferenceContext::CompareGenerality( const Type& from, const Type& to, const SubstitutionMap& prospective_substitutions) const { Type from_subs = Substitute(from, prospective_substitutions); Type to_subs = Substitute(to, prospective_substitutions); if (from_subs == to_subs) { return RelativeGenerality::kEquivalent; } if (IsUnionType(from_subs) && IsSubsetOf(to_subs, from_subs)) { return RelativeGenerality::kMoreGeneral; } if (IsUnionType(to_subs)) { return RelativeGenerality::kLessGeneral; } if (enable_legacy_null_assignment_ && IsLegacyNullable(from_subs) && to_subs.IsNull()) { return RelativeGenerality::kMoreGeneral; } // Not a polytype. Check if it is a parameterized type and all parameters are // equivalent and at least one is more general. if (from_subs.IsList() && to_subs.IsList()) { return CompareGenerality(from_subs.AsList()->GetElement(), to_subs.AsList()->GetElement(), prospective_substitutions); } if (from_subs.IsMap() && to_subs.IsMap()) { RelativeGenerality key_generality = CompareGenerality(from_subs.AsMap()->GetKey(), to_subs.AsMap()->GetKey(), prospective_substitutions); RelativeGenerality value_generality = CompareGenerality( from_subs.AsMap()->GetValue(), to_subs.AsMap()->GetValue(), prospective_substitutions); if (key_generality == RelativeGenerality::kLessGeneral || value_generality == RelativeGenerality::kLessGeneral) { return RelativeGenerality::kLessGeneral; } if (key_generality == RelativeGenerality::kMoreGeneral || value_generality == RelativeGenerality::kMoreGeneral) { return RelativeGenerality::kMoreGeneral; } return RelativeGenerality::kEquivalent; } if (from_subs.IsOpaque() && to_subs.IsOpaque() && from_subs.AsOpaque()->name() == to_subs.AsOpaque()->name() && from_subs.AsOpaque()->GetParameters().size() == to_subs.AsOpaque()->GetParameters().size()) { RelativeGenerality max_generality = RelativeGenerality::kEquivalent; for (int i = 0; i < from_subs.AsOpaque()->GetParameters().size(); ++i) { RelativeGenerality generality = CompareGenerality( from_subs.AsOpaque()->GetParameters()[i], to_subs.AsOpaque()->GetParameters()[i], prospective_substitutions); if (generality == RelativeGenerality::kLessGeneral) { return RelativeGenerality::kLessGeneral; } if (generality == RelativeGenerality::kMoreGeneral) { max_generality = RelativeGenerality::kMoreGeneral; } } return max_generality; } // Default not comparable. Since we ruled out polytypes, they should be // equivalent for the purposes of deciding the most general eligible // substitution. return RelativeGenerality::kEquivalent; } bool TypeInferenceContext::OccursWithin( absl::string_view var_name, const Type& type, const SubstitutionMap& substitutions) const { // This is difficult to trigger in normal CEL expressions, but may // happen with comprehensions where we can potentially reference a variable // with a free type var in different ways. // // This check guarantees that we don't introduce a recursive type definition // (a cycle in the substitution map). // // We can't reuse Substitute here because it does the pointer chasing and // might hide a cycle. // // E.g. // T2 in T3 when // T3 -> T2 -> null_type; Type substitution = type; while (substitution.kind() == TypeKind::kTypeParam) { absl::string_view param_name = substitution.AsTypeParam()->name(); if (param_name == var_name) { return true; } if (auto it = substitutions.find(param_name); it != substitutions.end()) { substitution = it->second; continue; } if (auto it = type_parameter_bindings_.find(param_name); it != type_parameter_bindings_.end() && it->second.type.has_value()) { substitution = it->second.type.value(); continue; } // Type parameter is free. return false; } for (const auto& param : substitution.GetParameters()) { if (OccursWithin(var_name, param, substitutions)) { return true; } } return false; } bool TypeInferenceContext::IsAssignableWithConstraints( const Type& from, const Type& to, SubstitutionMap& prospective_substitutions) { if (to.kind() == TypeKind::kTypeParam && from.kind() == TypeKind::kTypeParam) { if (to.AsTypeParam()->name() != from.AsTypeParam()->name()) { // Simple case, bind from to 'to' if both are free. prospective_substitutions[from.AsTypeParam()->name()] = to; } return true; } if (to.kind() == TypeKind::kTypeParam) { absl::string_view name = to.AsTypeParam()->name(); if (!OccursWithin(name, from, prospective_substitutions)) { prospective_substitutions[name] = from; return true; } } if (from.kind() == TypeKind::kTypeParam) { absl::string_view name = from.AsTypeParam()->name(); if (!OccursWithin(name, to, prospective_substitutions)) { prospective_substitutions[name] = to; return true; } } // If either types are wild cards but we weren't able to specialize, // assume assignable and continue. if (IsWildCardType(from) || IsWildCardType(to)) { return true; } return false; } absl::optional TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, absl::Span argument_types, bool is_receiver) { absl::optional result_type; std::vector matching_overloads; for (const auto& ovl : decl.overloads()) { if (ovl.member() != is_receiver || argument_types.size() != ovl.args().size()) { continue; } auto call_type_instance = InstantiateFunctionOverload(*this, ovl); ABSL_DCHECK_EQ(argument_types.size(), call_type_instance.param_types.size()); bool is_match = true; AssignabilityContext assignability_context = CreateAssignabilityContext(); for (int i = 0; i < argument_types.size(); ++i) { if (!assignability_context.IsAssignable( argument_types[i], call_type_instance.param_types[i])) { is_match = false; break; } } if (is_match) { matching_overloads.push_back(ovl); assignability_context.UpdateInferredTypeAssignments(); if (!result_type.has_value()) { result_type = call_type_instance.result_type; } else { if (!TypeEquivalent(*result_type, call_type_instance.result_type)) { result_type = DynType(); } } } } if (!result_type.has_value() || matching_overloads.empty()) { return absl::nullopt; } return OverloadResolution{ .result_type = FullySubstitute(*result_type, /*free_to_dyn=*/false), .overloads = std::move(matching_overloads), }; } void TypeInferenceContext::UpdateTypeParameterBindings( const SubstitutionMap& prospective_substitutions) { if (prospective_substitutions.empty()) { return; } for (auto iter = prospective_substitutions.begin(); iter != prospective_substitutions.end(); ++iter) { if (auto binding_iter = type_parameter_bindings_.find(iter->first); binding_iter != type_parameter_bindings_.end()) { binding_iter->second.type = iter->second; } else { ABSL_LOG(WARNING) << "Uninstantiated type parameter: " << iter->first; } } } bool TypeInferenceContext::TypeEquivalent(const Type& a, const Type& b) { return a == b; } Type TypeInferenceContext::FullySubstitute(const Type& type, bool free_to_dyn) const { switch (type.kind()) { case TypeKind::kTypeParam: { Type subs = Substitute(type, {}); if (subs.kind() == TypeKind::kTypeParam) { if (free_to_dyn) { return DynType(); } return subs; } return FullySubstitute(subs, free_to_dyn); } case TypeKind::kType: { if (type.AsType()->GetParameters().empty()) { return type; } Type param = FullySubstitute(type.AsType()->GetType(), free_to_dyn); return TypeType(arena_, param); } case TypeKind::kList: { Type elem = FullySubstitute(type.AsList()->GetElement(), free_to_dyn); return ListType(arena_, elem); } case TypeKind::kMap: { Type key = FullySubstitute(type.AsMap()->GetKey(), free_to_dyn); Type value = FullySubstitute(type.AsMap()->GetValue(), free_to_dyn); return MapType(arena_, key, value); } case TypeKind::kOpaque: { std::vector types; for (const auto& param : type.AsOpaque()->GetParameters()) { types.push_back(FullySubstitute(param, free_to_dyn)); } return OpaqueType(arena_, type.AsOpaque()->name(), types); } default: return type; } } bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from, const Type& to) { return inference_context_.IsAssignableInternal(from, to, prospective_substitutions_); } std::string TypeInferenceContext::DebugString() const { return absl::StrCat( "type_parameter_bindings: ", absl::StrJoin( type_parameter_bindings_, "\n ", [](std::string* out, const auto& binding) { absl::StrAppend( out, binding.first, " (", binding.second.name, ") -> ", checker_internal::FormatTypeName( binding.second.type.value_or(Type(TypeParamType("none"))))); })); } void TypeInferenceContext::AssignabilityContext:: UpdateInferredTypeAssignments() { inference_context_.UpdateTypeParameterBindings(prospective_substitutions_); prospective_substitutions_.clear(); } void TypeInferenceContext::AssignabilityContext::Reset() { prospective_substitutions_.clear(); } } // namespace cel::checker_internal ================================================ FILE: checker/internal/type_inference_context.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_INFERENCE_CONTEXT_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_INFERENCE_CONTEXT_H_ #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/decl.h" #include "common/type.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { // Class manages context for type inferences in the type checker. // TODO(uncreated-issue/72): for now, just checks assignability for concrete types. // Support for finding substitutions of type parameters will be added in a // follow-up CL. class TypeInferenceContext { public: // Convenience alias for an instance map for type parameters mapped to type // vars in a given context. // // This should be treated as opaque, the client should not manually modify. using InstanceMap = absl::flat_hash_map; struct OverloadResolution { Type result_type; std::vector overloads; }; private: // Alias for a map from type var name to the type it is bound to. // // Used for prospective substitutions during type inference to make progress // without affecting final assigned types. using SubstitutionMap = absl::flat_hash_map; public: // Helper class for managing several dependent type assignability checks. // // Note: while allowed, updating multiple AssignabilityContexts concurrently // can lead to inconsistencies in the final type bindings. class AssignabilityContext { public: // Checks if `from` is assignable to `to` with the current type // substitutions and any additional prospective substitutions in the parent // inference context. bool IsAssignable(const Type& from, const Type& to); // Applies any prospective type assignments to the parent inference context. // // This should only be called after all assignability checks have completed. // // Leaves the AssignabilityContext in the starting state (i.e. no // prospective substitutions). void UpdateInferredTypeAssignments(); // Return the AssignabilityContext to the starting state (i.e. no // prospective substitutions). void Reset(); private: explicit AssignabilityContext(TypeInferenceContext& inference_context) : inference_context_(inference_context) {} AssignabilityContext(const AssignabilityContext&) = delete; AssignabilityContext& operator=(const AssignabilityContext&) = delete; AssignabilityContext(AssignabilityContext&&) = delete; AssignabilityContext& operator=(AssignabilityContext&&) = delete; friend class TypeInferenceContext; TypeInferenceContext& inference_context_; SubstitutionMap prospective_substitutions_; }; explicit TypeInferenceContext(google::protobuf::Arena* arena, bool enable_legacy_null_assignment = true) : arena_(arena), enable_legacy_null_assignment_(enable_legacy_null_assignment) {} // Creates a new AssignabilityContext for the current inference context. // // This is intended for managing several dependent type assignability checks // that should only be added to the final type bindings if all checks succeed. // // Note: while allowed, updating multiple AssignabilityContexts concurrently // can lead to inconsistencies in the final type bindings. AssignabilityContext CreateAssignabilityContext() ABSL_ATTRIBUTE_LIFETIME_BOUND { return AssignabilityContext(*this); } // Resolves any remaining type parameters in the given type to a concrete // type or dyn. Type FinalizeType(const Type& type) const { return FullySubstitute(type, /*free_to_dyn=*/true); } // Recursively apply any substitutions to the given type. Type FullySubstitute(const Type& type, bool free_to_dyn = false) const; // Replace any generic type parameters in the given type with specific type // variables. Internally, type variables are just a unique string parameter // name. Type InstantiateTypeParams(const Type& type); // Overload for function overload types that need coordination across // multiple function parameters. Type InstantiateTypeParams(const Type& type, InstanceMap& substitutions); // Resolves the applicable overloads for the given function call given the // inferred argument types. // // If found, returns the result type and the list of applicable overloads. absl::optional ResolveOverload( const FunctionDecl& decl, absl::Span argument_types, bool is_receiver); // Checks if `from` is assignable to `to`. bool IsAssignable(const Type& from, const Type& to); std::string DebugString() const; private: struct TypeVar { absl::optional type; absl::string_view name; }; // Relative generality between two types. enum class RelativeGenerality { kMoreGeneral, // Note: kLessGeneral does not imply it is definitely more specific, only // that we cannot determine if equivalent or more general. kLessGeneral, kEquivalent, }; absl::string_view NewTypeVar(absl::string_view name = "") { next_type_parameter_id_++; auto inserted = type_parameter_bindings_.insert( {absl::StrCat("T%", next_type_parameter_id_), {absl::nullopt, name}}); ABSL_DCHECK(inserted.second); return inserted.first->first; } // Returns true if the two types are equivalent with the current type // substitutions. bool TypeEquivalent(const Type& a, const Type& b); // Returns true if `from` is assignable to `to` with the current type // substitutions and any additional prospective substitutions. // // `prospective_substitutions` is a map from type var name to the type it // should be bound to in the current context, augmenting any existing // substitutions. // // If the types are not assignable, returns false and leaves // `prospective_substitutions` unmodified. // // If the types are assignable, returns true and updates // `prospective_substitutions` with any new type parameter bindings. bool IsAssignableInternal(const Type& from, const Type& to, SubstitutionMap& prospective_substitutions); bool IsAssignableWithConstraints(const Type& from, const Type& to, SubstitutionMap& prospective_substitutions); // Relative generality of `from` as compared to `to` with the current type // substitutions and any additional prospective substitutions. // // Generality is only defined as a partial ordering. Some types are // incomparable. However we only need to know if a type is definitely more // general or not. RelativeGenerality CompareGenerality( const Type& from, const Type& to, const SubstitutionMap& prospective_substitutions) const; Type Substitute(const Type& type, const SubstitutionMap& substitutions) const; bool OccursWithin(absl::string_view var_name, const Type& type, const SubstitutionMap& substitutions) const; void UpdateTypeParameterBindings( const SubstitutionMap& prospective_substitutions); // Map from type var parameter name to the type it is bound to. // // Type var parameters are formatted as "T%" to avoid collisions with // provided type parameter names. // // node_hash_map is used to preserve pointer stability for use with // TypeParamType. // // Type parameter instances should be resolved to a concrete type during type // checking to remove the lifecycle dependency on the inference context // instance. // // nullopt signifies a free type variable. absl::node_hash_map type_parameter_bindings_; int64_t next_type_parameter_id_ = 0; google::protobuf::Arena* arena_; bool enable_legacy_null_assignment_; }; } // namespace cel::checker_internal #endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_TYPE_INFERENCE_CONTEXT_H_ ================================================ FILE: checker/internal/type_inference_context_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/internal/type_inference_context.h" #include #include #include "absl/log/absl_check.h" #include "absl/types/optional.h" #include "common/decl.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::SafeMatcherCast; using ::testing::SizeIs; MATCHER_P(IsTypeParam, param, "") { const Type& got = arg; if (got.kind() != TypeKind::kTypeParam) { return false; } TypeParamType type = got.GetTypeParam(); return type.name() == param; } MATCHER_P(IsListType, elems_matcher, "") { const Type& got = arg; if (got.kind() != TypeKind::kList) { return false; } ListType type = got.GetList(); Type elem = type.element(); return SafeMatcherCast(elems_matcher) .MatchAndExplain(elem, result_listener); } MATCHER_P2(IsMapType, key_matcher, value_matcher, "") { const Type& got = arg; if (got.kind() != TypeKind::kMap) { return false; } MapType type = got.GetMap(); Type key = type.key(); Type value = type.value(); return SafeMatcherCast(key_matcher) .MatchAndExplain(key, result_listener) && SafeMatcherCast(value_matcher) .MatchAndExplain(value, result_listener); } MATCHER_P(IsTypeKind, kind, "") { const Type& got = arg; TypeKind want_kind = kind; if (got.kind() == want_kind) { return true; } *result_listener << "got: " << TypeKindToString(got.kind()); *result_listener << "\n"; *result_listener << "wanted: " << TypeKindToString(want_kind); return false; } MATCHER_P(IsTypeType, matcher, "") { const Type& got = arg; if (got.kind() != TypeKind::kType) { return false; } TypeType type_type = got.GetType(); if (type_type.GetParameters().size() != 1) { return false; } return SafeMatcherCast(matcher).MatchAndExplain(got.GetParameters()[0], result_listener); } TEST(TypeInferenceContextTest, InstantiateTypeParams) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type type = context.InstantiateTypeParams(TypeParamType("MyType")); EXPECT_THAT(type, IsTypeParam("T%1")); Type type2 = context.InstantiateTypeParams(TypeParamType("MyType")); EXPECT_THAT(type2, IsTypeParam("T%2")); } TEST(TypeInferenceContextTest, InstantiateTypeParamsWithSubstitutions) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); TypeInferenceContext::InstanceMap instance_map; Type type = context.InstantiateTypeParams(TypeParamType("MyType"), instance_map); EXPECT_THAT(type, IsTypeParam("T%1")); Type type2 = context.InstantiateTypeParams(TypeParamType("MyType"), instance_map); EXPECT_THAT(type2, IsTypeParam("T%1")); } TEST(TypeInferenceContextTest, InstantiateTypeParamsUnparameterized) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type type = context.InstantiateTypeParams(IntType()); EXPECT_TRUE(type.IsInt()); } TEST(TypeInferenceContextTest, InstantiateTypeParamsList) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_type = ListType(&arena, TypeParamType("MyType")); Type type = context.InstantiateTypeParams(list_type); EXPECT_THAT(type, IsListType(IsTypeParam("T%1"))); } TEST(TypeInferenceContextTest, InstantiateTypeParamsListPrimitive) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_type = ListType(&arena, IntType()); Type type = context.InstantiateTypeParams(list_type); EXPECT_THAT(type, IsListType(IsTypeKind(TypeKind::kInt))); } TEST(TypeInferenceContextTest, InstantiateTypeParamsMap) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type map_type = MapType(&arena, TypeParamType("K"), TypeParamType("V")); Type type = context.InstantiateTypeParams(map_type); EXPECT_THAT(type, IsMapType(IsTypeParam("T%1"), IsTypeParam("T%2"))); } TEST(TypeInferenceContextTest, InstantiateTypeParamsMapSameParam) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type map_type = MapType(&arena, TypeParamType("E"), TypeParamType("E")); Type type = context.InstantiateTypeParams(map_type); EXPECT_THAT(type, IsMapType(IsTypeParam("T%1"), IsTypeParam("T%1"))); } TEST(TypeInferenceContextTest, InstantiateTypeParamsMapPrimitive) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type map_type = MapType(&arena, StringType(), IntType()); Type type = context.InstantiateTypeParams(map_type); EXPECT_THAT(type, IsMapType(IsTypeKind(TypeKind::kString), IsTypeKind(TypeKind::kInt))); } TEST(TypeInferenceContextTest, InstantiateTypeParamsType) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type type_type = TypeType(&arena, TypeParamType("T")); Type type = context.InstantiateTypeParams(type_type); EXPECT_THAT(type, IsTypeType(IsTypeParam("T%1"))); } TEST(TypeInferenceContextTest, InstantiateTypeParamsTypeEmpty) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type type_type = TypeType(); Type type = context.InstantiateTypeParams(type_type); EXPECT_THAT(type, IsTypeKind(TypeKind::kType)); EXPECT_THAT(type.AsType()->GetParameters(), IsEmpty()); } TEST(TypeInferenceContextTest, InstantiateTypeParamsOpaque) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); std::vector parameters = {TypeParamType("T"), IntType(), TypeParamType("U"), TypeParamType("T")}; Type type_type = OpaqueType(&arena, "MyTuple", parameters); Type type = context.InstantiateTypeParams(type_type); ASSERT_THAT(type, IsTypeKind(TypeKind::kOpaque)); EXPECT_EQ(type.AsOpaque()->name(), "MyTuple"); EXPECT_THAT(type.AsOpaque()->GetParameters(), ElementsAre(IsTypeParam("T%1"), IsTypeKind(TypeKind::kInt), IsTypeParam("T%2"), IsTypeParam("T%1"))); } // TODO(uncreated-issue/72): Does not consider any substitutions based on type // inferences yet. TEST(TypeInferenceContextTest, OpaqueTypeAssignable) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); std::vector parameters = {TypeParamType("T"), IntType()}; Type type_type = OpaqueType(&arena, "MyTuple", parameters); Type type = context.InstantiateTypeParams(type_type); ASSERT_THAT(type, IsTypeKind(TypeKind::kOpaque)); EXPECT_TRUE(context.IsAssignable(type, type)); } TEST(TypeInferenceContextTest, WrapperTypeAssignable) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); EXPECT_TRUE(context.IsAssignable(StringType(), StringWrapperType())); EXPECT_TRUE(context.IsAssignable(NullType(), StringWrapperType())); } TEST(TypeInferenceContextTest, MismatchedTypeNotAssignable) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); EXPECT_FALSE(context.IsAssignable(IntType(), StringWrapperType())); } TEST(TypeInferenceContextTest, OverloadResolution) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( auto decl, MakeFunctionDecl( "foo", MakeOverloadDecl("foo_int_int", IntType(), IntType(), IntType()), MakeOverloadDecl("foo_double_double", DoubleType(), DoubleType(), DoubleType()))); auto resolution = context.ResolveOverload(decl, {IntType(), IntType()}, /*is_receiver=*/false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kInt)); EXPECT_THAT(resolution->overloads, SizeIs(1)); } TEST(TypeInferenceContextTest, MultipleOverloadsResultTypeDyn) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( auto decl, MakeFunctionDecl( "foo", MakeOverloadDecl("foo_int_int", IntType(), IntType(), IntType()), MakeOverloadDecl("foo_double_double", DoubleType(), DoubleType(), DoubleType()))); auto resolution = context.ResolveOverload(decl, {DynType(), DynType()}, /*is_receiver=*/false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kDyn)); EXPECT_THAT(resolution->overloads, SizeIs(2)); } MATCHER_P(IsOverloadDecl, name, "") { const OverloadDecl& got = arg; return got.id() == name; } TEST(TypeInferenceContextTest, ResolveOverloadBasic) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl( "_+_", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()), MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); absl::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kInt)); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("add_int"))); } TEST(TypeInferenceContextTest, ResolveOverloadFails) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl( "_+_", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()), MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); absl::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } TEST(TypeInferenceContextTest, ResolveOverloadWithParamsNoMatch) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl( "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); absl::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl( "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); absl::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_a}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); } TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch2) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); Type list_of_int = ListType(&arena, IntType()); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl( "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); absl::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_int}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("equals"))); } TEST(TypeInferenceContextTest, ResolveOverloadWithParamsMatches) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl( "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); absl::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsBool()); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("equals"))); } TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); absl::optional resolution = context.ResolveOverload( decl, {list_of_a_instance, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsList()); EXPECT_THAT( context.FinalizeType(resolution->result_type).AsList()->GetElement(), IsTypeKind(TypeKind::kInt)) << context.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("add_list"))); absl::optional resolution2 = context.ResolveOverload( decl, {ListType(&arena, IntType()), list_of_a_instance}, false); ASSERT_TRUE(resolution2.has_value()); EXPECT_TRUE(resolution2->result_type.IsList()); EXPECT_THAT( context.FinalizeType(resolution2->result_type).AsList()->GetElement(), IsTypeKind(TypeKind::kInt)) << context.DebugString(); EXPECT_THAT(resolution2->overloads, ElementsAre(IsOverloadDecl("add_list"))); } TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsNoMatch) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); absl::optional resolution = context.ResolveOverload(decl, {list_of_a_instance, IntType()}, false); EXPECT_FALSE(resolution.has_value()); } TEST(TypeInferenceContextTest, InferencesAccumulate) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); absl::optional resolution1 = context.ResolveOverload(decl, {list_of_a_instance, list_of_a_instance}, false); ASSERT_TRUE(resolution1.has_value()); EXPECT_TRUE(resolution1->result_type.IsList()); absl::optional resolution2 = context.ResolveOverload( decl, {resolution1->result_type, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution2.has_value()); EXPECT_TRUE(resolution2->result_type.IsList()); EXPECT_THAT( context.FinalizeType(resolution2->result_type).AsList()->GetElement(), IsTypeKind(TypeKind::kInt)); EXPECT_THAT(resolution2->overloads, ElementsAre(IsOverloadDecl("add_list"))); } TEST(TypeInferenceContextTest, DebugString) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); Type list_of_int = ListType(&arena, IntType()); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); absl::optional resolution = context.ResolveOverload(decl, {list_of_int, list_of_int}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsList()); EXPECT_EQ(context.DebugString(), "type_parameter_bindings: T%1 (A) -> int"); } struct TypeInferenceContextWrapperTypesTestCase { Type wrapper_type; Type wrapped_primitive_type; }; class TypeInferenceContextWrapperTypesTest : public ::testing::TestWithParam< TypeInferenceContextWrapperTypesTestCase> { public: TypeInferenceContextWrapperTypesTest() : context_(&arena_) { auto decl = MakeFunctionDecl( "_?_:_", MakeOverloadDecl("ternary", /*result_type=*/TypeParamType("A"), BoolType(), TypeParamType("A"), TypeParamType("A"))); ABSL_CHECK_OK(decl.status()); ternary_decl_ = *std::move(decl); } protected: google::protobuf::Arena arena_; TypeInferenceContext context_{&arena_}; FunctionDecl ternary_decl_; }; TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); absl::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapped_primitive_type}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(context_.FinalizeType(resolution->result_type), IsTypeKind(test_case.wrapper_type.kind())) << context_.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); } TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); absl::optional resolution = context_.ResolveOverload( ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapper_type}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(context_.FinalizeType(resolution->result_type), IsTypeKind(test_case.wrapper_type.kind())) << context_.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); } TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); absl::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, NullType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(context_.FinalizeType(resolution->result_type), IsTypeKind(test_case.wrapper_type.kind())) << context_.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); } TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); absl::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), NullType(), test_case.wrapper_type}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(context_.FinalizeType(resolution->result_type), IsTypeKind(test_case.wrapper_type.kind())) << context_.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); } TEST_P(TypeInferenceContextWrapperTypesTest, PrimitiveWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); absl::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapped_primitive_type, test_case.wrapper_type}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(context_.FinalizeType(resolution->result_type), IsTypeKind(test_case.wrapper_type.kind())) << context_.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); } INSTANTIATE_TEST_SUITE_P( Types, TypeInferenceContextWrapperTypesTest, ::testing::Values( TypeInferenceContextWrapperTypesTestCase{IntWrapperType(), IntType()}, TypeInferenceContextWrapperTypesTestCase{UintWrapperType(), UintType()}, TypeInferenceContextWrapperTypesTestCase{DoubleWrapperType(), DoubleType()}, TypeInferenceContextWrapperTypesTestCase{StringWrapperType(), StringType()}, TypeInferenceContextWrapperTypesTestCase{BytesWrapperType(), BytesType()}, TypeInferenceContextWrapperTypesTestCase{BoolWrapperType(), BoolType()}, TypeInferenceContextWrapperTypesTestCase{DynType(), IntType()})); TEST(TypeInferenceContextTest, ResolveOverloadWithUnionTypePromotion) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl( "_?_:_", MakeOverloadDecl("ternary", /*result_type=*/TypeParamType("A"), BoolType(), TypeParamType("A"), TypeParamType("A")))); absl::optional resolution = context.ResolveOverload(decl, {BoolType(), NullType(), IntWrapperType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(context.FinalizeType(resolution->result_type), IsTypeKind(TypeKind::kIntWrapper)) << context.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("ternary"))); } // TypeType has special handling (differently-parameterized type-types are // always assignable for the sake of comparisons). TEST(TypeInferenceContextTest, ResolveOverloadWithTypeType) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( FunctionDecl decl, MakeFunctionDecl("type", MakeOverloadDecl("to_type", /*result_type=*/ TypeType(&arena, TypeParamType("A")), TypeParamType("A")))); absl::optional resolution = context.ResolveOverload(decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); auto result_type = context.FinalizeType(resolution->result_type); ASSERT_THAT(result_type, IsTypeKind(TypeKind::kType)); EXPECT_THAT(result_type.AsType()->GetParameters(), ElementsAre(IsTypeKind(TypeKind::kString))); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("to_type"))); } TEST(TypeInferenceContextTest, ResolveOverloadWithInferredTypeType) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); ASSERT_OK_AND_ASSIGN( FunctionDecl to_type_decl, MakeFunctionDecl("type", MakeOverloadDecl("to_type", /*result_type=*/ TypeType(&arena, TypeParamType("A")), TypeParamType("A")))); ASSERT_OK_AND_ASSIGN( FunctionDecl equals_decl, MakeFunctionDecl("_==_", MakeOverloadDecl("equals", /*result_type=*/ BoolType(), TypeParamType("A"), TypeParamType("A")))); absl::optional resolution = context.ResolveOverload(to_type_decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); auto lhs_result_type = resolution->result_type; ASSERT_THAT(lhs_result_type, IsTypeKind(TypeKind::kType)); resolution = context.ResolveOverload(to_type_decl, {IntType()}, false); ASSERT_TRUE(resolution.has_value()); auto rhs_result_type = resolution->result_type; ASSERT_THAT(rhs_result_type, IsTypeKind(TypeKind::kType)); resolution = context.ResolveOverload( equals_decl, {rhs_result_type, lhs_result_type}, false); ASSERT_TRUE(resolution.has_value()); auto result_type = context.FinalizeType(resolution->result_type); ASSERT_THAT(result_type, IsTypeKind(TypeKind::kBool)); auto inferred_lhs = context.FinalizeType(lhs_result_type); auto inferred_rhs = context.FinalizeType(rhs_result_type); ASSERT_THAT(inferred_rhs, IsTypeKind(TypeKind::kType)); ASSERT_THAT(inferred_lhs, IsTypeKind(TypeKind::kType)); ASSERT_THAT(inferred_lhs.AsType()->GetParameters(), ElementsAre(IsTypeKind(TypeKind::kString))); ASSERT_THAT(inferred_rhs.AsType()->GetParameters(), ElementsAre(IsTypeKind(TypeKind::kInt))); } TEST(TypeInferenceContextTest, AssignabilityContext) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); { auto assignability_context = context.CreateAssignabilityContext(); EXPECT_TRUE(assignability_context.IsAssignable( IntType(), list_of_a_instance.AsList()->GetElement())); EXPECT_TRUE(assignability_context.IsAssignable( IntType(), list_of_a_instance.AsList()->GetElement())); EXPECT_TRUE(assignability_context.IsAssignable( IntWrapperType(), list_of_a_instance.AsList()->GetElement())); assignability_context.UpdateInferredTypeAssignments(); } Type resolved_type = context.FinalizeType(list_of_a_instance); ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); EXPECT_THAT(resolved_type.AsList()->GetElement(), IsTypeKind(TypeKind::kIntWrapper)); } TEST(TypeInferenceContextTest, AssignabilityContextAbstractType) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); { auto assignability_context = context.CreateAssignabilityContext(); EXPECT_TRUE(assignability_context.IsAssignable( OptionalType(&arena, IntType()), list_of_a_instance.AsList()->GetElement())); EXPECT_TRUE(assignability_context.IsAssignable( OptionalType(&arena, DynType()), list_of_a_instance.AsList()->GetElement())); assignability_context.UpdateInferredTypeAssignments(); } Type resolved_type = context.FinalizeType(list_of_a_instance); ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); ASSERT_THAT(resolved_type.AsList()->GetElement(), IsTypeKind(TypeKind::kOpaque)); EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(), "optional_type"); EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(), ElementsAre(IsTypeKind(TypeKind::kDyn))); } TEST(TypeInferenceContextTest, AssignabilityContextAbstractTypeWrapper) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); { auto assignability_context = context.CreateAssignabilityContext(); EXPECT_TRUE(assignability_context.IsAssignable( OptionalType(&arena, IntType()), list_of_a_instance.AsList()->GetElement())); EXPECT_TRUE(assignability_context.IsAssignable( OptionalType(&arena, IntWrapperType()), list_of_a_instance.AsList()->GetElement())); assignability_context.UpdateInferredTypeAssignments(); } Type resolved_type = context.FinalizeType(list_of_a_instance); ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); ASSERT_THAT(resolved_type.AsList()->GetElement(), IsTypeKind(TypeKind::kOpaque)); EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->name(), "optional_type"); EXPECT_THAT(resolved_type.AsList()->GetElement().AsOpaque()->GetParameters(), ElementsAre(IsTypeKind(TypeKind::kIntWrapper))); } TEST(TypeInferenceContextTest, AssignabilityContextNotApplied) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); { auto assignability_context = context.CreateAssignabilityContext(); EXPECT_TRUE(assignability_context.IsAssignable( IntType(), list_of_a_instance.AsList()->GetElement())); EXPECT_TRUE(assignability_context.IsAssignable( IntType(), list_of_a_instance.AsList()->GetElement())); EXPECT_TRUE(assignability_context.IsAssignable( IntWrapperType(), list_of_a_instance.AsList()->GetElement())); } Type resolved_type = context.FinalizeType(list_of_a_instance); ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); EXPECT_THAT(resolved_type.AsList()->GetElement(), IsTypeKind(TypeKind::kDyn)); } TEST(TypeInferenceContextTest, AssignabilityContextReset) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); Type list_of_a = ListType(&arena, TypeParamType("A")); Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); { auto assignability_context = context.CreateAssignabilityContext(); EXPECT_TRUE(assignability_context.IsAssignable( IntType(), list_of_a_instance.AsList()->GetElement())); assignability_context.Reset(); EXPECT_TRUE(assignability_context.IsAssignable( DoubleType(), list_of_a_instance.AsList()->GetElement())); assignability_context.UpdateInferredTypeAssignments(); } Type resolved_type = context.FinalizeType(list_of_a_instance); ASSERT_THAT(resolved_type, IsTypeKind(TypeKind::kList)); EXPECT_THAT(resolved_type.AsList()->GetElement(), IsTypeKind(TypeKind::kDouble)); } } // namespace } // namespace cel::checker_internal ================================================ FILE: checker/optional.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/optional.h" #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "base/builtins.h" #include "checker/internal/builtins_arena.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "internal/status_macros.h" namespace cel { namespace { Type OptionalOfV() { static const absl::NoDestructor kInstance( checker_internal::BuiltinsArena(), TypeParamType("V")); return *kInstance; } Type TypeOfOptionalOfV() { static const absl::NoDestructor kInstance( checker_internal::BuiltinsArena(), OptionalOfV()); return *kInstance; } Type ListOfV() { static const absl::NoDestructor kInstance( checker_internal::BuiltinsArena(), TypeParamType("V")); return *kInstance; } Type OptionalListOfV() { static const absl::NoDestructor kInstance( checker_internal::BuiltinsArena(), ListOfV()); return *kInstance; } Type MapOfKV() { static const absl::NoDestructor kInstance( checker_internal::BuiltinsArena(), TypeParamType("K"), TypeParamType("V")); return *kInstance; } Type OptionalMapOfKV() { static const absl::NoDestructor kInstance( checker_internal::BuiltinsArena(), MapOfKV()); return *kInstance; } class OptionalNames { public: static constexpr char kOptionalType[] = "optional_type"; static constexpr char kOptionalOf[] = "optional.of"; static constexpr char kOptionalOfNonZeroValue[] = "optional.ofNonZeroValue"; static constexpr char kOptionalNone[] = "optional.none"; static constexpr char kOptionalValue[] = "value"; static constexpr char kOptionalHasValue[] = "hasValue"; static constexpr char kOptionalOr[] = "or"; static constexpr char kOptionalOrValue[] = "orValue"; static constexpr char kOptionalSelect[] = "_?._"; static constexpr char kOptionalIndex[] = "_[?_]"; static constexpr char kOptionalFirst[] = "first"; static constexpr char kOptionalLast[] = "last"; }; class OptionalOverloads { public: // Creation static constexpr char kOptionalOf[] = "optional_of"; static constexpr char kOptionalOfNonZeroValue[] = "optional_ofNonZeroValue"; static constexpr char kOptionalNone[] = "optional_none"; // Basic accessors static constexpr char kOptionalValue[] = "optional_value"; static constexpr char kOptionalHasValue[] = "optional_hasValue"; // Chaining `or` overloads. static constexpr char kOptionalOr[] = "optional_or_optional"; static constexpr char kOptionalOrValue[] = "optional_orValue_value"; // Selection static constexpr char kOptionalSelect[] = "select_optional_field"; // Indexing static constexpr char kListOptionalIndexInt[] = "list_optindex_optional_int"; static constexpr char kOptionalListOptionalIndexInt[] = "optional_list_optindex_optional_int"; static constexpr char kMapOptionalIndexValue[] = "map_optindex_optional_value"; static constexpr char kOptionalMapOptionalIndexValue[] = "optional_map_optindex_optional_value"; static constexpr char kListFirst[] = "list_first"; static constexpr char kListLast[] = "list_last"; // Syntactic sugar for chained indexing. static constexpr char kOptionalListIndexInt[] = "optional_list_index_int"; static constexpr char kOptionalMapIndexValue[] = "optional_map_index_value"; }; absl::Status RegisterOptionalDecls(TypeCheckerBuilder& builder, int version) { CEL_ASSIGN_OR_RETURN( auto of, MakeFunctionDecl(OptionalNames::kOptionalOf, MakeOverloadDecl(OptionalOverloads::kOptionalOf, OptionalOfV(), TypeParamType("V")))); CEL_ASSIGN_OR_RETURN( auto of_non_zero, MakeFunctionDecl( OptionalNames::kOptionalOfNonZeroValue, MakeOverloadDecl(OptionalOverloads::kOptionalOfNonZeroValue, OptionalOfV(), TypeParamType("V")))); CEL_ASSIGN_OR_RETURN( auto none, MakeFunctionDecl( OptionalNames::kOptionalNone, MakeOverloadDecl(OptionalOverloads::kOptionalNone, OptionalOfV()))); CEL_ASSIGN_OR_RETURN( auto value, MakeFunctionDecl(OptionalNames::kOptionalValue, MakeMemberOverloadDecl( OptionalOverloads::kOptionalValue, TypeParamType("V"), OptionalOfV()))); CEL_ASSIGN_OR_RETURN( auto has_value, MakeFunctionDecl(OptionalNames::kOptionalHasValue, MakeMemberOverloadDecl( OptionalOverloads::kOptionalHasValue, BoolType(), OptionalOfV()))); CEL_ASSIGN_OR_RETURN( auto or_, MakeFunctionDecl( OptionalNames::kOptionalOr, MakeMemberOverloadDecl(OptionalOverloads::kOptionalOr, OptionalOfV(), OptionalOfV(), OptionalOfV()))); CEL_ASSIGN_OR_RETURN(auto or_value, MakeFunctionDecl(OptionalNames::kOptionalOrValue, MakeMemberOverloadDecl( OptionalOverloads::kOptionalOrValue, TypeParamType("V"), OptionalOfV(), TypeParamType("V")))); // This is special cased by the type checker -- just adding a Decl to prevent // accidental user overloading. CEL_ASSIGN_OR_RETURN( auto select, MakeFunctionDecl( OptionalNames::kOptionalSelect, MakeOverloadDecl(OptionalOverloads::kOptionalSelect, OptionalOfV(), DynType(), StringType()))); CEL_ASSIGN_OR_RETURN( auto opt_index, MakeFunctionDecl( OptionalNames::kOptionalIndex, MakeOverloadDecl(OptionalOverloads::kOptionalListOptionalIndexInt, OptionalOfV(), OptionalListOfV(), IntType()), MakeOverloadDecl(OptionalOverloads::kListOptionalIndexInt, OptionalOfV(), ListOfV(), IntType()), MakeOverloadDecl(OptionalOverloads::kMapOptionalIndexValue, OptionalOfV(), MapOfKV(), TypeParamType("K")), MakeOverloadDecl(OptionalOverloads::kOptionalMapOptionalIndexValue, OptionalOfV(), OptionalMapOfKV(), TypeParamType("K")))); CEL_ASSIGN_OR_RETURN( auto first, MakeFunctionDecl(OptionalNames::kOptionalFirst, MakeMemberOverloadDecl(OptionalOverloads::kListFirst, OptionalOfV(), ListOfV()))); CEL_ASSIGN_OR_RETURN( auto last, MakeFunctionDecl(OptionalNames::kOptionalLast, MakeMemberOverloadDecl(OptionalOverloads::kListLast, OptionalOfV(), ListOfV()))); CEL_ASSIGN_OR_RETURN( auto index, MakeFunctionDecl( cel::builtin::kIndex, MakeOverloadDecl(OptionalOverloads::kOptionalListIndexInt, OptionalOfV(), OptionalListOfV(), IntType()), MakeOverloadDecl(OptionalOverloads::kOptionalMapIndexValue, OptionalOfV(), OptionalMapOfKV(), TypeParamType("K")))); CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl(OptionalNames::kOptionalType, TypeOfOptionalOfV()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(of))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(of_non_zero))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(none))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(value))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(has_value))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(or_))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(or_value))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(opt_index))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(select))); CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(index))); if (version == 0 || version == 1) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(first))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(last))); return absl::OkStatus(); } } // namespace CheckerLibrary OptionalCheckerLibrary(int version) { return CheckerLibrary({ "optional", [version](TypeCheckerBuilder& builder) { return RegisterOptionalDecls(builder, version); }, }); } } // namespace cel ================================================ FILE: checker/optional.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_OPTIONAL_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_OPTIONAL_H_ #include "checker/type_checker_builder.h" namespace cel { constexpr int kOptionalExtensionLatestVersion = 2; // Library for CEL optional definitions. CheckerLibrary OptionalCheckerLibrary( int version = kOptionalExtensionLatestVersion); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_OPTIONAL_H_ ================================================ FILE: checker/optional_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/optional.h" #include #include #include #include #include "absl/status/status_matchers.h" #include "absl/strings/str_join.h" #include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" #include "checker/standard_library.h" #include "checker/type_check_issue.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "common/ast.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::cel::checker_internal::MakeTestParsedAst; using ::cel::internal::GetSharedTestingDescriptorPool; using ::testing::_; using ::testing::Contains; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Key; using ::testing::Not; using ::testing::Property; using ::testing::SizeIs; MATCHER_P(IsOptionalType, inner_type, "") { const TypeSpec& type = arg; if (!type.has_abstract_type()) { return false; } const auto& abs_type = type.abstract_type(); if (abs_type.name() != "optional_type") { *result_listener << "expected optional_type, got: " << abs_type.name(); return false; } if (abs_type.parameter_types().size() != 1) { *result_listener << "unexpected number of parameters: " << abs_type.parameter_types().size(); return false; } if (inner_type == abs_type.parameter_types()[0]) { return true; } *result_listener << "unexpected inner type: " << abs_type.parameter_types()[0].type_kind().index(); return false; } TEST(OptionalTest, OptSelectDoesNotAnnotateFieldType) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("TestAllTypes{}.?single_int64")); ASSERT_OK_AND_ASSIGN(auto result, checker->Check(std::move(ast))); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); ASSERT_THAT(checked_ast->root_expr().call_expr().args(), SizeIs(2)); int64_t field_id = checked_ast->root_expr().call_expr().args()[1].id(); EXPECT_NE(field_id, 0); EXPECT_THAT(checked_ast->type_map(), Not(Contains(Key(field_id)))); EXPECT_THAT(checked_ast->GetTypeOrDyn(checked_ast->root_expr().id()), IsOptionalType(TypeSpec(PrimitiveType::kInt64))); } struct TestCase { std::string expr; testing::Matcher result_type_matcher; std::string error_substring; }; class OptionalTest : public testing::TestWithParam {}; TEST_P(OptionalTest, Runner) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); const TestCase& test_case = GetParam(); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); ASSERT_OK_AND_ASSIGN(auto result, checker->Check(std::move(ast))); if (!test_case.error_substring.empty()) { EXPECT_THAT(result.GetIssues(), Contains(Property(&TypeCheckIssue::message, HasSubstr(test_case.error_substring)))) << absl::StrJoin(result.GetIssues(), "\n", [](std::string* out, const auto& i) { absl::StrAppend(out, i.message()); }); return; } EXPECT_THAT(result.GetIssues(), IsEmpty()) << "for expression: " << test_case.expr; ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); int64_t root_id = checked_ast->root_expr().id(); EXPECT_THAT(checked_ast->GetTypeOrDyn(root_id), test_case.result_type_matcher) << "for expression: " << test_case.expr; } INSTANTIATE_TEST_SUITE_P( OptionalTests, OptionalTest, ::testing::Values( TestCase{ "optional.of('abc')", IsOptionalType(TypeSpec(PrimitiveType::kString)), }, TestCase{ "optional.ofNonZeroValue('')", IsOptionalType(TypeSpec(PrimitiveType::kString)), }, TestCase{ "optional.none()", IsOptionalType(TypeSpec(DynTypeSpec())), }, // Odd case -- the correct result might be a bespoke recursively-defined // type but CEL doesn't support that. Null is used because it is // implicitly assignable to optional types. This allows for a recursive // type to be non-trivial and verify the checker is actually avoiding // introducing a cyclic type. TestCase{ "[optional.none()].map(x, [?x, null, x])", Eq(TypeSpec(ListTypeSpec(std::make_unique( ListTypeSpec(std::make_unique(NullTypeSpec())))))), }, TestCase{ "optional.of('abc').hasValue()", Eq(TypeSpec(PrimitiveType::kBool)), }, TestCase{ "optional.of('abc').value()", Eq(TypeSpec(PrimitiveType::kString)), }, TestCase{ "type(optional.of('abc')) == optional_type", Eq(TypeSpec(PrimitiveType::kBool)), }, TestCase{ "type(optional.of('abc')) == optional_type", Eq(TypeSpec(PrimitiveType::kBool)), }, TestCase{ "optional.of('abc').or(optional.of('def'))", IsOptionalType(TypeSpec(PrimitiveType::kString)), }, TestCase{"optional.of('abc').or(optional.of(1))", _, "no matching overload for 'or'"}, TestCase{ "optional.of('abc').orValue('def')", Eq(TypeSpec(PrimitiveType::kString)), }, TestCase{"optional.of('abc').orValue(1)", _, "no matching overload for 'orValue'"}, TestCase{ "{'k': 'v'}.?k", IsOptionalType(TypeSpec(PrimitiveType::kString)), }, TestCase{"1.?k", _, "expression of type 'int' cannot be the operand of a select " "operation"}, TestCase{ "{'k': {'k': 'v'}}.?k.?k2", IsOptionalType(TypeSpec(PrimitiveType::kString)), }, TestCase{ "{'k': {'k': 'v'}}.?k.k2", IsOptionalType(TypeSpec(PrimitiveType::kString)), }, TestCase{"{?'k': optional.of('v')}", Eq(TypeSpec(MapTypeSpec(std::unique_ptr(new TypeSpec( PrimitiveType::kString)), std::unique_ptr(new TypeSpec( PrimitiveType::kString)))))}, TestCase{"{'k': 'v', ?'k2': optional.none()}", Eq(TypeSpec(MapTypeSpec(std::unique_ptr(new TypeSpec( PrimitiveType::kString)), std::unique_ptr(new TypeSpec( PrimitiveType::kString)))))}, TestCase{"{'k': 'v', ?'k2': 'v'}", _, "expected type 'optional_type(string)' but found 'string'"}, TestCase{"[?optional.of('v')]", Eq(TypeSpec(ListTypeSpec(std::unique_ptr( new TypeSpec(PrimitiveType::kString)))))}, TestCase{"['v', ?optional.none()]", Eq(TypeSpec(ListTypeSpec(std::unique_ptr( new TypeSpec(PrimitiveType::kString)))))}, TestCase{"['v1', ?'v2']", _, "expected type 'optional_type(string)' but found 'string'"}, TestCase{"[optional.of(dyn('1')), optional.of('2')][0]", IsOptionalType(TypeSpec(DynTypeSpec()))}, TestCase{"[optional.of('1'), optional.of(dyn('2'))][0]", IsOptionalType(TypeSpec(DynTypeSpec()))}, TestCase{"[{1: optional.of(1)}, {1: optional.of(dyn(1))}][0][1]", IsOptionalType(TypeSpec(DynTypeSpec()))}, TestCase{"[{1: optional.of(dyn(1))}, {1: optional.of(1)}][0][1]", IsOptionalType(TypeSpec(DynTypeSpec()))}, TestCase{"[optional.of('1'), optional.of(2)][0]", Eq(TypeSpec(DynTypeSpec()))}, TestCase{"['v1', ?'v2']", _, "expected type 'optional_type(string)' but found 'string'"}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: " "optional.of(1)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, TestCase{"[0][?1]", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, TestCase{"[[0]][?1][?1]", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, TestCase{"[[0]][?1][1]", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, TestCase{"{0: 1}[?1]", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, TestCase{"{0: {0: 1}}[?1][?1]", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, TestCase{"{0: {0: 1}}[?1][1]", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, TestCase{"{0: {0: 1}}[?1]['']", _, "no matching overload for '_[_]'"}, TestCase{"{0: {0: 1}}[?1][?'']", _, "no matching overload for '_[?_]'"}, TestCase{"[1, 2, 3].first()", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, TestCase{"[1, 2, 3].last()", IsOptionalType(TypeSpec(PrimitiveType::kInt64))}, TestCase{"optional.of('abc').optMap(x, x + 'def')", IsOptionalType(TypeSpec(PrimitiveType::kString))}, TestCase{"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))", IsOptionalType(TypeSpec(PrimitiveType::kString))}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " "optional.of(0)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, // Legacy nullability behaviors. TestCase{ "cel.expr.conformance.proto3.TestAllTypes{?single_value: null}", Eq(TypeSpec( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")))}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_value: " "optional.of(null)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{}.?single_int64 " "== null", Eq(TypeSpec(PrimitiveType::kBool))})); class OptionalStrictNullAssignmentTest : public testing::TestWithParam {}; TEST_P(OptionalStrictNullAssignmentTest, Runner) { CheckerOptions options; options.enable_legacy_null_assignment = false; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); const TestCase& test_case = GetParam(); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checker, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); ASSERT_OK_AND_ASSIGN(auto result, checker->Check(std::move(ast))); if (!test_case.error_substring.empty()) { EXPECT_THAT(result.GetIssues(), Contains(Property(&TypeCheckIssue::message, HasSubstr(test_case.error_substring)))) << absl::StrJoin(result.GetIssues(), "\n", [](std::string* out, const auto& i) { absl::StrAppend(out, i.message()); }); return; } EXPECT_THAT(result.GetIssues(), IsEmpty()) << "for expression: " << test_case.expr; ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); int64_t root_id = checked_ast->root_expr().id(); EXPECT_THAT(checked_ast->GetTypeOrDyn(root_id), test_case.result_type_matcher) << "for expression: " << test_case.expr; } INSTANTIATE_TEST_SUITE_P( OptionalTests, OptionalStrictNullAssignmentTest, ::testing::Values( TestCase{ "cel.expr.conformance.proto3.TestAllTypes{?single_int64: null}", _, "expected type of field 'single_int64' is 'optional_type(int)' but " "provided type is 'null_type'"}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{}.?single_int64 " "== null", _, "no matching overload for '_==_'"})); } // namespace } // namespace cel ================================================ FILE: checker/standard_library.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/standard_library.h" #include #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "checker/internal/builtins_arena.h" #include "checker/type_checker_builder.h" #include "common/constant.h" #include "common/decl.h" #include "common/standard_definitions.h" #include "common/type.h" #include "internal/status_macros.h" namespace cel { namespace { using ::cel::checker_internal::BuiltinsArena; // Arbitrary type parameter name A. TypeParamType TypeParamA() { return TypeParamType("A"); } // Arbitrary type parameter name B. TypeParamType TypeParamB() { return TypeParamType("B"); } Type ListOfA() { static absl::NoDestructor kInstance( ListType(BuiltinsArena(), TypeParamA())); return *kInstance; } Type MapOfAB() { static absl::NoDestructor kInstance( MapType(BuiltinsArena(), TypeParamA(), TypeParamB())); return *kInstance; } Type TypeOfType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), TypeType())); return *kInstance; } Type TypeOfA() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), TypeParamA())); return *kInstance; } Type TypeNullType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), NullType())); return *kInstance; } Type TypeBoolType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), BoolType())); return *kInstance; } Type TypeIntType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), IntType())); return *kInstance; } Type TypeUintType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), UintType())); return *kInstance; } Type TypeDoubleType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), DoubleType())); return *kInstance; } Type TypeStringType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), StringType())); return *kInstance; } Type TypeBytesType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), BytesType())); return *kInstance; } Type TypeDynType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), DynType())); return *kInstance; } Type TypeListType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), ListOfA())); return *kInstance; } Type TypeMapType() { static absl::NoDestructor kInstance( TypeType(BuiltinsArena(), MapOfAB())); return *kInstance; } absl::Status AddArithmeticOps(TypeCheckerBuilder& builder) { FunctionDecl add_op; add_op.set_name(StandardFunctions::kAdd); CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kAddInt, IntType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(add_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kAddDouble, DoubleType(), DoubleType(), DoubleType()))); CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kAddUint, UintType(), UintType(), UintType()))); // timestamp math CEL_RETURN_IF_ERROR(add_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kAddDurationDuration, DurationType(), DurationType(), DurationType()))); CEL_RETURN_IF_ERROR(add_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kAddDurationTimestamp, TimestampType(), DurationType(), TimestampType()))); CEL_RETURN_IF_ERROR(add_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kAddTimestampDuration, TimestampType(), TimestampType(), DurationType()))); // string concat CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kAddBytes, BytesType(), BytesType(), BytesType()))); CEL_RETURN_IF_ERROR(add_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kAddString, StringType(), StringType(), StringType()))); // list concat CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kAddList, ListOfA(), ListOfA(), ListOfA()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(add_op))); FunctionDecl subtract_op; subtract_op.set_name(StandardFunctions::kSubtract); CEL_RETURN_IF_ERROR(subtract_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kSubtractInt, IntType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(subtract_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kSubtractUint, UintType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(subtract_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kSubtractDouble, DoubleType(), DoubleType(), DoubleType()))); // Timestamp math CEL_RETURN_IF_ERROR(subtract_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kSubtractDurationDuration, DurationType(), DurationType(), DurationType()))); CEL_RETURN_IF_ERROR(subtract_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kSubtractTimestampDuration, TimestampType(), TimestampType(), DurationType()))); CEL_RETURN_IF_ERROR(subtract_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kSubtractTimestampTimestamp, DurationType(), TimestampType(), TimestampType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(subtract_op))); FunctionDecl multiply_op; multiply_op.set_name(StandardFunctions::kMultiply); CEL_RETURN_IF_ERROR(multiply_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kMultiplyInt, IntType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(multiply_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kMultiplyUint, UintType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(multiply_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kMultiplyDouble, DoubleType(), DoubleType(), DoubleType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(multiply_op))); FunctionDecl division_op; division_op.set_name(StandardFunctions::kDivide); CEL_RETURN_IF_ERROR(division_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kDivideInt, IntType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(division_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kDivideUint, UintType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(division_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kDivideDouble, DoubleType(), DoubleType(), DoubleType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(division_op))); FunctionDecl modulo_op; modulo_op.set_name(StandardFunctions::kModulo); CEL_RETURN_IF_ERROR(modulo_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kModuloInt, IntType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(modulo_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kModuloUint, UintType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(modulo_op))); FunctionDecl negate_op; negate_op.set_name(StandardFunctions::kNeg); CEL_RETURN_IF_ERROR(negate_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kNegateInt, IntType(), IntType()))); CEL_RETURN_IF_ERROR(negate_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kNegateDouble, DoubleType(), DoubleType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(negate_op))); return absl::OkStatus(); } absl::Status AddLogicalOps(TypeCheckerBuilder& builder) { FunctionDecl not_op; not_op.set_name(StandardFunctions::kNot); CEL_RETURN_IF_ERROR(not_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kNot, BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(not_op))); FunctionDecl and_op; and_op.set_name(StandardFunctions::kAnd); CEL_RETURN_IF_ERROR(and_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kAnd, BoolType(), BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(and_op))); FunctionDecl or_op; or_op.set_name(StandardFunctions::kOr); CEL_RETURN_IF_ERROR(or_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kOr, BoolType(), BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(or_op))); FunctionDecl conditional_op; conditional_op.set_name(StandardFunctions::kTernary); CEL_RETURN_IF_ERROR(conditional_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kConditional, TypeParamA(), BoolType(), TypeParamA(), TypeParamA()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(conditional_op))); FunctionDecl not_strictly_false; not_strictly_false.set_name(StandardFunctions::kNotStrictlyFalse); CEL_RETURN_IF_ERROR(not_strictly_false.AddOverload(MakeOverloadDecl( StandardOverloadIds::kNotStrictlyFalse, BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(not_strictly_false))); FunctionDecl not_strictly_false_deprecated; not_strictly_false_deprecated.set_name( StandardFunctions::kNotStrictlyFalseDeprecated); CEL_RETURN_IF_ERROR(not_strictly_false_deprecated.AddOverload( MakeOverloadDecl(StandardOverloadIds::kNotStrictlyFalseDeprecated, BoolType(), BoolType()))); CEL_RETURN_IF_ERROR( builder.AddFunction(std::move(not_strictly_false_deprecated))); return absl::OkStatus(); } absl::Status AddTypeConversions(TypeCheckerBuilder& builder) { FunctionDecl to_dyn; to_dyn.set_name(StandardFunctions::kDyn); CEL_RETURN_IF_ERROR(to_dyn.AddOverload( MakeOverloadDecl(StandardOverloadIds::kToDyn, DynType(), TypeParamA()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_dyn))); // Uint FunctionDecl to_uint; to_uint.set_name(StandardFunctions::kUint); CEL_RETURN_IF_ERROR(to_uint.AddOverload(MakeOverloadDecl( StandardOverloadIds::kUintToUint, UintType(), UintType()))); CEL_RETURN_IF_ERROR(to_uint.AddOverload(MakeOverloadDecl( StandardOverloadIds::kIntToUint, UintType(), IntType()))); CEL_RETURN_IF_ERROR(to_uint.AddOverload(MakeOverloadDecl( StandardOverloadIds::kDoubleToUint, UintType(), DoubleType()))); CEL_RETURN_IF_ERROR(to_uint.AddOverload(MakeOverloadDecl( StandardOverloadIds::kStringToUint, UintType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_uint))); // Int FunctionDecl to_int; to_int.set_name(StandardFunctions::kInt); CEL_RETURN_IF_ERROR(to_int.AddOverload( MakeOverloadDecl(StandardOverloadIds::kIntToInt, IntType(), IntType()))); CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( StandardOverloadIds::kUintToInt, IntType(), UintType()))); CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( StandardOverloadIds::kDoubleToInt, IntType(), DoubleType()))); CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( StandardOverloadIds::kStringToInt, IntType(), StringType()))); CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( StandardOverloadIds::kTimestampToInt, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( StandardOverloadIds::kDurationToInt, IntType(), DurationType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_int))); FunctionDecl to_double; to_double.set_name(StandardFunctions::kDouble); CEL_RETURN_IF_ERROR(to_double.AddOverload(MakeOverloadDecl( StandardOverloadIds::kDoubleToDouble, DoubleType(), DoubleType()))); CEL_RETURN_IF_ERROR(to_double.AddOverload(MakeOverloadDecl( StandardOverloadIds::kIntToDouble, DoubleType(), IntType()))); CEL_RETURN_IF_ERROR(to_double.AddOverload(MakeOverloadDecl( StandardOverloadIds::kUintToDouble, DoubleType(), UintType()))); CEL_RETURN_IF_ERROR(to_double.AddOverload(MakeOverloadDecl( StandardOverloadIds::kStringToDouble, DoubleType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_double))); FunctionDecl to_bool; to_bool.set_name("bool"); CEL_RETURN_IF_ERROR(to_bool.AddOverload(MakeOverloadDecl( StandardOverloadIds::kBoolToBool, BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(to_bool.AddOverload(MakeOverloadDecl( StandardOverloadIds::kStringToBool, BoolType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_bool))); FunctionDecl to_string; to_string.set_name(StandardFunctions::kString); CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( StandardOverloadIds::kStringToString, StringType(), StringType()))); CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( StandardOverloadIds::kBytesToString, StringType(), BytesType()))); CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( StandardOverloadIds::kBoolToString, StringType(), BoolType()))); CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( StandardOverloadIds::kDoubleToString, StringType(), DoubleType()))); CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( StandardOverloadIds::kIntToString, StringType(), IntType()))); CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( StandardOverloadIds::kUintToString, StringType(), UintType()))); CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( StandardOverloadIds::kTimestampToString, StringType(), TimestampType()))); CEL_RETURN_IF_ERROR(to_string.AddOverload(MakeOverloadDecl( StandardOverloadIds::kDurationToString, StringType(), DurationType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_string))); FunctionDecl to_bytes; to_bytes.set_name(StandardFunctions::kBytes); CEL_RETURN_IF_ERROR(to_bytes.AddOverload(MakeOverloadDecl( StandardOverloadIds::kBytesToBytes, BytesType(), BytesType()))); CEL_RETURN_IF_ERROR(to_bytes.AddOverload(MakeOverloadDecl( StandardOverloadIds::kStringToBytes, BytesType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_bytes))); FunctionDecl to_timestamp; to_timestamp.set_name(StandardFunctions::kTimestamp); CEL_RETURN_IF_ERROR(to_timestamp.AddOverload( MakeOverloadDecl(StandardOverloadIds::kTimestampToTimestamp, TimestampType(), TimestampType()))); CEL_RETURN_IF_ERROR(to_timestamp.AddOverload(MakeOverloadDecl( StandardOverloadIds::kStringToTimestamp, TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(to_timestamp.AddOverload(MakeOverloadDecl( StandardOverloadIds::kIntToTimestamp, TimestampType(), IntType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_timestamp))); FunctionDecl to_duration; to_duration.set_name(StandardFunctions::kDuration); CEL_RETURN_IF_ERROR(to_duration.AddOverload( MakeOverloadDecl(StandardOverloadIds::kDurationToDuration, DurationType(), DurationType()))); CEL_RETURN_IF_ERROR(to_duration.AddOverload(MakeOverloadDecl( StandardOverloadIds::kStringToDuration, DurationType(), StringType()))); CEL_RETURN_IF_ERROR(to_duration.AddOverload(MakeOverloadDecl( StandardOverloadIds::kIntToDuration, DurationType(), IntType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_duration))); FunctionDecl to_type; to_type.set_name(StandardFunctions::kType); CEL_RETURN_IF_ERROR(to_type.AddOverload(MakeOverloadDecl( StandardOverloadIds::kToType, Type(TypeOfA()), TypeParamA()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(to_type))); return absl::OkStatus(); } absl::Status AddEqualityOps(TypeCheckerBuilder& builder) { FunctionDecl equals_op; equals_op.set_name(StandardFunctions::kEqual); CEL_RETURN_IF_ERROR(equals_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kEquals, BoolType(), TypeParamA(), TypeParamA()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(equals_op))); FunctionDecl not_equals_op; not_equals_op.set_name(StandardFunctions::kInequal); CEL_RETURN_IF_ERROR(not_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kNotEquals, BoolType(), TypeParamA(), TypeParamA()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(not_equals_op))); return absl::OkStatus(); } absl::Status AddContainerOps(TypeCheckerBuilder& builder) { FunctionDecl index; index.set_name(StandardFunctions::kIndex); CEL_RETURN_IF_ERROR(index.AddOverload(MakeOverloadDecl( StandardOverloadIds::kIndexList, TypeParamA(), ListOfA(), IntType()))); CEL_RETURN_IF_ERROR(index.AddOverload(MakeOverloadDecl( StandardOverloadIds::kIndexMap, TypeParamB(), MapOfAB(), TypeParamA()))); CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(index))); FunctionDecl in_op; in_op.set_name(StandardFunctions::kIn); CEL_RETURN_IF_ERROR(in_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kInList, BoolType(), TypeParamA(), ListOfA()))); CEL_RETURN_IF_ERROR(in_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kInMap, BoolType(), TypeParamA(), MapOfAB()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(in_op))); FunctionDecl in_function_deprecated; in_function_deprecated.set_name(StandardFunctions::kInFunction); CEL_RETURN_IF_ERROR(in_function_deprecated.AddOverload(MakeOverloadDecl( StandardOverloadIds::kInList, BoolType(), TypeParamA(), ListOfA()))); CEL_RETURN_IF_ERROR(in_function_deprecated.AddOverload(MakeOverloadDecl( StandardOverloadIds::kInMap, BoolType(), TypeParamA(), MapOfAB()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(in_function_deprecated))); FunctionDecl in_op_deprecated; in_op_deprecated.set_name(StandardFunctions::kInDeprecated); CEL_RETURN_IF_ERROR(in_op_deprecated.AddOverload(MakeOverloadDecl( StandardOverloadIds::kInList, BoolType(), TypeParamA(), ListOfA()))); CEL_RETURN_IF_ERROR(in_op_deprecated.AddOverload(MakeOverloadDecl( StandardOverloadIds::kInMap, BoolType(), TypeParamA(), MapOfAB()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(in_op_deprecated))); FunctionDecl size; size.set_name(StandardFunctions::kSize); CEL_RETURN_IF_ERROR(size.AddOverload( MakeOverloadDecl(StandardOverloadIds::kSizeList, IntType(), ListOfA()))); CEL_RETURN_IF_ERROR(size.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kSizeListMember, IntType(), ListOfA()))); CEL_RETURN_IF_ERROR(size.AddOverload( MakeOverloadDecl(StandardOverloadIds::kSizeMap, IntType(), MapOfAB()))); CEL_RETURN_IF_ERROR(size.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kSizeMapMember, IntType(), MapOfAB()))); CEL_RETURN_IF_ERROR(size.AddOverload(MakeOverloadDecl( StandardOverloadIds::kSizeBytes, IntType(), BytesType()))); CEL_RETURN_IF_ERROR(size.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kSizeBytesMember, IntType(), BytesType()))); CEL_RETURN_IF_ERROR(size.AddOverload(MakeOverloadDecl( StandardOverloadIds::kSizeString, IntType(), StringType()))); CEL_RETURN_IF_ERROR(size.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kSizeStringMember, IntType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(size))); return absl::OkStatus(); } absl::Status AddRelationOps(TypeCheckerBuilder& builder) { FunctionDecl less_op; less_op.set_name(StandardFunctions::kLess); // Numeric types CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kLessInt, BoolType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kLessUint, BoolType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessDouble, BoolType(), DoubleType(), DoubleType()))); // Non-numeric types CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kLessBool, BoolType(), BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessString, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kLessBytes, BoolType(), BytesType(), BytesType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessDuration, BoolType(), DurationType(), DurationType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessTimestamp, BoolType(), TimestampType(), TimestampType()))); FunctionDecl greater_op; greater_op.set_name(StandardFunctions::kGreater); // Numeric types CEL_RETURN_IF_ERROR(greater_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kGreaterInt, BoolType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kGreaterUint, BoolType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterDouble, BoolType(), DoubleType(), DoubleType()))); // Non-numeric types CEL_RETURN_IF_ERROR(greater_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kGreaterBool, BoolType(), BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterString, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterBytes, BoolType(), BytesType(), BytesType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterDuration, BoolType(), DurationType(), DurationType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterTimestamp, BoolType(), TimestampType(), TimestampType()))); FunctionDecl less_equals_op; less_equals_op.set_name(StandardFunctions::kLessOrEqual); // Numeric types CEL_RETURN_IF_ERROR(less_equals_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kLessEqualsInt, BoolType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsUint, BoolType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsDouble, BoolType(), DoubleType(), DoubleType()))); // Non-numeric types CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsBool, BoolType(), BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsString, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsBytes, BoolType(), BytesType(), BytesType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsDuration, BoolType(), DurationType(), DurationType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsTimestamp, BoolType(), TimestampType(), TimestampType()))); FunctionDecl greater_equals_op; greater_equals_op.set_name(StandardFunctions::kGreaterOrEqual); // Numeric types CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsInt, BoolType(), IntType(), IntType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsUint, BoolType(), UintType(), UintType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsDouble, BoolType(), DoubleType(), DoubleType()))); // Non-numeric types CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsBool, BoolType(), BoolType(), BoolType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsString, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsBytes, BoolType(), BytesType(), BytesType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsDuration, BoolType(), DurationType(), DurationType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsTimestamp, BoolType(), TimestampType(), TimestampType()))); if (builder.options().enable_cross_numeric_comparisons) { // Less CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kLessIntUint, BoolType(), IntType(), UintType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessIntDouble, BoolType(), IntType(), DoubleType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload(MakeOverloadDecl( StandardOverloadIds::kLessUintInt, BoolType(), UintType(), IntType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessUintDouble, BoolType(), UintType(), DoubleType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessDoubleInt, BoolType(), DoubleType(), IntType()))); CEL_RETURN_IF_ERROR(less_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessDoubleUint, BoolType(), DoubleType(), UintType()))); // Greater CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterIntUint, BoolType(), IntType(), UintType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterIntDouble, BoolType(), IntType(), DoubleType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterUintInt, BoolType(), UintType(), IntType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterUintDouble, BoolType(), UintType(), DoubleType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterDoubleInt, BoolType(), DoubleType(), IntType()))); CEL_RETURN_IF_ERROR(greater_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterDoubleUint, BoolType(), DoubleType(), UintType()))); // LessEqual CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsIntUint, BoolType(), IntType(), UintType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsIntDouble, BoolType(), IntType(), DoubleType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsUintInt, BoolType(), UintType(), IntType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsUintDouble, BoolType(), UintType(), DoubleType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsDoubleInt, BoolType(), DoubleType(), IntType()))); CEL_RETURN_IF_ERROR(less_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kLessEqualsDoubleUint, BoolType(), DoubleType(), UintType()))); // GreaterEqual CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsIntUint, BoolType(), IntType(), UintType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsIntDouble, BoolType(), IntType(), DoubleType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsUintInt, BoolType(), UintType(), IntType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsUintDouble, BoolType(), UintType(), DoubleType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsDoubleInt, BoolType(), DoubleType(), IntType()))); CEL_RETURN_IF_ERROR(greater_equals_op.AddOverload( MakeOverloadDecl(StandardOverloadIds::kGreaterEqualsDoubleUint, BoolType(), DoubleType(), UintType()))); } CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(less_op))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(greater_op))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(less_equals_op))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(greater_equals_op))); return absl::OkStatus(); } absl::Status AddStringFunctions(TypeCheckerBuilder& builder) { FunctionDecl contains; contains.set_name(StandardFunctions::kStringContains); CEL_RETURN_IF_ERROR(contains.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kContainsString, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(contains))); FunctionDecl starts_with; starts_with.set_name(StandardFunctions::kStringStartsWith); CEL_RETURN_IF_ERROR(starts_with.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kStartsWithString, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(starts_with))); FunctionDecl ends_with; ends_with.set_name(StandardFunctions::kStringEndsWith); CEL_RETURN_IF_ERROR(ends_with.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kEndsWithString, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(ends_with))); return absl::OkStatus(); } absl::Status AddRegexFunctions(TypeCheckerBuilder& builder) { FunctionDecl matches; matches.set_name(StandardFunctions::kRegexMatch); CEL_RETURN_IF_ERROR(matches.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kMatchesMember, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(matches.AddOverload(MakeOverloadDecl( StandardOverloadIds::kMatches, BoolType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(matches))); return absl::OkStatus(); } absl::Status AddTimeFunctions(TypeCheckerBuilder& builder) { FunctionDecl get_full_year; get_full_year.set_name(StandardFunctions::kFullYear); CEL_RETURN_IF_ERROR(get_full_year.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToYear, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_full_year.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToYearWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_full_year))); FunctionDecl get_month; get_month.set_name(StandardFunctions::kMonth); CEL_RETURN_IF_ERROR(get_month.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToMonth, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_month.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToMonthWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_month))); FunctionDecl get_day_of_year; get_day_of_year.set_name(StandardFunctions::kDayOfYear); CEL_RETURN_IF_ERROR(get_day_of_year.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToDayOfYear, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_day_of_year.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDayOfYearWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_day_of_year))); FunctionDecl get_day_of_month; get_day_of_month.set_name(StandardFunctions::kDayOfMonth); CEL_RETURN_IF_ERROR(get_day_of_month.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDayOfMonth, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_day_of_month.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDayOfMonthWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_day_of_month))); FunctionDecl get_date; get_date.set_name(StandardFunctions::kDate); CEL_RETURN_IF_ERROR(get_date.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToDate, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_date.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDateWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_date))); FunctionDecl get_day_of_week; get_day_of_week.set_name(StandardFunctions::kDayOfWeek); CEL_RETURN_IF_ERROR(get_day_of_week.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToDayOfWeek, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_day_of_week.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToDayOfWeekWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_day_of_week))); FunctionDecl get_hours; get_hours.set_name(StandardFunctions::kHours); CEL_RETURN_IF_ERROR(get_hours.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToHours, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_hours.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToHoursWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(get_hours.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kDurationToHours, IntType(), DurationType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_hours))); FunctionDecl get_minutes; get_minutes.set_name(StandardFunctions::kMinutes); CEL_RETURN_IF_ERROR(get_minutes.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToMinutes, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_minutes.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToMinutesWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(get_minutes.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kDurationToMinutes, IntType(), DurationType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_minutes))); FunctionDecl get_seconds; get_seconds.set_name(StandardFunctions::kSeconds); CEL_RETURN_IF_ERROR(get_seconds.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToSeconds, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_seconds.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToSecondsWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(get_seconds.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kDurationToSeconds, IntType(), DurationType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_seconds))); FunctionDecl get_milliseconds; get_milliseconds.set_name(StandardFunctions::kMilliseconds); CEL_RETURN_IF_ERROR(get_milliseconds.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kTimestampToMilliseconds, IntType(), TimestampType()))); CEL_RETURN_IF_ERROR(get_milliseconds.AddOverload(MakeMemberOverloadDecl( StandardOverloadIds::kTimestampToMillisecondsWithTz, IntType(), TimestampType(), StringType()))); CEL_RETURN_IF_ERROR(get_milliseconds.AddOverload( MakeMemberOverloadDecl(StandardOverloadIds::kDurationToMilliseconds, IntType(), DurationType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(get_milliseconds))); return absl::OkStatus(); } absl::Status AddTypeConstantVariables(TypeCheckerBuilder& builder) { CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl(StandardFunctions::kDyn, TypeDynType()))); CEL_RETURN_IF_ERROR( builder.AddVariable(MakeVariableDecl("bool", TypeBoolType()))); CEL_RETURN_IF_ERROR( builder.AddVariable(MakeVariableDecl("null_type", TypeNullType()))); CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl(StandardFunctions::kInt, TypeIntType()))); CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl(StandardFunctions::kUint, TypeUintType()))); CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl(StandardFunctions::kDouble, TypeDoubleType()))); CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl(StandardFunctions::kString, TypeStringType()))); CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl(StandardFunctions::kBytes, TypeBytesType()))); // Note: timestamp and duration are only referenced by the corresponding // protobuf type names and handled by the type lookup logic. CEL_RETURN_IF_ERROR( builder.AddVariable(MakeVariableDecl("list", TypeListType()))); CEL_RETURN_IF_ERROR( builder.AddVariable(MakeVariableDecl("map", TypeMapType()))); CEL_RETURN_IF_ERROR( builder.AddVariable(MakeVariableDecl("type", TypeOfType()))); return absl::OkStatus(); } absl::Status AddEnumConstants(TypeCheckerBuilder& builder) { VariableDecl pb_null; pb_null.set_name("google.protobuf.NullValue.NULL_VALUE"); pb_null.set_type(IntType()); pb_null.set_value(Constant(int64_t{0})); CEL_RETURN_IF_ERROR(builder.AddVariable(std::move(pb_null))); return absl::OkStatus(); } absl::Status AddStandardLibraryDecls(TypeCheckerBuilder& builder) { CEL_RETURN_IF_ERROR(AddLogicalOps(builder)); CEL_RETURN_IF_ERROR(AddArithmeticOps(builder)); CEL_RETURN_IF_ERROR(AddTypeConversions(builder)); CEL_RETURN_IF_ERROR(AddEqualityOps(builder)); CEL_RETURN_IF_ERROR(AddContainerOps(builder)); CEL_RETURN_IF_ERROR(AddRelationOps(builder)); CEL_RETURN_IF_ERROR(AddStringFunctions(builder)); CEL_RETURN_IF_ERROR(AddRegexFunctions(builder)); CEL_RETURN_IF_ERROR(AddTimeFunctions(builder)); CEL_RETURN_IF_ERROR(AddTypeConstantVariables(builder)); CEL_RETURN_IF_ERROR(AddEnumConstants(builder)); return absl::OkStatus(); } } // namespace // Returns a CheckerLibrary containing all of the standard CEL declarations. CheckerLibrary StandardCheckerLibrary() { return {"stdlib", AddStandardLibraryDecls}; } } // namespace cel ================================================ FILE: checker/standard_library.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_STANDARD_LIBRARY_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_STANDARD_LIBRARY_H_ #include "checker/type_checker_builder.h" namespace cel { // Returns a CheckerLibrary containing all of the standard CEL declarations. CheckerLibrary StandardCheckerLibrary(); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_STANDARD_LIBRARY_H_ ================================================ FILE: checker/standard_library_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/standard_library.h" #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/constant.h" #include "common/decl.h" #include "common/type.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::Reference; using ::cel::internal::GetSharedTestingDescriptorPool; using ::testing::IsEmpty; using ::testing::Pointee; using ::testing::Property; using AstType = cel::TypeSpec; TEST(StandardLibraryTest, StandardLibraryAddsDecls) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); EXPECT_THAT(builder->Build(), IsOk()); } TEST(StandardLibraryTest, StandardLibraryErrorsIfAddedTwice) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); EXPECT_THAT(builder->AddLibrary(StandardCheckerLibrary()), StatusIs(absl::StatusCode::kAlreadyExists)); } TEST(StandardLibraryTest, ComprehensionVarsIndirectCyclicParamAssignability) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); // Note: this is atypical -- parameterized variables aren't well supported // outside of built-in syntax. // e.g. `list : Type(List(A))` is instantiated per reference to bind A to // the concrete type of a list in the same assignability context. // // Validate that parameterization is sanitized to be contextual // List(V) -> List(T%1) // Map(K, V) -> Map(T%2, T%3) Type list_type = ListType(&arena, TypeParamType("V")); Type map_type = MapType(&arena, TypeParamType("K"), TypeParamType("V")); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("list_var", list_type)), IsOk()); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("map_var", map_type)), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN( auto ast, checker_internal::MakeTestParsedAst( "list_var.exists(v," " map_var.filter(k, map_var[k] > 1.0).size() > int(v)" ")")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); } TEST(StandardLibraryTest, ComprehensionResultTypeIsSubstituted) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); // Test that type for the result list of .map is resolved to a concrete type // when it is known. Checks for a bug where the result type is considered to // still be flexible and may widen to dyn. builder->set_container("cel.expr.conformance.proto2"); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, checker_internal::MakeTestParsedAst( "[TestAllTypes{}]" ".map(x, x.repeated_nested_message[0])" ".map(x, x.bb)[0]")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); TypeSpec type = checked_ast->GetTypeOrDyn(checked_ast->root_expr().id()); EXPECT_TRUE(type.has_primitive() && type.primitive() == PrimitiveType::kInt64); } class StandardLibraryDefinitionsTest : public ::testing::Test { public: void SetUp() override { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(stdlib_type_checker_, builder->Build()); } protected: std::unique_ptr stdlib_type_checker_; }; class StdlibTypeVarDefinitionTest : public StandardLibraryDefinitionsTest, public testing::WithParamInterface {}; TEST_P(StdlibTypeVarDefinitionTest, DefinesTypeConstants) { auto ast = std::make_unique(); ast->mutable_root_expr().mutable_ident_expr().set_name(GetParam()); ast->mutable_root_expr().set_id(1); ASSERT_OK_AND_ASSIGN(ValidationResult result, stdlib_type_checker_->Check(std::move(ast))); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->GetReference(1), Pointee(Property(&Reference::name, GetParam()))); EXPECT_THAT(checked_ast->GetTypeOrDyn(1), Property(&AstType::has_type, true)); } INSTANTIATE_TEST_SUITE_P(StdlibTypeVarDefinitions, StdlibTypeVarDefinitionTest, ::testing::Values("bool", "bytes", "double", "dyn", "int", "list", "map", "null_type", "string", "type", "uint"), [](const auto& info) -> std::string { return info.param; }); TEST_F(StandardLibraryDefinitionsTest, DefinesProtoStructNull) { auto ast = std::make_unique(); auto& enumerator = ast->mutable_root_expr(); enumerator.set_id(4); enumerator.mutable_select_expr().set_field("NULL_VALUE"); auto& enumeration = enumerator.mutable_select_expr().mutable_operand(); enumeration.set_id(3); enumeration.mutable_select_expr().set_field("NullValue"); auto& protobuf = enumeration.mutable_select_expr().mutable_operand(); protobuf.set_id(2); protobuf.mutable_select_expr().set_field("protobuf"); auto& google = protobuf.mutable_select_expr().mutable_operand(); google.set_id(1); google.mutable_ident_expr().set_name("google"); ASSERT_OK_AND_ASSIGN(ValidationResult result, stdlib_type_checker_->Check(std::move(ast))); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->GetReference(4), Pointee(Property(&Reference::name, "google.protobuf.NullValue.NULL_VALUE"))); } TEST_F(StandardLibraryDefinitionsTest, DefinesTypeType) { auto ast = std::make_unique(); auto& ident = ast->mutable_root_expr(); ident.set_id(1); ident.mutable_ident_expr().set_name("type"); ASSERT_OK_AND_ASSIGN(ValidationResult result, stdlib_type_checker_->Check(std::move(ast))); EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); EXPECT_THAT(checked_ast->GetReference(1), Pointee(Property(&Reference::name, "type"))); EXPECT_THAT(checked_ast->GetTypeOrDyn(1), Property(&AstType::has_type, true)); } struct DefinitionsTestCase { std::string expr; bool type_check_success = true; CheckerOptions options; }; class StdLibDefinitionsTest : public ::testing::TestWithParam { public: }; // Basic coverage that the standard library definitions are defined. // This is not intended to be exhaustive since it is expected to be covered by // spec conformance tests. // // TODO(uncreated-issue/72): Tests are fairly minimal right now -- it's not possible to // test thoroughly without a more complete implementation of the type checker. // Type-parameterized functions are not yet checkable. TEST_P(StdLibDefinitionsTest, Runner) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), GetParam().options)); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, checker_internal::MakeTestParsedAst(GetParam().expr)); ASSERT_OK_AND_ASSIGN(auto result, type_checker->Check(std::move(ast))); EXPECT_EQ(result.IsValid(), GetParam().type_check_success); } INSTANTIATE_TEST_SUITE_P( Strings, StdLibDefinitionsTest, ::testing::Values(DefinitionsTestCase{ /* .expr = */ "'123'.size()", }, DefinitionsTestCase{ /* .expr = */ "size('123')", }, DefinitionsTestCase{ /* .expr = */ "'123' + '123'", }, DefinitionsTestCase{ /* .expr = */ "'123' + '123'", }, DefinitionsTestCase{ /* .expr = */ "'123' + '123'", }, DefinitionsTestCase{ /* .expr = */ "'123'.endsWith('123')", }, DefinitionsTestCase{ /* .expr = */ "'123'.startsWith('123')", }, DefinitionsTestCase{ /* .expr = */ "'123'.contains('123')", }, DefinitionsTestCase{ /* .expr = */ "'123'.matches(r'123')", }, DefinitionsTestCase{ /* .expr = */ "matches('123', r'123')", })); INSTANTIATE_TEST_SUITE_P(TypeCasts, StdLibDefinitionsTest, ::testing::Values(DefinitionsTestCase{ /* .expr = */ "int(1)", }, DefinitionsTestCase{ /* .expr = */ "uint(1)", }, DefinitionsTestCase{ /* .expr = */ "double(1)", }, DefinitionsTestCase{ /* .expr = */ "string(1)", }, DefinitionsTestCase{ /* .expr = */ "bool('true')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0)", }, DefinitionsTestCase{ /* .expr = */ "duration('1s')", }, DefinitionsTestCase{ /* .expr = */ "type(1)", })); INSTANTIATE_TEST_SUITE_P(Arithmetic, StdLibDefinitionsTest, ::testing::Values(DefinitionsTestCase{ /* .expr = */ "1 + 2", }, DefinitionsTestCase{ /* .expr = */ "1 - 2", }, DefinitionsTestCase{ /* .expr = */ "1 / 2", }, DefinitionsTestCase{ /* .expr = */ "1 * 2", }, DefinitionsTestCase{ /* .expr = */ "2 % 1", }, DefinitionsTestCase{ /* .expr = */ "-1", })); INSTANTIATE_TEST_SUITE_P( TimeArithmetic, StdLibDefinitionsTest, ::testing::Values(DefinitionsTestCase{ /* .expr = */ "timestamp(0) + duration('1s')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0) - duration('1s')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0) - timestamp(0)", }, DefinitionsTestCase{ /* .expr = */ "duration('1s') + duration('1s')", }, DefinitionsTestCase{ /* .expr = */ "duration('1s') - duration('1s')", })); INSTANTIATE_TEST_SUITE_P(NumericComparisons, StdLibDefinitionsTest, ::testing::Values(DefinitionsTestCase{ /* .expr = */ "1 > 2", }, DefinitionsTestCase{ /* .expr = */ "1 < 2", }, DefinitionsTestCase{ /* .expr = */ "1 >= 2", }, DefinitionsTestCase{ /* .expr = */ "1 <= 2", })); INSTANTIATE_TEST_SUITE_P( CrossNumericComparisons, StdLibDefinitionsTest, ::testing::Values( DefinitionsTestCase{ /* .expr = */ "1u < 2", /* .type_check_success = */ true, /* .options = */ {.enable_cross_numeric_comparisons = true}}, DefinitionsTestCase{ /* .expr = */ "1u > 2", /* .type_check_success = */ true, /* .options = */ {.enable_cross_numeric_comparisons = true}}, DefinitionsTestCase{ /* .expr = */ "1u <= 2", /* .type_check_success = */ true, /* .options = */ {.enable_cross_numeric_comparisons = true}}, DefinitionsTestCase{ /* .expr = */ "1u >= 2", /* .type_check_success = */ true, /* .options = */ {.enable_cross_numeric_comparisons = true}})); INSTANTIATE_TEST_SUITE_P( TimeComparisons, StdLibDefinitionsTest, ::testing::Values(DefinitionsTestCase{ /* .expr = */ "duration('1s') < duration('1s')", }, DefinitionsTestCase{ /* .expr = */ "duration('1s') > duration('1s')", }, DefinitionsTestCase{ /* .expr = */ "duration('1s') <= duration('1s')", }, DefinitionsTestCase{ /* .expr = */ "duration('1s') >= duration('1s')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0) < timestamp(0)", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0) > timestamp(0)", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0) <= timestamp(0)", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0) >= timestamp(0)", })); INSTANTIATE_TEST_SUITE_P( TimeAccessors, StdLibDefinitionsTest, ::testing::Values( DefinitionsTestCase{ /* .expr = */ "timestamp(0).getFullYear()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getFullYear('-08:00')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getMonth()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getMonth('-08:00')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getDayOfYear()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getDayOfYear('-08:00')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getDate()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getDate('-08:00')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getDayOfWeek()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getDayOfWeek('-08:00')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getHours()", }, DefinitionsTestCase{ /* .expr = */ "duration('1s').getHours()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getHours('-08:00')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getMinutes()", }, DefinitionsTestCase{ /* .expr = */ "duration('1s').getMinutes()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getMinutes('-08:00')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getSeconds()", }, DefinitionsTestCase{ /* .expr = */ "duration('1s').getSeconds()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getSeconds('-08:00')", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getMilliseconds()", }, DefinitionsTestCase{ /* .expr = */ "duration('1s').getMilliseconds()", }, DefinitionsTestCase{ /* .expr = */ "timestamp(0).getMilliseconds('-08:00')", })); INSTANTIATE_TEST_SUITE_P(Logic, StdLibDefinitionsTest, ::testing::Values(DefinitionsTestCase{ /* .expr = */ "true || false", }, DefinitionsTestCase{ /* .expr = */ "true && false", }, DefinitionsTestCase{ /* .expr = */ "!true", }, DefinitionsTestCase{ /* .expr = */ "true ? 1 : 2", })); } // namespace } // namespace cel ================================================ FILE: checker/type_check_issue.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/type_check_issue.h" #include #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "common/source.h" namespace cel { namespace { absl::string_view SeverityString(TypeCheckIssue::Severity severity) { switch (severity) { case TypeCheckIssue::Severity::kInformation: return "INFORMATION"; case TypeCheckIssue::Severity::kWarning: return "WARNING"; case TypeCheckIssue::Severity::kError: return "ERROR"; case TypeCheckIssue::Severity::kDeprecated: return "DEPRECATED"; default: return "SEVERITY_UNSPECIFIED"; } } } // namespace std::string TypeCheckIssue::ToDisplayString(const Source* source) const { int column = location_.column; // convert to 1-based if it's in range. int display_column = column >= 0 ? column + 1 : column; if (source) { return absl::StrFormat("%s: %s:%d:%d: %s%s", SeverityString(severity_), source->description(), location_.line, display_column, message_, source->DisplayErrorLocation(location_)); } return absl::StrFormat("%s: :%d:%d: %s", SeverityString(severity_), location_.line, display_column, message_); } } // namespace cel ================================================ FILE: checker/type_check_issue.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECK_ISSUE_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECK_ISSUE_H_ #include #include #include "absl/strings/string_view.h" #include "common/source.h" namespace cel { // Represents a single issue identified in type checking. class TypeCheckIssue { public: enum class Severity { kError, kWarning, kInformation, kDeprecated }; TypeCheckIssue(Severity severity, SourceLocation location, std::string message) : severity_(severity), location_(location), message_(std::move(message)) {} // Factory for error-severity issues. static TypeCheckIssue CreateError(SourceLocation location, std::string message) { return TypeCheckIssue(Severity::kError, location, std::move(message)); } // Factory for error-severity issues. // line is 1-based, column is 0-based. static TypeCheckIssue CreateError(int line, int column, std::string message) { return TypeCheckIssue(Severity::kError, SourceLocation{line, column}, std::move(message)); } // Format the issue highlighting the source position. std::string ToDisplayString(const Source* source) const; std::string ToDisplayString(const Source& source) const { return ToDisplayString(&source); } absl::string_view message() const { return message_; } Severity severity() const { return severity_; } SourceLocation location() const { return location_; } private: Severity severity_; SourceLocation location_; std::string message_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECK_ISSUE_H_ ================================================ FILE: checker/type_check_issue_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/type_check_issue.h" #include "common/source.h" #include "internal/testing.h" namespace cel { namespace { TEST(TypeCheckIssueTest, DisplayString) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("test{\n\tfield1: 123\n}")); TypeCheckIssue issue = TypeCheckIssue::CreateError(2, 2, "test error"); // Note: The column is displayed as 1 based to match the Go checker. EXPECT_EQ(issue.ToDisplayString(*source), "ERROR: :2:3: test error\n" " | field1: 123\n" " | ..^"); } TEST(TypeCheckIssueTest, DisplayStringNoPosition) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("test{\n\tfield1: 123\n}")); TypeCheckIssue issue = TypeCheckIssue::CreateError(-1, -1, "test error"); EXPECT_EQ(issue.ToDisplayString(*source), "ERROR: :-1:-1: test error"); } TEST(TypeCheckIssueTest, DisplayStringDeprecated) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("test{\n\tfield1: 123\n}")); TypeCheckIssue issue = TypeCheckIssue(TypeCheckIssue::Severity::kDeprecated, {-1, -1}, "test error 2"); EXPECT_EQ(issue.ToDisplayString(*source), "DEPRECATED: :-1:-1: test error 2"); } } // namespace } // namespace cel ================================================ FILE: checker/type_checker.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/type_checker.h" namespace cel { absl::StatusOr TypeChecker::Check( std::unique_ptr ast) const { return CheckImpl(std::move(ast), nullptr); } absl::StatusOr TypeChecker::Check( std::unique_ptr ast, google::protobuf::Arena* arena) const { return CheckImpl(std::move(ast), arena); } absl::StatusOr TypeChecker::Check(const Ast& ast) const { return CheckImpl(std::make_unique(ast), nullptr); } absl::StatusOr TypeChecker::Check( const Ast& ast, google::protobuf::Arena* arena) const { return CheckImpl(std::make_unique(ast), arena); } } // namespace cel ================================================ FILE: checker/type_checker.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ #include #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "checker/validation_result.h" #include "common/ast.h" #include "google/protobuf/arena.h" namespace cel { class TypeCheckerBuilder; // TypeChecker interface. // // Checks references and type agreement for a parsed CEL expression. // // See Compiler for bundled parse and type check from a source expression // string. class TypeChecker { public: virtual ~TypeChecker() = default; // Checks the references and type agreement of the given parsed expression // based on the configured CEL environment. // // Most type checking errors are returned as Issues in the validation result. // A non-ok status is returned if type checking can't reasonably complete // (e.g. if an internal precondition is violated or an extension returns an // error). absl::StatusOr Check(std::unique_ptr ast) const; absl::StatusOr Check(std::unique_ptr ast, google::protobuf::Arena* arena) const; absl::StatusOr Check(const Ast& ast) const; absl::StatusOr Check(const Ast& ast, google::protobuf::Arena* arena) const; // Returns a builder initialized with the configuration of this type checker. virtual std::unique_ptr ToBuilder() const = 0; private: virtual absl::StatusOr CheckImpl( std::unique_ptr ast, google::protobuf::Arena* absl_nullable arena) const = 0; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ ================================================ FILE: checker/type_checker_builder.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ #include #include #include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/type_checker.h" #include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { class TypeCheckerBuilder; // Functional implementation to apply the library features to a // TypeCheckerBuilder. using TypeCheckerBuilderConfigurer = absl::AnyInvocable; struct CheckerLibrary { // Optional identifier to avoid collisions re-adding the same declarations. // If id is empty, it is not considered. std::string id; TypeCheckerBuilderConfigurer configure; }; // Represents a declaration to only use a subset of a library. struct TypeCheckerSubset { using FunctionPredicate = absl::AnyInvocable; // The id of the library to subset. Only one subset can be applied per // library id. // // Must be non-empty. std::string library_id; // Predicate to apply to function overloads. If true, the overload will be // included in the subset. If no overload for a function is included, the // entire function is excluded. FunctionPredicate should_include_overload; }; // Interface for TypeCheckerBuilders. class TypeCheckerBuilder { public: virtual ~TypeCheckerBuilder() = default; // Adds a library to the TypeChecker being built. // // Libraries are applied in the order they are added. They effectively // apply before any direct calls to AddVariable, AddFunction, etc. virtual absl::Status AddLibrary(CheckerLibrary library) = 0; // Adds a subset declaration for a library to the TypeChecker being built. // // At most one subset can be applied per library id. virtual absl::Status AddLibrarySubset(TypeCheckerSubset subset) = 0; // Adds a variable declaration that may be referenced in expressions checked // with the resulting type checker. virtual absl::Status AddVariable(const VariableDecl& decl) = 0; // Adds a variable declaration that may be referenced in expressions checked // with the resulting type checker. // // This version replaces any existing variable declaration with the same name. virtual absl::Status AddOrReplaceVariable(const VariableDecl& decl) = 0; // Declares struct type by fully qualified name as a context declaration. // // Context declarations are a way to declare a group of variables based on the // definition of a struct type. Each top level field of the struct is declared // as an individual variable of the field type. // // It is an error if the type contains a field that overlaps with another // declared variable. // // Note: only protobuf backed struct types are supported at this time. virtual absl::Status AddContextDeclaration(absl::string_view type) = 0; // Adds a function declaration that may be referenced in expressions checked // with the resulting TypeChecker. virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; // Adds function declaration overloads to the TypeChecker being built. // // Attempts to merge with any existing overloads for a function decl with the // same name. If the overloads are not compatible, an error is returned and // no change is made. virtual absl::Status MergeFunction(const FunctionDecl& decl) = 0; // Sets the expected type for checked expressions. // // Validation will fail with an ERROR level issue if the deduced type of the // expression is not assignable to this type. // // Note: if set multiple times, the last value is used. virtual void SetExpectedType(const Type& type) = 0; // Adds a type provider to the TypeChecker being built. // // Type providers are used to describe custom types with typed field // traversal. This is not needed for built-in types or protobuf messages // described by the associated descriptor pool. virtual void AddTypeProvider(std::unique_ptr provider) = 0; // Set the container for the TypeChecker being built. // // This is used for resolving references in the expressions being built. // // Prefer setting the container via SetExpressionContainer(). // // Note: if set multiple times, the last value is used. This can lead to // surprising behavior if used in a custom library. If container is not a // valid container name, the operation is ignored. virtual void set_container(absl::string_view container) = 0; virtual void SetExpressionContainer( ExpressionContainer expression_container) = 0; // The current options for the TypeChecker being built. virtual const CheckerOptions& options() const = 0; // Builds a new TypeChecker instance. virtual absl::StatusOr> Build() = 0; // Returns a pointer to an arena that can be used to allocate memory for types // that will be used by the TypeChecker being built. // // On Build(), the arena is transferred to the TypeChecker being built. virtual google::protobuf::Arena* absl_nonnull arena() = 0; // The configured descriptor pool. virtual const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const = 0; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_H_ ================================================ FILE: checker/type_checker_builder_factory.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/type_checker_builder_factory.h" #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "checker/checker_options.h" #include "checker/internal/type_checker_builder_impl.h" #include "checker/type_checker_builder.h" #include "internal/noop_delete.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/descriptor.h" namespace cel { absl::StatusOr> CreateTypeCheckerBuilder( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, const CheckerOptions& options) { ABSL_DCHECK(descriptor_pool != nullptr); return CreateTypeCheckerBuilder( std::shared_ptr( descriptor_pool, internal::NoopDeleteFor()), options); } absl::StatusOr> CreateTypeCheckerBuilder( absl_nonnull std::shared_ptr descriptor_pool, const CheckerOptions& options) { ABSL_DCHECK(descriptor_pool != nullptr); // Verify the standard descriptors, we do not need to keep // `well_known_types::Reflection` at the moment here. CEL_RETURN_IF_ERROR( well_known_types::Reflection().Initialize(descriptor_pool.get())); return std::make_unique( std::move(descriptor_pool), options); } } // namespace cel ================================================ FILE: checker/type_checker_builder_factory.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "checker/checker_options.h" #include "checker/type_checker_builder.h" #include "google/protobuf/descriptor.h" namespace cel { // Creates a new `TypeCheckerBuilder`. // // The builder implementation is thread-hostile and should only be used from a // single thread, but the resulting `TypeChecker` instance is thread-safe. // // When passing a raw pointer to a descriptor pool, the descriptor pool must // outlive the type checker builder and the type checker builder it creates. // // The descriptor pool must include the minimally necessary // descriptors required by CEL. Those are the following: // - google.protobuf.NullValue // - google.protobuf.BoolValue // - google.protobuf.Int32Value // - google.protobuf.Int64Value // - google.protobuf.UInt32Value // - google.protobuf.UInt64Value // - google.protobuf.FloatValue // - google.protobuf.DoubleValue // - google.protobuf.BytesValue // - google.protobuf.StringValue // - google.protobuf.Any // - google.protobuf.Duration // - google.protobuf.Timestamp absl::StatusOr> CreateTypeCheckerBuilder( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, const CheckerOptions& options = {}); absl::StatusOr> CreateTypeCheckerBuilder( absl_nonnull std::shared_ptr descriptor_pool, const CheckerOptions& options = {}); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_BUILDER_FACTORY_H_ ================================================ FILE: checker/type_checker_builder_factory_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/type_checker_builder_factory.h" #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::checker_internal::MakeTestParsedAst; using ::cel::internal::GetSharedTestingDescriptorPool; using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::Truly; TEST(TypeCheckerBuilderTest, AddVariable) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } TEST(TypeCheckerBuilderTest, AddComplexType) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); MapType map_type(builder->arena(), StringType(), IntType()); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("m", map_type)), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); builder.reset(); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("m.foo")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } TEST(TypeCheckerBuilderTest, TypeCheckersIndependent) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); MapType map_type(builder->arena(), StringType(), IntType()); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("m", map_type)), IsOk()); ASSERT_OK_AND_ASSIGN( FunctionDecl fn, MakeFunctionDecl( "foo", MakeOverloadDecl("foo", IntType(), IntType(), IntType()))); ASSERT_THAT(builder->AddFunction(std::move(fn)), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker1, builder->Build()); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("ns.m2", map_type)), IsOk()); builder->set_container("ns"); ASSERT_OK_AND_ASSIGN(auto checker2, builder->Build()); // Test for lifetime issues between separate type checker instances from the // same builder. builder.reset(); { ASSERT_OK_AND_ASSIGN(auto ast1, MakeTestParsedAst("foo(m.bar, m.bar)")); ASSERT_OK_AND_ASSIGN(auto ast2, MakeTestParsedAst("foo(m.bar, m2.bar)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker1->Check(std::move(ast1))); EXPECT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(ValidationResult result2, checker1->Check(std::move(ast2))); EXPECT_FALSE(result2.IsValid()); } checker1.reset(); { ASSERT_OK_AND_ASSIGN(auto ast1, MakeTestParsedAst("foo(m.bar, m.bar)")); ASSERT_OK_AND_ASSIGN(auto ast2, MakeTestParsedAst("foo(m.bar, m2.bar)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker2->Check(std::move(ast1))); EXPECT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(ValidationResult result2, checker2->Check(std::move(ast2))); EXPECT_TRUE(result2.IsValid()); } } TEST(TypeCheckerBuilderTest, AddVariableRedeclaredError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); // We resolve the variable declarations at the Build() call, so the error // surfaces then. ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); EXPECT_THAT(builder->Build(), StatusIs(absl::StatusCode::kAlreadyExists, "variable 'x' declared multiple times")); } TEST(TypeCheckerBuilderTest, AddFunction) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } TEST(TypeCheckerBuilderTest, AddFunctionRedeclaredError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); EXPECT_THAT(builder->Build(), StatusIs(absl::StatusCode::kAlreadyExists, "function 'add' declared multiple times")); } TEST(TypeCheckerBuilderTest, AddLibrary) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); ASSERT_THAT(builder->AddLibrary({"", [&](TypeCheckerBuilder& b) { return builder->AddFunction(fn_decl); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } // Example test lib that adds: // - add(int, int) -> int // - add(double, double) -> double // - sub(int, int) -> int // - sub(double, double) -> double absl::Status SubsetTestlibConfigurer(TypeCheckerBuilder& builder) { absl::Status s; CEL_ASSIGN_OR_RETURN( FunctionDecl fn_decl, MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()), MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(fn_decl))); CEL_ASSIGN_OR_RETURN( fn_decl, MakeFunctionDecl( "sub", MakeOverloadDecl("sub_int", IntType(), IntType(), IntType()), MakeOverloadDecl("sub_double", DoubleType(), DoubleType(), DoubleType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(fn_decl))); return absl::OkStatus(); } CheckerLibrary SubsetTestlib() { return {"testlib", SubsetTestlibConfigurer}; } TEST(TypeCheckerBuilderTest, AddLibraryIncludeSubset) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); ASSERT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view /*function*/, absl::string_view overload_id) { return (overload_id == "add_int" || overload_id == "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); std::vector results; for (const auto& expr : {"sub(1, 2)", "add(1, 2)", "sub(1.0, 2.0)", "add(1.0, 2.0)"}) { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); results.push_back(std::move(result)); } ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_THAT(results, ElementsAre(Truly([](const ValidationResult& result) { return result.IsValid(); }), Truly([](const ValidationResult& result) { return result.IsValid(); }), Truly([](const ValidationResult& result) { return !result.IsValid(); }), Truly([](const ValidationResult& result) { return !result.IsValid(); }))); } TEST(TypeCheckerBuilderTest, AddLibraryExcludeSubset) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); ASSERT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view /*function*/, absl::string_view overload_id) { return (overload_id != "add_int" && overload_id != "sub_int"); ; }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); std::vector results; for (const auto& expr : {"sub(1, 2)", "add(1, 2)", "sub(1.0, 2.0)", "add(1.0, 2.0)"}) { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); results.push_back(std::move(result)); } ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_THAT(results, ElementsAre(Truly([](const ValidationResult& result) { return !result.IsValid(); }), Truly([](const ValidationResult& result) { return !result.IsValid(); }), Truly([](const ValidationResult& result) { return result.IsValid(); }), Truly([](const ValidationResult& result) { return result.IsValid(); }))); } TEST(TypeCheckerBuilderTest, AddLibrarySubsetRemoveAllOvl) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); ASSERT_THAT(builder->AddLibrarySubset({"testlib", [](absl::string_view function, absl::string_view /*overload_id*/) { return function != "add"; }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); std::vector results; for (const auto& expr : {"sub(1, 2)", "add(1, 2)", "sub(1.0, 2.0)", "add(1.0, 2.0)"}) { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(expr)); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); results.push_back(std::move(result)); } ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("add(1, 2)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_THAT(results, ElementsAre(Truly([](const ValidationResult& result) { return result.IsValid(); }), Truly([](const ValidationResult& result) { return !result.IsValid(); }), Truly([](const ValidationResult& result) { return result.IsValid(); }), Truly([](const ValidationResult& result) { return !result.IsValid(); }))); } TEST(TypeCheckerBuilderTest, AddLibraryOneSubsetPerLibraryId) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); ASSERT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, absl::string_view /*overload_id*/) { return true; }}), IsOk()); EXPECT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, absl::string_view /*overload_id*/) { return true; }}), StatusIs(absl::StatusCode::kAlreadyExists)); } TEST(TypeCheckerBuilderTest, AddLibrarySubsetLibraryIdRequireds) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); EXPECT_THAT(builder->AddLibrarySubset({"", [](absl::string_view function, absl::string_view /*overload_id*/) { return function == "add"; }}), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(TypeCheckerBuilderTest, AddContextDeclaration) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl("increment", MakeOverloadDecl("increment_int", IntType(), IntType()))); ASSERT_THAT(builder->AddContextDeclaration( "cel.expr.conformance.proto3.TestAllTypes"), IsOk()); ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("increment(single_int64)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } TEST(TypeCheckerBuilderTest, WellKnownTypeContextDeclarationError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Any"), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'google.protobuf.Any' is not a struct"))); } TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclaration) { CheckerOptions options; options.allow_well_known_type_context_declarations = true; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Any"), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst( R"cel(value == b'' && type_url == 'type.googleapis.com/google.protobuf.Duration')cel")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); } TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationStruct) { CheckerOptions options; options.allow_well_known_type_context_declarations = true; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Struct"), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst(R"cel(fields.foo.bar_list.exists(x, x == 1))cel")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); } TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationValue) { CheckerOptions options; options.allow_well_known_type_context_declarations = true; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Value"), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst( // Note: one of fields are all added with safe traversal, so // we lose the union discriminator information. R"cel( null_value == 0 && number_value == 0.0 && string_value == '' && list_value == [] && struct_value == {} && bool_value == false)cel")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); } TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationInt64Value) { CheckerOptions options; options.allow_well_known_type_context_declarations = true; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); ASSERT_THAT(builder->AddContextDeclaration("google.protobuf.Int64Value"), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(R"cel(value == 0)cel")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); } TEST(TypeCheckerBuilderTest, ContextDeclarationWithJsonName) { CheckerOptions options; options.use_json_field_names = true; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); ASSERT_THAT(builder->AddContextDeclaration("cel.cpp.testutil.TestJsonNames"), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( R"cel(int32_snake_case_json_name == 1 && int64CamelCaseJsonName == 2 && uint32DefaultJsonName == 3u && // `uint64-custom-json-name` == 4u && single_string == 'shadows' && singleString == 'shadowed')cel")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); EXPECT_THAT( checked_ast->source_info().extensions(), ElementsAre(cel::ExtensionSpec( "json_name", std::make_unique(1, 1), {cel::ExtensionSpec::Component::kRuntime}))); } TEST(TypeCheckerBuilderTest, JsonFieldNameOptionStructCreation) { CheckerOptions options; options.use_json_field_names = true; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( R"cel(cel.cpp.testutil.TestJsonNames{ int32_snake_case_json_name: 1, int64CamelCaseJsonName: 2, uint32DefaultJsonName: 3u, `uint64-custom-json-name`: 4u, single_string: 'shadows', singleString: 'shadowed' })cel")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(MessageTypeSpec("cel.cpp.testutil.TestJsonNames"))); EXPECT_THAT( checked_ast->source_info().extensions(), ElementsAre(cel::ExtensionSpec( "json_name", std::make_unique(1, 1), {cel::ExtensionSpec::Component::kRuntime}))); } TEST(TypeCheckerBuilderTest, JsonFieldNameOptionFieldAccess) { CheckerOptions options; options.use_json_field_names = true; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT( builder->AddVariable(MakeVariableDecl( "jsonObj", cel::MessageType(builder->descriptor_pool()->FindMessageTypeByName( "cel.cpp.testutil.TestJsonNames")))), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, builder->Build()); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( R"cel( jsonObj.int32_snake_case_json_name == 1 && jsonObj.int64CamelCaseJsonName == 2 && jsonObj.uint32DefaultJsonName == 3u && jsonObj.`uint64-custom-json-name` == 4u && jsonObj.single_string == 'shadows' && jsonObj.singleString == 'shadowed' && jsonObj.`cel.cpp.testutil.int32_snake_case_ext` == 5 && jsonObj.`cel.cpp.testutil.int64CamelCaseExt` == 6 )cel")); ASSERT_OK_AND_ASSIGN(ValidationResult result, type_checker->Check(std::move(ast))); ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); EXPECT_THAT( checked_ast->source_info().extensions(), ElementsAre(cel::ExtensionSpec( "json_name", std::make_unique(1, 1), {cel::ExtensionSpec::Component::kRuntime}))); } TEST(TypeCheckerBuilderTest, AddLibraryRedeclaredError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); ASSERT_THAT(builder->AddLibrary({"testlib", [&](TypeCheckerBuilder& b) { return builder->AddFunction(fn_decl); }}), IsOk()); EXPECT_THAT(builder->AddLibrary({"testlib", [&](TypeCheckerBuilder& b) { return builder->AddFunction(fn_decl); }}), StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("testlib"))); } TEST(TypeCheckerBuilderTest, BuildForwardsLibraryErrors) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl( "add", MakeOverloadDecl("add_int", IntType(), IntType(), IntType()))); ASSERT_THAT(builder->AddLibrary({"", [&](TypeCheckerBuilder& b) { return builder->AddFunction(fn_decl); }}), IsOk()); ASSERT_THAT(builder->AddLibrary({"", [](TypeCheckerBuilder& b) { return absl::InternalError("test error"); }}), IsOk()); EXPECT_THAT(builder->Build(), StatusIs(absl::StatusCode::kInternal, "test error")); } TEST(TypeCheckerBuilderTest, AddFunctionOverlapsWithStdMacroError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl("map", MakeMemberOverloadDecl( "ovl_3", ListType(), ListType(), DynType(), DynType()))); EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'map' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("filter"); EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'filter' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("exists"); EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'exists' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("exists_one"); EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'exists_one' with 3 argument(s) " "overlaps with predefined macro")); fn_decl.set_name("all"); EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'all' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("optMap"); EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'optMap' with 3 argument(s) overlaps " "with predefined macro")); fn_decl.set_name("optFlatMap"); EXPECT_THAT( builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'optFlatMap' with 3 argument(s) overlaps " "with predefined macro")); ASSERT_OK_AND_ASSIGN( fn_decl, MakeFunctionDecl( "has", MakeOverloadDecl("ovl_1", BoolType(), DynType()))); EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'has' with 1 argument(s) overlaps " "with predefined macro")); ASSERT_OK_AND_ASSIGN( fn_decl, MakeFunctionDecl("map", MakeMemberOverloadDecl( "ovl_4", ListType(), ListType(), DynType(), DynType(), DynType()))); EXPECT_THAT(builder->AddFunction(fn_decl), StatusIs(absl::StatusCode::kInvalidArgument, "overload for name 'map' with 4 argument(s) overlaps " "with predefined macro")); } TEST(TypeCheckerBuilderTest, AddFunctionNoOverlapWithStdMacroError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl("has", MakeMemberOverloadDecl("ovl", BoolType(), DynType(), StringType()))); EXPECT_THAT(builder->AddFunction(fn_decl), IsOk()); } TEST(TypeCheckerBuilderTest, ToBuilderIndependenceAndInheritance) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); ASSERT_OK_AND_ASSIGN( auto fn_decl, MakeFunctionDecl("addOne", MakeOverloadDecl("addOne_int", IntType(), IntType()))); ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker1, builder->Build()); // Exercise checker1. { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("addOne(x)")); ASSERT_OK_AND_ASSIGN(ValidationResult result1, checker1->Check(std::move(ast))); EXPECT_TRUE(result1.IsValid()); } // Start new builder via ToBuilder. auto builder2 = checker1->ToBuilder(); ASSERT_THAT(builder2->AddVariable(MakeVariableDecl("y", IntType())), IsOk()); ASSERT_THAT(builder2->AddLibrary(OptionalCheckerLibrary()), IsOk()); builder2->SetExpectedType(IntType()); ASSERT_OK_AND_ASSIGN(auto checker2, builder2->Build()); { ASSERT_OK_AND_ASSIGN( auto ast, MakeTestParsedAst("optional.of(addOne(x)).orValue(0) + y")); ASSERT_OK_AND_ASSIGN(ValidationResult result2, checker2->Check(std::move(ast))); EXPECT_TRUE(result2.IsValid()); } // Demonstrate checker1 is unmodified and independent (still does not know // about y). { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("y")); ASSERT_OK_AND_ASSIGN(ValidationResult result_y_checker1_again, checker1->Check(std::move(ast))); EXPECT_FALSE(result_y_checker1_again.IsValid()); } // Same for optional library functions. { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("optional.none().orValue(x)")); ASSERT_OK_AND_ASSIGN(ValidationResult result, checker1->Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); } } } // namespace } // namespace cel ================================================ FILE: checker/type_checker_subset_factory.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/type_checker_subset_factory.h" #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/type_checker_builder.h" namespace cel { TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( absl::string_view /*function*/, absl::string_view overload_id) { return overload_ids.contains(overload_id); }; } TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( absl::Span overload_ids) { return IncludeOverloadsByIdPredicate(absl::flat_hash_set( overload_ids.begin(), overload_ids.end())); } TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( absl::string_view /*function*/, absl::string_view overload_id) { return !overload_ids.contains(overload_id); }; } TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( absl::Span overload_ids) { return ExcludeOverloadsByIdPredicate(absl::flat_hash_set( overload_ids.begin(), overload_ids.end())); } } // namespace cel ================================================ FILE: checker/type_checker_subset_factory.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Factory functions for creating typical type checker library subsets. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_SUBSET_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_SUBSET_FACTORY_H_ #include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/type_checker_builder.h" namespace cel { // Subsets a type checker library to only include the given overload ids. TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids); TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( absl::Span overload_ids); // Subsets a type checker library to exclude the given overload ids. TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids); TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( absl::Span overload_ids); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_SUBSET_FACTORY_H_ ================================================ FILE: checker/type_checker_subset_factory_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/type_checker_subset_factory.h" #include #include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" #include "checker/validation_result.h" #include "common/standard_definitions.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" using ::absl_testing::IsOk; namespace cel { namespace { TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); absl::string_view allowlist[] = { StandardOverloadIds::kNot, StandardOverloadIds::kAnd, StandardOverloadIds::kOr, StandardOverloadIds::kConditional, StandardOverloadIds::kEquals, StandardOverloadIds::kNotEquals, StandardOverloadIds::kNotStrictlyFalse, }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ "stdlib", IncludeOverloadsByIdPredicate(allowlist), }), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); ASSERT_OK_AND_ASSIGN( ValidationResult r, compiler->Compile( "!true || !false && (false) ? true : false && 1 == 2 || 3.0 != 2.1")); EXPECT_TRUE(r.IsValid()); ASSERT_OK_AND_ASSIGN( r, compiler->Compile("[true, false, true, false].exists(x, x && !x)")); EXPECT_TRUE(r.IsValid()); // Not in allowlist. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); EXPECT_FALSE(r.IsValid()); } TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); absl::string_view exclude_list[] = { StandardOverloadIds::kMatches, StandardOverloadIds::kMatchesMember, }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ "stdlib", ExcludeOverloadsByIdPredicate(exclude_list), }), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); ASSERT_OK_AND_ASSIGN( ValidationResult r, compiler->Compile( "!true || !false && (false) ? true : false && 1 == 2 || 3.0 != 2.1")); EXPECT_TRUE(r.IsValid()); ASSERT_OK_AND_ASSIGN( r, compiler->Compile("[true, false, true, false].exists(x, x && !x)")); EXPECT_TRUE(r.IsValid()); // Not in allowlist. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_TRUE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_TRUE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); EXPECT_FALSE(r.IsValid()); } } // namespace } // namespace cel ================================================ FILE: checker/validation_result.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/validation_result.h" #include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "checker/type_check_issue.h" namespace cel { std::string ValidationResult::FormatError() const { return absl::StrJoin( issues_, "\n", [this](std::string* out, const TypeCheckIssue& issue) { absl::StrAppend(out, issue.ToDisplayString(source_.get())); }); } } // namespace cel ================================================ FILE: checker/validation_result.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" #include "common/source.h" #include "common/type.h" namespace cel { // ValidationResult holds the result of type checking. // // Error states are captured as type check issues where possible. class ValidationResult { public: using TypeMap = absl::flat_hash_map; ValidationResult(std::unique_ptr ast, std::vector issues) : ast_(std::move(ast)), issues_(std::move(issues)) {} explicit ValidationResult(std::vector issues) : ast_(nullptr), issues_(std::move(issues)) {} bool IsValid() const { return ast_ != nullptr; } // Returns the AST if validation was successful. // // This is a non-null pointer if IsValid() is true. const Ast* absl_nullable GetAst() const { return ast_.get(); } absl::StatusOr> ReleaseAst() { if (ast_ == nullptr) { return absl::FailedPreconditionError( "ValidationResult is empty. Check for TypeCheckIssues."); } return std::move(ast_); } absl::Span GetIssues() const { return issues_; } void AddIssue(TypeCheckIssue issue) { issues_.push_back(std::move(issue)); } // The source expression may optionally be set if it is available. const cel::Source* absl_nullable GetSource() const { return source_.get(); } void SetSource(std::unique_ptr source) { source_ = std::move(source); } absl_nullable std::unique_ptr ReleaseSource() { return std::move(source_); } // Returns the resolved type map for the AST. // // Only populated if the AST was checked with an explicit arena. // // The type entries may have storage in the arena or reference type // information from the type checker that produced the AST. This means the map // is only valid as long as both the type checker and the arena are valid. const TypeMap& GetResolvedTypeMap() const { return resolved_type_map_; } void SetResolvedTypeMap(TypeMap resolved_type_map) { resolved_type_map_ = std::move(resolved_type_map); } // Returns a string representation of the issues in the result suitable for // display. // // The result is empty if no issues are present. // // The result is formatted similarly to CEL-Java and CEL-Go, but we do not // give strong guarantees on the format or stability. // // Example: // // ERROR: :1:3: Issue1 // | source.cel // | ..^ // INFORMATION: :-1:-1: Issue2 std::string FormatError() const; private: absl_nullable std::unique_ptr ast_; TypeMap resolved_type_map_; std::vector issues_; absl_nullable std::unique_ptr source_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ ================================================ FILE: checker/validation_result_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "checker/validation_result.h" #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "checker/type_check_issue.h" #include "common/ast.h" #include "common/source.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::testing::_; using ::testing::IsNull; using ::testing::NotNull; using ::testing::SizeIs; using Severity = TypeCheckIssue::Severity; TEST(ValidationResultTest, IsValidWithAst) { ValidationResult result(std::make_unique(), {}); EXPECT_TRUE(result.IsValid()); EXPECT_THAT(result.GetAst(), NotNull()); EXPECT_THAT(result.ReleaseAst(), IsOkAndHolds(NotNull())); } TEST(ValidationResultTest, IsNotValidWithoutAst) { ValidationResult result({}); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetAst(), IsNull()); EXPECT_THAT(result.ReleaseAst(), StatusIs(absl::StatusCode::kFailedPrecondition, _)); } TEST(ValidationResultTest, GetIssues) { ValidationResult result( {TypeCheckIssue::CreateError({-1, -1}, "Issue1"), TypeCheckIssue(Severity::kInformation, {-1, -1}, "Issue2")}); EXPECT_FALSE(result.IsValid()); ASSERT_THAT(result.GetIssues(), SizeIs(2)); EXPECT_THAT(result.GetIssues()[0].message(), "Issue1"); EXPECT_THAT(result.GetIssues()[0].severity(), Severity::kError); EXPECT_THAT(result.GetIssues()[1].message(), "Issue2"); EXPECT_THAT(result.GetIssues()[1].severity(), Severity::kInformation); } TEST(ValidationResultTest, FormatError) { ValidationResult result( {TypeCheckIssue::CreateError({1, 2}, "Issue1"), TypeCheckIssue(Severity::kInformation, {-1, -1}, "Issue2")}); EXPECT_FALSE(result.IsValid()); ASSERT_OK_AND_ASSIGN(std::unique_ptr source, NewSource("source.cel", "")); result.SetSource(std::move(source)); ASSERT_THAT(result.GetIssues(), SizeIs(2)); EXPECT_THAT(result.FormatError(), "ERROR: :1:3: Issue1\n" " | source.cel\n" " | ..^\n" "INFORMATION: :-1:-1: Issue2"); } } // namespace } // namespace cel ================================================ FILE: cloudbuild.yaml ================================================ steps: - name: 'gcr.io/cel-analysis/gcc9@sha256:4d5ff2e55224398807235a44b57e9c5793e922ac46e9ff428536bb8f8e5790ce' args: - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. - 'test' - '...' - '--enable_bzlmod' - '--copt=-Wno-deprecated-declarations' - '--compilation_mode=fastbuild' - '--test_output=errors' - '--show_timestamps' - '--test_tag_filters=-benchmark,-notap' - '--jobs=HOST_CPUS*.5' - '--local_ram_resources=HOST_RAM*.4' - '--remote_cache=https://storage.googleapis.com/cel-cpp-remote-cache' - '--google_default_credentials' id: gcc-9 waitFor: ['-'] - name: 'gcr.io/cel-analysis/gcc9@sha256:4d5ff2e55224398807235a44b57e9c5793e922ac46e9ff428536bb8f8e5790ce' env: - 'CC=clang-11' - 'CXX=clang++-11' args: - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. - 'test' - '...' - '--enable_bzlmod' - '--copt=-Wno-deprecated-declarations' - '--compilation_mode=fastbuild' - '--test_output=errors' - '--show_timestamps' - '--test_tag_filters=-benchmark,-notap' - '--jobs=HOST_CPUS*.5' - '--local_ram_resources=HOST_RAM*.4' - '--remote_cache=https://storage.googleapis.com/cel-cpp-remote-cache' - '--google_default_credentials' id: clang-11 waitFor: ['-'] timeout: 1h options: machineType: 'E2_HIGHCPU_32' ================================================ FILE: codelab/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) exports_files( srcs = glob([ "exercise*.h", "exercise*_test.cc", ]), visibility = ["//codelab/solutions:__pkg__"], ) # Exclude tests from tap and glob runs since they start failing for the codelab. # The solutions directory has test targets that are included to catch breaking changes. EXERCISE_TEST_TAGS = [ "manual", "notap", "norapid", ] cc_library( name = "exercise1", srcs = ["exercise1.cc"], hdrs = ["exercise1.h"], tags = [ "manual", "nobuilder", ], deps = [ "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//internal:status_macros", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise1_test", srcs = ["exercise1_test.cc"], tags = EXERCISE_TEST_TAGS, deps = [ ":exercise1", "//internal:testing", "@com_google_absl//absl/status", ], ) cc_library( name = "exercise2", srcs = ["exercise2.cc"], hdrs = ["exercise2.h"], deps = [ ":cel_compiler", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise2_test", srcs = ["exercise2_test.cc"], tags = EXERCISE_TEST_TAGS, deps = [ ":exercise2", "//internal:testing", "@com_google_absl//absl/status", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise3_test", srcs = ["exercise3_test.cc"], tags = EXERCISE_TEST_TAGS, deps = [ ":exercise2", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", ], ) cc_library( name = "cel_compiler", hdrs = ["cel_compiler.h"], deps = [ "//checker:validation_result", "//common:ast_proto", "//compiler", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", ], ) cc_test( name = "cel_compiler_test", srcs = ["cel_compiler_test.cc"], deps = [ ":cel_compiler", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//eval/public:activation", "//eval/public:activation_bind_helper", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_function_adapter", "//eval/public:cel_value", "//eval/public/testing:matchers", "//internal:testing", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "exercise4", srcs = ["exercise4.cc"], hdrs = ["exercise4.h"], deps = [ ":cel_compiler", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//eval/public:activation", "//eval/public:activation_bind_helper", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise4_test", srcs = ["exercise4_test.cc"], tags = EXERCISE_TEST_TAGS, deps = [ ":exercise4", "//internal:testing", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_library( name = "network_functions", srcs = ["network_functions.cc"], hdrs = ["network_functions.h"], deps = [ "//checker:type_checker_builder", "//common:decl", "//common:native_type", "//common:type", "//common:typeinfo", "//common:value", "//compiler", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", "//runtime:type_registry", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "network_functions_test", srcs = ["network_functions_test.cc"], deps = [ ":network_functions", "//common:decl", "//common:minimal_descriptor_pool", "//common:type", "//common:value", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:benchmark", "//internal:status_macros", "//internal:testing", "//runtime", "//runtime:activation", "//runtime:constant_folding", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "exercise10", srcs = ["exercise10.cc"], hdrs = ["exercise10.h"], deps = [ ":network_functions", "//checker:validation_result", "//common:decl", "//common:minimal_descriptor_pool", "//common:type", "//common:value", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//runtime", "//runtime:activation", "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise10_test", srcs = ["exercise10_test.cc"], tags = EXERCISE_TEST_TAGS, deps = [ ":exercise10", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", ], ) ================================================ FILE: codelab/Dockerfile ================================================ ARG DEBIAN_IMAGE="marketplace.gcr.io/google/debian11:latest" FROM ${DEBIAN_IMAGE} ARG BAZELISK_RELEASE="https://github.com/bazelbuild/bazelisk/releases/download/v1.25.0/bazelisk-amd64.deb" RUN apt update && apt upgrade -y && apt install -y gcc-9 g++-9 clang-13 git curl bash openjdk-11-jdk-headless RUN curl -L ${BAZELISK_RELEASE} > ./bazelisk.deb RUN apt install ./bazelisk.deb RUN git clone https://github.com/google/cel-cpp.git ENV CXX=clang++-13 ENV CC=clang-13 WORKDIR /cel-cpp # not generally recommended to cache the bazel build in the image, # but works ok for prototyping. RUN bazelisk build ... && bazelisk test //codelab/solutions:all ================================================ FILE: codelab/README.md ================================================ # What is CEL? Common Expression Language (CEL) is an expression language that’s fast, portable, and safe to execute in performance-critical applications. CEL is designed to be embedded in an application, with application-specific extensions, and is ideal for extending declarative configurations that your applications might already use. ## What is covered in this Codelab? This codelab is aimed at developers who would like to learn CEL to use services that already support CEL. This Codelab covers common use cases. This codelab doesn't cover how to integrate CEL into your own project. For a more in-depth look at the language, semantics, and features see the [CEL Language Definition on GitHub](https://github.com/google/cel-spec). Some key areas covered are: * [Hello, World: Using CEL to evaluate a String](#hello-world) * [Creating variables](#creating-variables) * [Commutative logical AND/OR](#logical-andor) * [Adding custom functions](#custom-functions) ### Prerequisites This codelab builds upon a basic understanding of Protocol Buffers and C++. If you're not familiar with Protocol Buffers, the first exercise will give you a sense of how CEL works, but because the more advanced examples use Protocol Buffers as the input into CEL, they may be harder to understand. Consider working through one of these tutorials, first. See the devsite for [Protocol Buffers](https://protobuf.dev). Notes on portability: Protocol Buffers are not required to use CEL generally, but the C++ implementation has a hard dependency on the library and some APIs reference protobuf types directly. Automated builds test against gcc9 and clang11 on linux. We accept requests for portability fixes for other OSes and compilers, but don't actively maintain support at this time. A simple Docker file is provided as a reference for a known good environment configuration for running the codelab solutions. What you'll need: - Git - Bazel - C/C++ Compiler (GCC, Clang, Visual Studio). - Optional: bazelisk is a wrapper around bazel that simplifies version management. If using, substitute all bazel commands below with `bazelisk`. ## GitHub Setup GitHub Repo: The code for this codelab lives in the `codelab` folder of the cel-cpp repo. The solution is available in the `codelab/solution` folder of the same repo. Clone and cd into the repo: ``` git clone git@github.com:google/cel-cpp.git cd cel-cpp ``` Make sure everything is working by building the codelab: ``` bazel build //codelab:all ``` ## Hello, World In the tried and true tradition of all programming languages, let's start with "Hello, World!". Update exercise1.cc with the following: Using declarations: ```c++ using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpression; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; ``` Implementation: ```c++ absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr) { // === Start Codelab === // Setup a default environment for building expressions. InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(options); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); // Parse the expression. This is fine for codelabs, but this skips the type // checking phase. It won't check that functions and variables are available // in the environment, and it won't handle certain ambiguous identifier // expressions (e.g. container lookup vs namespaced name, packaged function // vs. receiver call style function). ParsedExpr parsed_expr; CEL_ASSIGN_OR_RETURN(parsed_expr, Parse(cel_expr)); // The evaluator uses a proto Arena for incidental allocations during // evaluation. proto2::Arena arena; // The activation provides variables and functions that are bound into the // expression environment. In this example, there's no context expected, so // we just provide an empty one to the evaluator. Activation activation; // Build the expression plan. This assumes that the source expression AST and // the expression builder outlives the CelExpression object. CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); // Actually run the expression plan. We don't support any environment // variables at the moment so just use an empty activation. CEL_ASSIGN_OR_RETURN(CelValue result, expression_plan->Evaluate(activation, &arena)); // Convert the result to a c++ string. CelValues may reference instances from // either the input expression, or objects allocated on the arena, so we need // to pass ownership (in this case by copying to a new instance and returning // that). return ConvertResult(result); // === End Codelab === } ``` Run the following to check your work: ``` bazel test //codelab:exercise1_test ``` You can add additional test cases or experiment with different return types. Hello, World! Now, let's break down what's happening. ### Setup the Environment CEL applications evaluate an expression against an environment. The standard CEL environment supports all of the types, operators, functions, and macros defined within the language spec. The environment can be customized by providing options to disable macros, declare custom variables and functions, etc. An ExpressionBuilder maintains C++ evaluation environment. This creates a builder with the standard environment. ```c++ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_options.h" ... // Setup a default environment for building expressions. // Breaking behavior changes and optional features are controlled by // InterpreterOptions. InterpreterOptions options; // Environment used for planning and evaluating expressions is managed by an // ExpressionBuilder. std::unique_ptr builder = CreateCelExpressionBuilder(options); // Add standard function bindings e.g. for +,-,==,||,&& operators. // Custom functions (implementing the CelFunction interface) can be added to the // registry similarly. CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); ``` ### Parse After the environment is configured, you can parse and check the expressions: ```c++ #include "google/api/expr/syntax.proto.h" #include "parser/parser.h" // ... ASSIGN_OR_RETURN(google::api::expr::ParsedExpr parsed_expr, google::api::expr::parser::Parse(cel_expr)); ``` The C++ parser is a stand-alone utility. It's not aware of the evaluation environment and does not perform any semantic checks on the expression. A status is returned if the input string isn't a syntactically valid CEL expression or if it exceeds the configured complexity limits (see cel::ParserOptions and default limits). ### Evaluate After the expressions have been parsed and checked into an AST representation, it can be converted into an evaluable program whose function bindings and evaluation modes can be customized depending on the stack you are using. Once a CEL expression is planned, it can be evaluated against an evaluation context (an activation). The evaluation result will be either a value or an error state. The InterpreterOptions to create the expression plan are honored at evaluation. C++ uses the proto representation of either a parsed `google.api.expr.ParsedExpr` or parsed and type-checked `google.api.expr.CheckedExpr` AST directly. Once a CEL program is planned (represented by a `google::api::expr::runtime::CelExpression`), it can be evaluated against an `google::api::expr::runtime::Activation`. The Activation provides per-evaluation bindings for variables and functions in the expression's environment. ```c++ #include "third_party/protobuf/arena.h" #include "eval/public/activation.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "parser/parser.h" ... // The evaluator uses a proto Arena for incidental allocations during // evaluation. proto2::Arena arena; // The activation provides variables and functions that are bound into the // expression environment. In this example, there's no context expected, so // we just provide an empty one to the evaluator. Activation activation; // Build the expression plan. This assumes that the source expression AST and // the expression builder outlives the CelExpression object. CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); // Actually run the expression plan. We don't support any environment // variables at the moment so just use an empty activation. CEL_ASSIGN_OR_RETURN(CelValue result, expression_plan->Evaluate(activation, &arena)); // Convert the result to a C++ string. CelValues may reference instances from // either the input expression, or objects allocated on the arena, so we need // to pass ownership (in this case by copying to a new instance and returning // that). return ConvertResult(result); ``` ## Creating variables Most CEL applications will declare variables that can be referenced within expressions. Variables declarations specify a name and a type. A variable's type may either be a CEL builtin type, a protocol buffer well-known type, or any protobuf message type so long as its descriptor is also provided to CEL. At runtime, the hosting program binds instances of variables to the evaluation context (using the variable name as a key). For the C++ evaluator at runtime, the values are managed by the `google::api::expr::runtime::CelValue` type, a variant over the C++ representations of supported CEL types. Update exercise2.cc: ```c++ // The Variables exercise shows how to declare and use variables in expressions. // There are two overloads for preparing an expression either granularly for // individual variables or using a helper to bind a context proto. // The first overload shows manually populating individual variables in the // evaluation environment. This allows cel_expr to reference 'bool_var'. absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr, bool bool_var) { Activation activation; proto2::Arena arena; // === Start Codelab === activation.InsertValue("bool_var", CelValue::CreateBool(bool_var)); // === End Codelab === return ParseAndEvaluate(cel_expr, activation, &arena); } ``` Run the following to check your work. You should have fixed the first two test cases in exercise2_test.cc. ``` bazel test //codelab:exercise2_test ``` The second overload uses a protocol buffer message to represent the environment variables. For this use case, there is a helper to automatically bind in fields from a top level message (see `google::api::expr::runtime::BindProtoToActivation`). In this example, we assume that unset fields should be bound to default values. ```c++ #include "eval/public/activation_bind_helper.h" // ... using ::google::api::expr::runtime::ProtoUnsetFieldOptions; // ... absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr, const AttributeContext& context) { Activation activation; google::protobuf::Arena arena; // === Start Codelab === CEL_RETURN_IF_ERROR(BindProtoToActivation( &context, &arena, &activation, ProtoUnsetFieldOptions::kBindDefault)); // === End Codelab === return ParseAndEvaluate(cel_expr, activation, &arena); } ``` Note: You can experiment with unset values and the alternative bind option for BindProtoToActivation. With ProtoUnsetFieldOptions::kSkip unset values will not be bound at all, and accesses in expressions will cause errors. ## Logical And/Or One of CEL's more distinctive features is its use of commutative logical operators. Either side of a conditional branch can short-circuit the evaluation, even in the face of errors or partial input. Note: If you are skipping ahead, copy the solution for exercise2 -- we'll be using it to test the behavior of some simple expressions. exercise3_test.cc lists truth tables for simple expressions using the 'or', 'and', and 'ternary' operators. Running the following should result in some failing expectations. ``` bazel test //codelab:exercise3_test ``` Open exercise3_test.cc in your editor: ```c++ TEST(Exercise3Var, LogicalOr) { // Some of these expectations are incorrect. // If a logical operation can short-circuit a branch that results in an error, // CEL evaluation will return the logical result instead of propagating the // error. For logical or, this means if one branch is true, the result will // always be true, regardless of the other branch. // Wrong EXPECT_THAT(TruthTableTest("true || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("false || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); // Wrong EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || true"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("true || true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("true || false"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("false || true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("false || false"), IsOkAndHolds(false)); } ``` Updating the two failing cases "true || (1 / 0 > 2)" and "(1 / 0 > 2) || true" should fix this test: ```c++ // ... // Correct EXPECT_THAT(TruthTableTest("true || (1 / 0 > 2)"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("false || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); // Correct EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || true"), IsOkAndHolds(true)); ``` You can examine the other tests for other cases for corresponding behavior for the 'and' and ternary operators. CEL finds an evaluation order which gives results whenever possible, ignoring errors or even missing data that might occur in other evaluation orders. Applications like IAM conditions rely on this property to minimize the cost of evaluation, deferring the gathering of expensive inputs when a result can be reached without them. ================================================ FILE: codelab/cel_compiler.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CODELAB_COMPILER_H_ #define THIRD_PARTY_CEL_CPP_CODELAB_COMPILER_H_ #include "cel/expr/checked.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/validation_result.h" #include "common/ast_proto.h" #include "compiler/compiler.h" #include "internal/status_macros.h" namespace cel_codelab { // Helper for compiling expression and converting to proto. // // Simplifies error handling for brevity in the codelab. inline absl::StatusOr CompileToCheckedExpr( const cel::Compiler& compiler, absl::string_view expr) { CEL_ASSIGN_OR_RETURN(cel::ValidationResult result, compiler.Compile(expr)); if (!result.IsValid() || result.GetAst() == nullptr) { return absl::InvalidArgumentError(result.FormatError()); } cel::expr::CheckedExpr pb; CEL_RETURN_IF_ERROR(cel::AstToCheckedExpr(*result.GetAst(), &pb)); return pb; }; } // namespace cel_codelab #endif // THIRD_PARTY_CEL_CPP_CODELAB_COMPILER_H_ ================================================ FILE: codelab/cel_compiler_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/cel_compiler.h" #include #include #include "google/rpc/context/attribute_context.pb.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "eval/public/activation.h" #include "eval/public/activation_bind_helper.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_value.h" #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel_codelab { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::BoolType; using ::cel::MakeFunctionDecl; using ::cel::MakeOverloadDecl; using ::cel::MakeVariableDecl; using ::cel::StringType; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::BindProtoToActivation; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::FunctionAdapter; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::api::expr::runtime::test::IsCelBool; using ::google::rpc::context::AttributeContext; using ::testing::HasSubstr; std::unique_ptr MakeDefaultCompilerBuilder() { google::protobuf::LinkMessageReflection(); auto builder = cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool()); ABSL_CHECK_OK(builder.status()); ABSL_CHECK_OK((*builder)->AddLibrary(cel::StandardCompilerLibrary())); ABSL_CHECK_OK((*builder)->GetCheckerBuilder().AddContextDeclaration( "google.rpc.context.AttributeContext")); return std::move(builder).value(); } TEST(DefaultCompiler, Basic) { ASSERT_OK_AND_ASSIGN(auto compiler, MakeDefaultCompilerBuilder()->Build()); EXPECT_THAT(compiler->Compile("1 < 2").status(), IsOk()); } TEST(DefaultCompiler, AddFunctionDecl) { auto builder = MakeDefaultCompilerBuilder(); ASSERT_OK_AND_ASSIGN( cel::FunctionDecl decl, MakeFunctionDecl("IpMatch", MakeOverloadDecl("IpMatch_string_string", BoolType(), StringType(), StringType()))); EXPECT_THAT(builder->GetCheckerBuilder().AddFunction(decl), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); EXPECT_THAT(CompileToCheckedExpr( *compiler, "IpMatch('255.255.255.255', '255.255.255.255')") .status(), IsOk()); EXPECT_THAT( CompileToCheckedExpr(*compiler, "IpMatch('255.255.255.255', 123436)") .status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("no matching overload"))); } TEST(DefaultCompiler, EndToEnd) { google::protobuf::Arena arena; auto compiler_builder = MakeDefaultCompilerBuilder(); ASSERT_OK_AND_ASSIGN( cel::FunctionDecl func_decl, MakeFunctionDecl("MyFunc", MakeOverloadDecl("MyFunc", BoolType()))); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddFunction(func_decl), IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("my_var", BoolType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); ASSERT_OK_AND_ASSIGN( auto expr, CompileToCheckedExpr( *compiler, "(my_var || MyFunc()) && request.host == 'www.google.com'")); auto builder = CreateCelExpressionBuilder(google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory()); ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_THAT(FunctionAdapter::CreateAndRegister( "MyFunc", false, [](google::protobuf::Arena*) { return true; }, builder->GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto plan, builder->CreateExpression(&expr)); AttributeContext context; context.mutable_request()->set_host("www.google.com"); Activation activation; ASSERT_THAT(BindProtoToActivation(&context, &arena, &activation), IsOk()); activation.InsertValue("my_var", CelValue::CreateBool(false)); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); EXPECT_THAT(result, IsCelBool(true)); } } // namespace } // namespace cel_codelab ================================================ FILE: codelab/exercise1.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise1.h" #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace cel_codelab { namespace { using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelValue; // Convert the CelResult to a C++ string if it is string typed. Otherwise, // return invalid argument error. This takes a copy to avoid lifecycle concerns // (the evaluator may represent strings as stringviews backed by the input // expression). absl::StatusOr ConvertResult(const CelValue& value) { if (CelValue::StringHolder inner_value; value.GetValue(&inner_value)) { return std::string(inner_value.value()); } else { return absl::InvalidArgumentError(absl::StrCat( "expected string result got '", CelValue::TypeName(value.type()), "'")); } } } // namespace absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr) { // === Start Codelab === // Parse the expression using ::google::api::expr::parser::Parse; // This will return a cel::expr::ParsedExpr message. // Setup a default environment for building expressions. // std::unique_ptr builder = // CreateCelExpressionBuilder(options); // Register standard functions. // CEL_RETURN_IF_ERROR( // RegisterBuiltinFunctions(builder->GetRegistry(), options)); // The evaluator uses a proto Arena for incidental allocations during // evaluation. google::protobuf::Arena arena; // The activation provides variables and functions that are bound into the // expression environment. In this example, there's no context expected, so // we just provide an empty one to the evaluator. Activation activation; // Using the CelExpressionBuilder and the ParseExpr, create an execution plan // (google::api::expr::runtime::CelExpression), evaluate, and return the // result. Use the provided helper function ConvertResult to copy the value // for return. return absl::UnimplementedError("Not yet implemented"); // === End Codelab === } } // namespace cel_codelab ================================================ FILE: codelab/exercise1.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ #define THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" namespace cel_codelab { // Parse a cel expression and evaluate it. This assumes no special setup for // the evaluation environment, and that the expression results in a string // value. absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr); } // namespace cel_codelab #endif // THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ ================================================ FILE: codelab/exercise10.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise10.h" #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "checker/validation_result.h" #include "codelab/network_functions.h" #include "common/decl.h" #include "common/minimal_descriptor_pool.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel_codelab { namespace { absl::StatusOr> ConfigureCompiler() { absl::StatusOr> compiler_builder = cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool()); if (!compiler_builder.ok()) { return std::move(compiler_builder).status(); } absl::Status s = (*compiler_builder)->AddLibrary(cel::StandardCompilerLibrary()); // =========================================================================== // Codelab: Update compiler builder with functions from network_functions.h // and add a varible for the input IP. // =========================================================================== if (!s.ok()) return s; return (*compiler_builder)->Build(); } absl::StatusOr> ConfigureRuntime() { cel::RuntimeOptions runtime_options; // Note: this is needed to resolve net.Address as a `type` constant. runtime_options.enable_qualified_type_identifiers = true; absl::StatusOr runtime_builder = cel::CreateStandardRuntimeBuilder(cel::GetMinimalDescriptorPool(), runtime_options); // =========================================================================== // Codelab: Update runtime builder with functions from network_functions.h // =========================================================================== return std::move(runtime_builder).value().Build(); } } // namespace absl::StatusOr CompileAndEvaluateExercise10(absl::string_view expression, absl::string_view ip) { absl::StatusOr> compiler = ConfigureCompiler(); if (!compiler.ok()) { return std::move(compiler).status(); } absl::StatusOr> runtime = ConfigureRuntime(); if (!runtime.ok()) { return std::move(runtime).status(); } absl::StatusOr checked = (*compiler)->Compile(expression); if (!checked.ok()) { return std::move(checked).status(); } if (!checked->IsValid() || checked->GetAst() == nullptr) { return absl::InvalidArgumentError(checked->FormatError()); } absl::StatusOr> program = (*runtime)->CreateProgram(checked->ReleaseAst().value()); if (!program.ok()) { return std::move(program).status(); } cel::Activation activation; google::protobuf::Arena arena; activation.InsertOrAssignValue("ip", cel::StringValue::From(ip, &arena)); absl::StatusOr result = (*program)->Evaluate(&arena, activation); if (!result.ok()) { return std::move(result).status(); } if (result->IsBool()) { return result->GetBool(); } if (result->IsError()) { return result->GetError().ToStatus(); } return absl::InvalidArgumentError( absl::StrCat("unexpected result type: ", result->DebugString())); } } // namespace cel_codelab ================================================ FILE: codelab/exercise10.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE10_H_ #define THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE10_H_ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" namespace cel_codelab { // Exercise10 -- extension types. // // This function compiles an expression then evaluates, expecting a bool // return type. // // Example: // net.ParseAddressMatcher("8.8.0.0-8.8.255.255") // .containsAddress( // net.parseAddress(ip) // ) // // Variables: // ip - string // // Functions: // net.ParseAddress(string) -> net.Address // net.ParseAddressMatcher(string) -> net.AddressMatcher // (net.AddressMatcher). absl::StatusOr CompileAndEvaluateExercise10(absl::string_view expression, absl::string_view ip); } // namespace cel_codelab #endif // THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE10_H_ ================================================ FILE: codelab/exercise10_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise10.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "internal/testing.h" namespace cel_codelab { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::testing::HasSubstr; TEST(Exercise10, IpInRange) { EXPECT_THAT(CompileAndEvaluateExercise10( R"cel( net.parseAddressMatcher("8.8.4.0-8.8.4.255") .containsAddress( net.parseAddress(ip) ) )cel", "8.8.4.4"), IsOkAndHolds(true)); } TEST(Exercise10, IpNotInRange) { EXPECT_THAT(CompileAndEvaluateExercise10( R"cel( net.parseAddressMatcher("8.8.4.0-8.8.4.255") .containsAddress( net.parseAddress(ip) ) )cel", "8.8.8.8"), IsOkAndHolds(false)); } TEST(Exercise10, IpEqual) { EXPECT_THAT(CompileAndEvaluateExercise10( R"cel( net.parseAddress("8.8.4.4") == net.parseAddress(ip) )cel", "8.8.4.4"), IsOkAndHolds(true)); } TEST(Exercise10, IpInequal) { EXPECT_THAT(CompileAndEvaluateExercise10( R"cel( net.parseAddress("8.8.4.4") == net.parseAddress(ip) )cel", "8.8.8.8"), IsOkAndHolds(false)); } TEST(Exercise10, IpInvalid) { EXPECT_THAT(CompileAndEvaluateExercise10( R"cel( net.parseAddress("8.8.4.4") == net.parseAddress(ip) )cel", "8.8"), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid address"))); } } // namespace } // namespace cel_codelab ================================================ FILE: codelab/exercise1_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise1.h" #include "absl/status/status.h" #include "internal/testing.h" namespace cel_codelab { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; TEST(Exercise1, PrintHelloWorld) { EXPECT_THAT(ParseAndEvaluate("'Hello, World!'"), IsOkAndHolds("Hello, World!")); } TEST(Exercise1, WrongTypeResultError) { EXPECT_THAT(ParseAndEvaluate("true"), StatusIs(absl::StatusCode::kInvalidArgument, "expected string result got 'bool'")); } TEST(Exercise1, Conditional) { EXPECT_THAT(ParseAndEvaluate("(1 < 0)? 'Hello, World!' : '¡Hola, Mundo!'"), IsOkAndHolds("¡Hola, Mundo!")); } } // namespace } // namespace cel_codelab ================================================ FILE: codelab/exercise2.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise2.h" #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "codelab/cel_compiler.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel_codelab { namespace { using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelError; using ::google::api::expr::runtime::CelExpression; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::rpc::context::AttributeContext; absl::StatusOr> MakeCelCompiler() { // Note: we are using the generated descriptor pool here for simplicity, but // it has the drawback of including all message types that are linked into the // binary instead of just the ones expected for the CEL environment. google::protobuf::LinkMessageReflection(); CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); // === Start Codelab === // Add 'AttributeContext' as a context message to the type checker and a // boolean variable 'bool_var'. Relevant functions are on the // TypeCheckerBuilder class (see CompilerBuilder::GetCheckerBuilder). // // We're reusing the same compiler for both evaluation paths here for brevity, // but it's likely a better fit to configure a separate compiler per use case. // === End Codelab === return builder->Build(); } // Parse a cel expression and evaluate it against the given activation and // arena. absl::StatusOr EvalCheckedExpr(const CheckedExpr& checked_expr, const Activation& activation, google::protobuf::Arena* arena) { // Setup a default environment for building expressions. InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), options); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); // Note, the expression_plan below is reusable for different inputs, but we // create one just in time for evaluation here. CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, builder->CreateExpression(&checked_expr)); CEL_ASSIGN_OR_RETURN(CelValue result, expression_plan->Evaluate(activation, arena)); if (bool value; result.GetValue(&value)) { return value; } else if (const CelError * value; result.GetValue(&value)) { return *value; } else { return absl::InvalidArgumentError(absl::StrCat( "expected 'bool' result got '", result.DebugString(), "'")); } } } // namespace absl::StatusOr CompileAndEvaluateWithBoolVar(absl::string_view cel_expr, bool bool_var) { CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, MakeCelCompiler()); CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, CompileToCheckedExpr(*compiler, cel_expr)); Activation activation; google::protobuf::Arena arena; // === Start Codelab === // Update the activation to bind the bool argument to 'bool_var' // === End Codelab === return EvalCheckedExpr(checked_expr, activation, &arena); } absl::StatusOr CompileAndEvaluateWithContext( absl::string_view cel_expr, const AttributeContext& context) { CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, MakeCelCompiler()); CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, CompileToCheckedExpr(*compiler, cel_expr)); Activation activation; google::protobuf::Arena arena; // === Start Codelab === // Update the activation to bind the AttributeContext. // === End Codelab === return EvalCheckedExpr(checked_expr, activation, &arena); } } // namespace cel_codelab ================================================ FILE: codelab/exercise2.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ #define THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" namespace cel_codelab { // Compile a cel expression and evaluate it. Binds a simple boolean to the // activation as 'bool_var' for use in the expression. // // cel_expr should result in a bool, otherwise an InvalidArgument error is // returned. absl::StatusOr CompileAndEvaluateWithBoolVar(absl::string_view cel_expr, bool bool_var); // Compile a cel expression and evaluate it. Binds an instance of the // AttributeContext message to the activation (binding the subfields directly). absl::StatusOr CompileAndEvaluateWithContext( absl::string_view cel_expr, const google::rpc::context::AttributeContext& context); } // namespace cel_codelab #endif // THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE1_H_ ================================================ FILE: codelab/exercise2_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise2.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "internal/testing.h" #include "google/protobuf/text_format.h" namespace cel_codelab { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::google::rpc::context::AttributeContext; using ::google::protobuf::TextFormat; using ::testing::HasSubstr; TEST(Exercise2Var, Simple) { EXPECT_THAT(CompileAndEvaluateWithBoolVar("bool_var", false), IsOkAndHolds(false)); EXPECT_THAT(CompileAndEvaluateWithBoolVar("bool_var", true), IsOkAndHolds(true)); EXPECT_THAT(CompileAndEvaluateWithBoolVar("bool_var || true", false), IsOkAndHolds(true)); EXPECT_THAT(CompileAndEvaluateWithBoolVar("bool_var && false", true), IsOkAndHolds(false)); } TEST(Exercise2Var, WrongTypeResultError) { EXPECT_THAT(CompileAndEvaluateWithBoolVar("'not a bool'", false), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("expected 'bool' result got 'string"))); } TEST(Exercise2Context, Simple) { AttributeContext context; ASSERT_TRUE(TextFormat::ParseFromString(R"pb( source { ip: "192.168.28.1" } request { host: "www.example.com" } destination { ip: "192.168.56.1" } )pb", &context)); EXPECT_THAT( CompileAndEvaluateWithContext("source.ip == '192.168.28.1'", context), IsOkAndHolds(true)); EXPECT_THAT(CompileAndEvaluateWithContext("request.host == 'api.example.com'", context), IsOkAndHolds(false)); EXPECT_THAT(CompileAndEvaluateWithContext("request.host == 'www.example.com'", context), IsOkAndHolds(true)); EXPECT_THAT(CompileAndEvaluateWithContext("destination.ip != '192.168.56.1'", context), IsOkAndHolds(false)); } TEST(Exercise2Context, WrongTypeResultError) { AttributeContext context; // For this codelab, we expect the bind default option which will return // proto api defaults for unset fields. EXPECT_THAT(CompileAndEvaluateWithContext("request.host", context), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("expected 'bool' result got 'string"))); } } // namespace } // namespace cel_codelab ================================================ FILE: codelab/exercise3_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "codelab/exercise2.h" #include "internal/testing.h" namespace cel_codelab { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::google::rpc::context::AttributeContext; // Helper for a simple CelExpression with no context. absl::StatusOr TruthTableTest(absl::string_view statement) { return CompileAndEvaluateWithBoolVar(statement, /*unused*/ false); } TEST(Exercise3, LogicalOr) { // Some of these expectations are incorrect. // If a logical operation can short-circuit a branch that results in an error, // CEL evaluation will return the logical result instead of propagating the // error. For logical or, this means if one branch is true, the result will // always be true, regardless of the other branch. // Wrong EXPECT_THAT(TruthTableTest("true || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("false || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); // Wrong EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || true"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("true || true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("true || false"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("false || true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("false || false"), IsOkAndHolds(false)); } TEST(Exercise3, LogicalAnd) { EXPECT_THAT(TruthTableTest("true && (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); // Wrong EXPECT_THAT(TruthTableTest("false && (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && true"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); // Wrong EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("true && true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("true && false"), IsOkAndHolds(false)); EXPECT_THAT(TruthTableTest("false && true"), IsOkAndHolds(false)); EXPECT_THAT(TruthTableTest("false && false"), IsOkAndHolds(false)); } TEST(Exercise3, Ternary) { EXPECT_THAT(TruthTableTest("(1 / 0 > 2) ? false : false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("true ? (1 / 0 > 2) : false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); // Wrong EXPECT_THAT(TruthTableTest("false ? (1 / 0 > 2) : false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); } TEST(Exercise3, BadFieldAccess) { AttributeContext context; // This type of error is normally caught by the type checker, to allow // it to surface here we use the dyn() operator to defer checking to runtime. // typo-ed field name from 'request.host' EXPECT_THAT( CompileAndEvaluateWithContext( "dyn(request).hostname == 'localhost' && true", context), StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); // Wrong EXPECT_THAT( CompileAndEvaluateWithContext( "dyn(request).hostname == 'localhost' && false", context), StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); // Wrong EXPECT_THAT( CompileAndEvaluateWithContext( "dyn(request).hostname == 'localhost' || true", context), StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); EXPECT_THAT( CompileAndEvaluateWithContext( "dyn(request).hostname == 'localhost' || false", context), StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); } } // namespace } // namespace cel_codelab ================================================ FILE: codelab/exercise4.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise4.h" #include #include "cel/expr/checked.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "codelab/cel_compiler.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "eval/public/activation.h" #include "eval/public/activation_bind_helper.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel_codelab { namespace { using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::BindProtoToActivation; using ::google::api::expr::runtime::CelError; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::rpc::context::AttributeContext; absl::StatusOr> MakeConfiguredCompiler() { // Setup for handling for protobuf types. // Using the generated descriptor pool is simpler to configure, but often // adds more types than necessary. google::protobuf::LinkMessageReflection(); CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); // Adds fields of AttributeContext as variables. CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddContextDeclaration( AttributeContext::descriptor()->full_name())); // Codelab part 1: // Add a declaration for the map.contains(string, V) function. // Hint: use cel::MakeFunctionDecl and cel::TypeCheckerBuilder::MergeFunction. return builder->Build(); } class Evaluator { public: Evaluator() { builder_ = CreateCelExpressionBuilder( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), options_); } absl::Status SetupEvaluatorEnvironment() { CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder_->GetRegistry())); // Codelab part 2: // Register the map.contains(string, value) function. // Hint: use `CelFunctionAdapter::CreateAndRegister` to adapt from a free // function ContainsExtensionFunction. return absl::OkStatus(); } absl::StatusOr Evaluate(const CheckedExpr& expr, const AttributeContext& context) { Activation activation; CEL_RETURN_IF_ERROR(BindProtoToActivation(&context, &arena_, &activation)); CEL_ASSIGN_OR_RETURN(auto plan, builder_->CreateExpression(&expr)); CEL_ASSIGN_OR_RETURN(CelValue result, plan->Evaluate(activation, &arena_)); if (bool value; result.GetValue(&value)) { return value; } else if (const CelError * value; result.GetValue(&value)) { return *value; } else { return absl::InvalidArgumentError( absl::StrCat("unexpected return type: ", result.DebugString())); } } private: google::protobuf::Arena arena_; std::unique_ptr builder_; InterpreterOptions options_; }; } // namespace absl::StatusOr EvaluateWithExtensionFunction( absl::string_view expr, const AttributeContext& context) { // Prepare a checked expression. CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, MakeConfiguredCompiler()); CEL_ASSIGN_OR_RETURN(auto checked_expr, CompileToCheckedExpr(*compiler, expr)); // Prepare an evaluation environment. Evaluator evaluator; CEL_RETURN_IF_ERROR(evaluator.SetupEvaluatorEnvironment()); // Evaluate a checked expression against a particular activation return evaluator.Evaluate(checked_expr, context); } } // namespace cel_codelab ================================================ FILE: codelab/exercise4.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE4_H_ #define THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE4_H_ #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" namespace cel_codelab { // Compile and evaluate an expression with google.rpc.context.AttributeContext // as context. // The environment includes the custom map member function // .contains(string, string). absl::StatusOr EvaluateWithExtensionFunction( absl::string_view cel_expr, const google::rpc::context::AttributeContext& context); } // namespace cel_codelab #endif // THIRD_PARTY_CEL_CPP_CODELAB_EXERCISE4_H_ ================================================ FILE: codelab/exercise4_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise4.h" #include "google/protobuf/struct.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "internal/testing.h" #include "google/protobuf/text_format.h" namespace cel_codelab { namespace { using ::absl_testing::IsOkAndHolds; using ::google::rpc::context::AttributeContext; TEST(EvaluateWithExtensionFunction, Baseline) { AttributeContext context; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"(request { path: "/" auth { claims { fields { key: "group" value {string_value: "admin"} } } } })", &context)); EXPECT_THAT(EvaluateWithExtensionFunction("request.path == '/'", context), IsOkAndHolds(true)); } TEST(EvaluateWithExtensionFunction, ContainsTrue) { AttributeContext context; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"(request { path: "/" auth { claims { fields { key: "group" value {string_value: "admin"} } } } })", &context)); EXPECT_THAT(EvaluateWithExtensionFunction( "request.auth.claims.contains('group', 'admin')", context), IsOkAndHolds(true)); } TEST(EvaluateWithExtensionFunction, ContainsFalse) { AttributeContext context; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"(request { path: "/" })", &context)); EXPECT_THAT(EvaluateWithExtensionFunction( "request.auth.claims.contains('group', 'admin')", context), IsOkAndHolds(false)); } } // namespace } // namespace cel_codelab ================================================ FILE: codelab/network_functions.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/network_functions.h" #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/native_type.h" #include "common/type.h" #include "common/typeinfo.h" #include "common/value.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel_codelab { namespace { // TODO(uncreated-issue/86): This is how internal extensions create types, but it isn't // a good pattern for client extensions (since they can't pool into one eternal // arena). google::protobuf::Arena* absl_nonnull BuiltinsArena() { static absl::NoDestructor arena; return arena.get(); } cel::Type AddressType() { static cel::Type kInstance( cel::OpaqueType(BuiltinsArena(), "net.Address", {})); return kInstance; } cel::Type TypeOfAddressType() { static cel::Type kInstance(cel::TypeType(BuiltinsArena(), AddressType())); return kInstance; } cel::Type AddressMatcherType() { static cel::Type kInstance( cel::OpaqueType(BuiltinsArena(), "net.AddressMatcher", {})); return kInstance; } cel::Type TypeOfAddressMatcherType() { static cel::Type kInstance( cel::TypeType(BuiltinsArena(), AddressMatcherType())); return kInstance; } absl::StatusOr ParseAddressImpl(absl::string_view str, uint32_t* ipv4_out, absl::Span ipv6_out) { if (str.size() < 2 || str.size() > 39) { return absl::InvalidArgumentError("unsupported address format (length)"); } if (absl::StrContains(str, ":")) { if (ipv6_out.size() < 16) { return absl::InternalError("invalid outbuffer in parse call"); } return absl::InvalidArgumentError("unsupported address format (ipv6)"); } uint32_t ipv4 = 0; int octet = 0; for (auto part : absl::StrSplit(str, '.')) { if (octet >= 4) { return absl::InvalidArgumentError( "unsupported address format (invalid ipv4)"); } int octet_val; if (!absl::SimpleAtoi(part, &octet_val) || octet_val > 255 || octet_val < 0) { return absl::InvalidArgumentError( "unsupported address format (invalid ipv4)"); } ipv4 <<= 8; ipv4 |= (uint32_t)octet_val; octet++; } if (octet != 4) { return absl::InvalidArgumentError( "unsupported address format (invalid ipv4)"); } *ipv4_out = ipv4; return IpVersion::kIPv4; } absl::Status ConfigureNetworkFunctions(cel::TypeCheckerBuilder& builder) { // Type identifiers CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl("net.Address", TypeOfAddressType()))); CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl("net.AddressMatcher", TypeOfAddressMatcherType()))); CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl("net.addressZeroValue", AddressType()))); // net.parseAddress(string) -> net.Address CEL_ASSIGN_OR_RETURN( auto decl, MakeFunctionDecl("net.parseAddress", MakeOverloadDecl("net_parseAddress_string", AddressType(), cel::StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(decl)); // net.parseAddressOrZero(string) -> net.Address CEL_ASSIGN_OR_RETURN( decl, MakeFunctionDecl("net.parseAddressOrZero", MakeOverloadDecl("net_parseAddressOrZero_string", AddressType(), cel::StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(decl)); // net.parseAddressMatcher(string) -> net.AddressMatcher CEL_ASSIGN_OR_RETURN( decl, MakeFunctionDecl( "net.parseAddressMatcher", MakeOverloadDecl("net_parseAddressMatcher_string", AddressMatcherType(), cel::StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(decl)); // (net.AddressMatcher).containsAddress(net.Address) -> bool CEL_ASSIGN_OR_RETURN( decl, MakeFunctionDecl( "containsAddress", MakeMemberOverloadDecl( "net_AddressMatcher_containsAddress_net_Address", cel::BoolType(), AddressMatcherType(), AddressType()), MakeMemberOverloadDecl( "net_AddressMatcher_containsAddress_string", cel::BoolType(), AddressMatcherType(), cel::StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(decl)); return absl::OkStatus(); } // ============================================================================= // Opaque Value type implementations for NetworkAddressRep. // ============================================================================= cel::NativeTypeId NetworkAddressRepGetTypeId( const cel::OpaqueValueDispatcher* dispatcher, cel::OpaqueValueContent content) { return cel::TypeId(); } google::protobuf::Arena* absl_nullable NetworkAddressRepGetArena( const cel::OpaqueValueDispatcher* absl_nonnull dispatcher, cel::OpaqueValueContent content) { return nullptr; } absl::string_view NetworkAddressRepGetTypeName( const cel::OpaqueValueDispatcher* absl_nonnull dispatcher, cel::OpaqueValueContent content) { return "net.Address"; } std::string NetworkAddressRepDebugString( const cel::OpaqueValueDispatcher* absl_nonnull dispatcher, cel::OpaqueValueContent content) { return absl::StrCat("net.parseAddress('", content.To().Format(), "')"); } cel::OpaqueType NetworkAddressRepGetRuntimeType( const cel::OpaqueValueDispatcher* absl_nonnull dispatcher, cel::OpaqueValueContent content) { return AddressType().GetOpaque(); } absl::Status NetworkAddressRepEqual( const cel::OpaqueValueDispatcher* absl_nonnull, cel::OpaqueValueContent content, const cel::OpaqueValue& other, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull, cel::Value* absl_nonnull result) { if (other.GetTypeId() != cel::TypeId()) { *result = cel::BoolValue(false); return absl::OkStatus(); } const NetworkAddressRep rep = content.To(); absl::optional other_rep = NetworkAddressRep::Unwrap(other); ABSL_DCHECK(other_rep.has_value()); *result = cel::BoolValue(rep.IsEqualTo(*other_rep)); return absl::OkStatus(); } cel::OpaqueValue NetworkAddressRepClone( const cel::OpaqueValueDispatcher* absl_nonnull, cel::OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { const NetworkAddressRep* rep = content.To(); ABSL_DCHECK(rep != nullptr); return NetworkAddressRep::MakeValue(*rep).GetOpaque(); } // Opaque Value types can be implemented either with a shared dispatcher or // with a subclass (using vtable dispatch). // // We use the shared dispatcher here since the address type has a compact // representation and we don't need to support different implementations at // runtime. // // If the data structure is more complex, benefits from runtime polymorphism, or // doesn't have easily defined move, swap, and copy operations, it's // recommended to use a subclass instead. static const cel::OpaqueValueDispatcher kAddressDispatcher{ /*.GetTypeId=*/NetworkAddressRepGetTypeId, /*.GetArena=*/NetworkAddressRepGetArena, /*.GetTypeName=*/NetworkAddressRepGetTypeName, /*.DebugString=*/NetworkAddressRepDebugString, /*.GetRuntimeType=*/NetworkAddressRepGetRuntimeType, /*.Equal=*/NetworkAddressRepEqual, /*.Clone=*/NetworkAddressRepClone}; // ============================================================================= // Opaque Value type implementations for NetworkAddressMatcher. // ============================================================================= // Implementation of the OpaqueValueInterface for NetworkAddressMatcher. // // This is simpler to implement, but adds an extra allocation and pointer // indirection for every matcher. This is recommended if the data structure is // more complex. class NetworkAddressMatcherImpl : public cel::OpaqueValueInterface { public: explicit NetworkAddressMatcherImpl(NetworkAddressMatcher rep) : rep_(std::move(rep)) {} const NetworkAddressMatcher& rep() const { return rep_; } // implement the OpaqueValueInterface std::string DebugString() const final { return absl::StrCat("net.ParseAddressMatcher('", "TODO(uncreated-issue/86)", "')"); } absl::string_view GetTypeName() const final { return "net.AddressMatcher"; } cel::OpaqueType GetRuntimeType() const final { return AddressMatcherType().GetOpaque(); } absl::Status Equal(const cel::OpaqueValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, cel::Value* absl_nonnull result) const final { if (other.GetTypeId() != cel::TypeId()) { *result = cel::BoolValue(false); return absl::OkStatus(); } const NetworkAddressMatcherImpl* other_rep = static_cast(other.interface()); *result = cel::BoolValue(rep_.IsEqualTo(other_rep->rep_)); return absl::OkStatus(); } cel::OpaqueValue Clone(google::protobuf::Arena* absl_nonnull arena) const final { return NetworkAddressMatcher::MakeValue(arena, rep_).GetOpaque(); } cel::NativeTypeId GetNativeTypeId() const final { return cel::TypeId(); } private: NetworkAddressMatcher rep_; }; // ============================================================================= // Extension function implementations. // ============================================================================= cel::Value parseAddress( const cel::StringValue& str, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); absl::optional rep = NetworkAddressRep::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue(absl::InvalidArgumentError("invalid address")); } return NetworkAddressRep::MakeValue(*rep); } cel::Value parseAddressOrZero(const cel::StringValue& str) { std::string buf; absl::string_view addr = str.ToStringView(&buf); absl::optional rep = NetworkAddressRep::Parse(addr); static const NetworkAddressRep kZero; if (!rep.has_value()) { return NetworkAddressRep::MakeValue(kZero); } return NetworkAddressRep::MakeValue(*rep); } cel::Value parseAddressMatcher( const cel::StringValue& str, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); absl::optional rep = NetworkAddressMatcher::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue( absl::InvalidArgumentError("invalid address matcher")); } return NetworkAddressMatcher::MakeValue(arena, std::move(rep).value()); } cel::Value containsAddress(const cel::OpaqueValue& matcher, const cel::OpaqueValue& addr) { const auto* matcher_rep = NetworkAddressMatcher::Unwrap(matcher); auto addr_rep = NetworkAddressRep::Unwrap(addr); if (matcher_rep == nullptr || !addr_rep.has_value()) { // dispatcher should catch this, but right now only distiguishes at the // kind level. return cel::ErrorValue(absl::InvalidArgumentError("no matching overload")); } return cel::BoolValue(matcher_rep->Match(*addr_rep)); } } // namespace cel::Value NetworkAddressRep::MakeValue(const NetworkAddressRep& rep) { return UnsafeOpaqueValue(&kAddressDispatcher, cel::OpaqueValueContent::From(rep)); } absl::optional NetworkAddressRep::Unwrap( const cel::Value& value) { auto opaque = value.AsOpaque(); if (!opaque.has_value() || opaque->GetTypeId() != cel::TypeId()) { return absl::nullopt; } // Note: safety depends on: // 1) correctly implementing GetTypeId // 2) the TypeId is unique // 3) all calls to UnsafeOpaqueValue with the dispatcher provide the expected // content type. return opaque->content().To(); } absl::optional NetworkAddressRep::Parse( absl::string_view str) { uint32_t ipv4 = 0; char ipv6[16]; auto version = ParseAddressImpl(str, &ipv4, ipv6); if (!version.ok()) { return absl::nullopt; } if (*version != IpVersion::kIPv4) { return absl::nullopt; } NetworkAddressRep rep; rep.version_ = *version; rep.addr_.v4 = ipv4; return rep; } bool NetworkAddressRep::IsEqualTo(const NetworkAddressRep& other) const { if (version_ != other.version_) { return false; } if (version_ == IpVersion::kIPv4) { return addr_.v4 == other.addr_.v4; } return false; } bool NetworkAddressRep::IsLessThan(const NetworkAddressRep& other) const { if (version_ != other.version_) { return version_ < other.version_; } if (version_ == IpVersion::kIPv4) { return addr_.v4 < other.addr_.v4; } return false; } absl::optional NetworkAddressMatcher::Parse( absl::string_view str) { // range style addr-addr int dash_pos = str.find('-'); if (dash_pos == absl::string_view::npos) { // TODO(uncreated-issue/86): CIDR style addr/prefix-length return absl::nullopt; } absl::string_view min_str = str.substr(0, dash_pos); absl::string_view max_str = str.substr(dash_pos + 1); NetworkRangev4 v4; NetworkRangev6 v6; auto min_parse = ParseAddressImpl(min_str, &v4.min_incl, v6.min_incl); if (!min_parse.ok()) { return absl::nullopt; } auto max_parse = ParseAddressImpl(max_str, &v4.max_incl, v6.max_incl); if (!max_parse.ok()) { return absl::nullopt; } if (*min_parse != *max_parse) { return absl::nullopt; } NetworkAddressMatcher rep; if (*min_parse == IpVersion::kIPv4) { if (v4.min_incl > v4.max_incl) { return absl::nullopt; } rep.ranges_v4_.push_back(v4); } else if (*min_parse == IpVersion::kIPv6) { return absl::nullopt; } return rep; } cel::Value NetworkAddressMatcher::MakeValue(google::protobuf::Arena* arena, NetworkAddressMatcher rep) { auto* iface = google::protobuf::Arena::Create(arena, std::move(rep)); return cel::OpaqueValue(iface, arena); } const NetworkAddressMatcher* NetworkAddressMatcher::Unwrap( const cel::Value& value) { auto opaque = value.AsOpaque(); if (!opaque.has_value() || opaque->interface() == nullptr || opaque->GetTypeId() != cel::TypeId()) { return nullptr; } // Note: the safety of down casting like this depends on guaranteeing the // GetTypeId implementation is correct and is a unique ID. The CEL runtime // does not inspect or modify the interface type outside calling the interface // member functions. return &(static_cast(opaque->interface()) ->rep()); } bool NetworkAddressMatcher::Match(const NetworkAddressRep& addr) const { if (addr.IsZeroValue()) { return false; } if (addr.IsIPv4()) { for (const auto& range : ranges_v4_) { if (addr.GetIPv4() >= range.min_incl && addr.GetIPv4() <= range.max_incl) { return true; } } } // TODO(uncreated-issue/86): ipv6 support return false; } bool NetworkAddressMatcher::IsEqualTo( const NetworkAddressMatcher& other) const { if (ranges_v4_.size() != other.ranges_v4_.size()) { return false; } for (int i = 0; i < ranges_v4_.size(); ++i) { if (ranges_v4_[i].min_incl != other.ranges_v4_[i].min_incl || ranges_v4_[i].max_incl != other.ranges_v4_[i].max_incl) { return false; } } return true; } cel::CompilerLibrary NetworkFunctionsCompilerLibrary() { return cel::CompilerLibrary("cel_codelab.net", ConfigureNetworkFunctions); } absl::Status RegisterNetworkTypes(cel::TypeRegistry& registry, const cel::RuntimeOptions& options) { CEL_RETURN_IF_ERROR(registry.RegisterType(AddressType().GetOpaque())); CEL_RETURN_IF_ERROR(registry.RegisterType(AddressMatcherType().GetOpaque())); return absl::OkStatus(); } absl::Status RegisterNetworkFunctions(cel::FunctionRegistry& registry, const cel::RuntimeOptions& options) { // TODO(uncreated-issue/86): remaining functions auto s = cel::UnaryFunctionAdapter:: RegisterGlobalOverload("net.parseAddress", &parseAddress, registry); s.Update(cel::UnaryFunctionAdapter:: RegisterGlobalOverload("net.parseAddressOrZero", &parseAddressOrZero, registry)); s.Update(cel::UnaryFunctionAdapter:: RegisterGlobalOverload("net.parseAddressMatcher", &parseAddressMatcher, registry)); s.Update(cel::BinaryFunctionAdapter< cel::Value, const cel::OpaqueValue&, const cel::OpaqueValue&>::RegisterMemberOverload("containsAddress", &containsAddress, registry)); return s; } } // namespace cel_codelab ================================================ FILE: codelab/network_functions.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Example extension library for introducing an OpaqueValue type. // // The address handling is simplified for the example, and IPv6 is // unimplemented. Do not use this as-is. #ifndef THIRD_PARTY_CEL_CPP_CODELAB_NETWORK_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_CODELAB_NETWORK_FUNCTIONS_H_ #include #include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/value.h" #include "compiler/compiler.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" #include "google/protobuf/arena.h" namespace cel_codelab { enum class IpVersion : uint8_t { kUnset = 0, kIPv4 = 4, kIPv6 = 6, // unimplemented, but present for illustration. }; // Represents a network address. To simplify the CEL type representation, this // only supports IPv4. // // A the default value of 0v0 is special, and represents an invalid address, // comparing unequal to anything except itself. For the purposes of ordering, // compares less than any valid address. // // The example extension functions include a version that returns a zero value // on error and a version that returns a CEL error. // // This class is stored inline in the OpaqueValue because it is compact and // trivially copyable. class NetworkAddressRep { public: // Creates a Value that wraps the given NetworkAddress. The representation is // copied to the provided arena. static cel::Value MakeValue(const NetworkAddressRep& rep); // Unwraps a Value into a NetworkAddressRep. Returns nullptr if the value is // not a NetworkAddress. static absl::optional Unwrap(const cel::Value& value); // Parses a string representation of a network address. Returns nullopt if // the string is not a valid network address. // // TODO(uncreated-issue/86): error handling simplified for example, real usage should // provide some diagnostic for the parse failure. static absl::optional Parse(absl::string_view str); // Zero value for an invalid address. NetworkAddressRep() : addr_({0}), version_(IpVersion::kUnset) {} NetworkAddressRep(const NetworkAddressRep& other) = default; NetworkAddressRep(NetworkAddressRep&& other) = default; NetworkAddressRep& operator=(const NetworkAddressRep& other) = default; NetworkAddressRep& operator=(NetworkAddressRep&& other) = default; IpVersion version() const { return version_; } bool IsZeroValue() const { return version_ == IpVersion::kUnset; } bool IsIPv4() const { return version_ == IpVersion::kIPv4; } bool IsIPv6() const { return false; } absl::optional TryGetIPv4() const { if (version_ == IpVersion::kIPv4) { return addr_.v4; } return absl::nullopt; } absl::string_view TryGetIPv6() const { return absl::string_view(); } std::string Format() const { if (version_ == IpVersion::kUnset) { return "null"; } if (version_ == IpVersion::kIPv4) { return absl::StrCat( (addr_.v4 & 0xFF000000) >> 24, ".", (addr_.v4 & 0x00FF0000) >> 16, ".", (addr_.v4 & 0x0000FF00) >> 8, ".", (addr_.v4 & 0x000000FF)); } return "v6 not yet implemented"; } uint32_t GetIPv4() const { return addr_.v4; } bool IsEqualTo(const NetworkAddressRep& other) const; bool IsLessThan(const NetworkAddressRep& other) const; private: union { uint32_t v0; // zero value // Integer representation of an IPv4 address (system byte order) uint32_t v4; // TO_DO : add ipv6. this prevents storing the value inline due to size, so // skipped here. } addr_; IpVersion version_; }; // Represents a matcher for network addresses. // // Simple implementation that just stores a list of matching ranges. // // This is too big to store inline and has non-trivial copy and move behavior, // so the inline representation is a pointer to an arena-allocated object. class NetworkAddressMatcher { public: // Creates a Value that wraps the given NetworkAddress. static cel::Value MakeValue(google::protobuf::Arena* arena, NetworkAddressMatcher rep); // Unwraps a Value into a NetworkAddressMatcher. Returns nullptr if the value // is not a NetworkAddressMatcher. static const NetworkAddressMatcher* Unwrap(const cel::Value& value); // Parses a string representation of a network address matcher. Returns // nullopt if the string is not a valid network address matcher. // // TODO(uncreated-issue/86): supports a simple IPv4 range for illustration: e.g. // 8.8.0.0-8.8.255.255 static absl::optional Parse(absl::string_view str); // Default value for an empty matcher. Matches nothing. NetworkAddressMatcher() = default; NetworkAddressMatcher(const NetworkAddressMatcher& other) = default; NetworkAddressMatcher(NetworkAddressMatcher&& other) = default; NetworkAddressMatcher& operator=(const NetworkAddressMatcher& other) = default; NetworkAddressMatcher& operator=(NetworkAddressMatcher&& other) = default; bool IsEmpty() const { return ranges_v4_.empty(); } bool IsEqualTo(const NetworkAddressMatcher& other) const; bool Match(const NetworkAddressRep& addr) const; private: struct NetworkRangev4 { uint32_t min_incl; uint32_t max_incl; }; // placeholder for illustration, not implemented. struct NetworkRangev6 { char min_incl[16]; char max_incl[16]; }; friend void swap(NetworkAddressMatcher& lhs, NetworkAddressMatcher& rhs) { using std::swap; swap(lhs.ranges_v4_, rhs.ranges_v4_); } // Sorted, non-overlapping ranges of matching IP addresses. std::vector ranges_v4_; }; // Returns a compiler library that adds the network functions to the type // checker. cel::CompilerLibrary NetworkFunctionsCompilerLibrary(); // Registers the network functions in a runtime for evaluation. absl::Status RegisterNetworkFunctions(cel::FunctionRegistry& registry, const cel::RuntimeOptions& options); // Registers the network types in a runtime for evaluation. This is needed // for resolving the type name to a runtime type `net.Address != type('foo')`. absl::Status RegisterNetworkTypes(cel::TypeRegistry& registry, const cel::RuntimeOptions& options); } // namespace cel_codelab #endif // THIRD_PARTY_CEL_CPP_CODELAB_NETWORK_FUNCTIONS_H_ ================================================ FILE: codelab/network_functions_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/network_functions.h" #include #include #include #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/decl.h" #include "common/minimal_descriptor_pool.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/benchmark.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/activation.h" #include "runtime/constant_folding.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel_codelab { namespace { using ::absl_testing::IsOk; using ::cel::Activation; using ::cel::Compiler; using ::cel::Program; using ::cel::Runtime; using ::cel::RuntimeOptions; using ::cel::StringValue; using ::testing::HasSubstr; struct TestCase { std::string name; std::string expr; std::string type_check_err_substr; }; class NetworkFunctionsCheckerTest : public testing::TestWithParam {}; TEST_P(NetworkFunctionsCheckerTest, DeclarationsTest) { const TestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN( auto compiler_builder, cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(NetworkFunctionsCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expr)); if (!test_case.type_check_err_substr.empty()) { EXPECT_THAT(result.FormatError(), HasSubstr(test_case.type_check_err_substr)); return; } EXPECT_TRUE(result.IsValid()) << result.FormatError(); } INSTANTIATE_TEST_SUITE_P( NetworkFunctionsCheckerTests, NetworkFunctionsCheckerTest, testing::ValuesIn({ {"type_identifier_addr", "net.Address != type(1)"}, {"type_identifier_addr_2", "net.Address != list"}, {"type_identifier_addr_matcher", "net.AddressMatcher != type(1)"}, {"parse_address", "net.parseAddress('1.2.3.4')"}, {"parse_address_or_zero", "net.parseAddressOrZero('1.2.3.4')"}, {"parse_address_no_match", "net.parseAddress(1.0)", "no matching overload for 'net.parseAddress'"}, {"address_zero", "net.addressZeroValue"}, {"equals", "net.parseAddress('1.2.3.4') != net.addressZeroValue"}, {"address_matcher_parse", "net.parseAddressMatcher('8.8.8.0-8.8.8.255')"}, {"address_matcher_parse_invalid", "net.parseAddressMatcher('8.8.8.0-8.8.4.255')"}, {"address_matcher_contains", "net.parseAddressMatcher('8.8.8.0-8.8.8.255').containsAddress(net." "parseAddress('8.8.8.1'))"}, {"address_matcher_contains_string", "net.parseAddressMatcher('8.8.8.0-8.8.8.255').containsAddress('8.8.8." "1')"}, }), [](const testing::TestParamInfo& info) { return info.param.name; }); struct RuntimeTestCase { std::string name; std::string expr; std::string runtime_err_substr; bool expected_value = true; }; class NetworkFunctionsRuntimeTest : public testing::TestWithParam {}; TEST_P(NetworkFunctionsRuntimeTest, EvaluationTest) { const RuntimeTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN( auto compiler_builder, cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(NetworkFunctionsCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expr)); ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); RuntimeOptions runtime_options; runtime_options.enable_qualified_type_identifiers = true; ASSERT_OK_AND_ASSIGN(auto runtime_builder, CreateStandardRuntimeBuilder( cel::GetMinimalDescriptorPool(), runtime_options)); ASSERT_THAT( RegisterNetworkTypes(runtime_builder.type_registry(), runtime_options), IsOk()); ASSERT_THAT(RegisterNetworkFunctions(runtime_builder.function_registry(), runtime_options), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(auto eval_result, program->Evaluate(&arena, activation)); if (!test_case.runtime_err_substr.empty()) { if (!eval_result.IsError()) { FAIL() << "Expected error, but got: " << eval_result.DebugString(); } EXPECT_THAT(eval_result.GetError().ToStatus().message(), HasSubstr(test_case.runtime_err_substr)); return; } if (test_case.expected_value) { EXPECT_TRUE(eval_result.IsBool() && eval_result.GetBool()) << eval_result.DebugString(); } } INSTANTIATE_TEST_SUITE_P( NetworkFunctionsRuntimeTests, NetworkFunctionsRuntimeTest, testing::ValuesIn( {{"type_identifier_addr", "net.Address != type(1)"}, {"type_identifier_addr_2", "net.Address != list"}, {"type_identifier_addr_matcher", "net.AddressMatcher != type(1)"}, {"parse_address", "net.parseAddress('1.2.3.4') == net.parseAddress('1.2.3.4')"}, {"parse_address_2", "net.parseAddress('1.2.3.4') != net.parseAddress('2.3.4.5')"}, {"parse_address_invalid", "net.parseAddress('256.2.3.4') != net.parseAddress('1.2.3.4')", "invalid address"}, {"parse_address_or_zero", "net.parseAddressOrZero('256.2.3.4') != " "net.parseAddressOrZero('1.2.3.4')"}, {"parse_address_matcher", "net.parseAddressMatcher('8.8.8.0-8.8.8.255') != " "net.parseAddressMatcher('8.8.8.0-8.8.8.127')"}, {"address_matcher_matches", "net.parseAddressMatcher('8.8.8.0-8.8.8.255').containsAddress(net." "parseAddress('8.8.8.1'))"}}), [](const testing::TestParamInfo& info) { return info.param.name; }); class BenchmarkState { public: static absl::StatusOr Create(bool optimize) { CEL_ASSIGN_OR_RETURN( auto compiler_builder, cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool())); CEL_RETURN_IF_ERROR( compiler_builder->AddLibrary(cel::StandardCompilerLibrary())); CEL_RETURN_IF_ERROR( compiler_builder->AddLibrary(NetworkFunctionsCompilerLibrary())); compiler_builder->GetCheckerBuilder() .AddVariable(MakeVariableDecl("ip", cel::StringType())) .IgnoreError(); CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); RuntimeOptions runtime_options; CEL_ASSIGN_OR_RETURN(auto runtime_builder, CreateStandardRuntimeBuilder( cel::GetMinimalDescriptorPool(), runtime_options)); CEL_RETURN_IF_ERROR( RegisterNetworkTypes(runtime_builder.type_registry(), runtime_options)); CEL_RETURN_IF_ERROR(RegisterNetworkFunctions( runtime_builder.function_registry(), runtime_options)); if (optimize) { CEL_RETURN_IF_ERROR( cel::extensions::EnableConstantFolding(runtime_builder)); } CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); return BenchmarkState(std::move(compiler), std::move(runtime)); } absl::StatusOr> MakeProgram(absl::string_view expr) { CEL_ASSIGN_OR_RETURN(auto result, compiler_->Compile(expr)); if (!result.IsValid()) { return absl::InvalidArgumentError(result.FormatError()); } CEL_ASSIGN_OR_RETURN(auto ast, result.ReleaseAst()); return runtime_->CreateProgram(std::move(ast)); } private: BenchmarkState(std::unique_ptr c, std::unique_ptr r) : compiler_(std::move(c)), runtime_(std::move(r)) {} std::unique_ptr compiler_; std::unique_ptr runtime_; std::unique_ptr constants_; }; void BM_ParseAddress(benchmark::State& state) { bool optimize = state.range(0); auto runner = BenchmarkState::Create(optimize); ABSL_CHECK_OK(runner.status()); auto program = runner->MakeProgram("net.parseAddress('1.2.3.4')"); ABSL_CHECK_OK(program.status()); google::protobuf::Arena arena; Activation activation; for (auto s : state) { auto result = (*program)->Evaluate(&arena, activation); ABSL_CHECK_OK(result.status()); } } void BM_ParseAddressVar(benchmark::State& state) { bool optimize = state.range(0); auto runner = BenchmarkState::Create(optimize); ABSL_CHECK_OK(runner.status()); auto program = runner->MakeProgram("net.parseAddress(ip)"); ABSL_CHECK_OK(program.status()); google::protobuf::Arena arena; Activation activation; activation.InsertOrAssignValue("ip", StringValue::From("8.8.8.8", &arena)); for (auto s : state) { auto result = (*program)->Evaluate(&arena, activation); ABSL_CHECK_OK(result.status()); } } void BM_ParseAddressMatcher(benchmark::State& state) { bool optimize = state.range(0); auto runner = BenchmarkState::Create(optimize); ABSL_CHECK_OK(runner.status()); auto program = runner->MakeProgram("net.parseAddressMatcher('8.8.8.0-8.8.8.255')"); ABSL_CHECK_OK(program.status()); google::protobuf::Arena arena; Activation activation; for (auto s : state) { auto result = (*program)->Evaluate(&arena, activation); ABSL_CHECK_OK(result.status()); } } void BM_ParseAddressMatcherMatches(benchmark::State& state) { bool optimize = state.range(0); auto runner = BenchmarkState::Create(optimize); ABSL_CHECK_OK(runner.status()); auto program = runner->MakeProgram( "net.parseAddressMatcher('8.8.8.0-8.8.8.255').containsAddress(net." "parseAddress('8.8.8.1'))"); ABSL_CHECK_OK(program.status()); google::protobuf::Arena arena; Activation activation; for (auto s : state) { auto result = (*program)->Evaluate(&arena, activation); ABSL_CHECK_OK(result.status()); } } void BM_ParseAddressMatcherMatchesVar(benchmark::State& state) { bool optimize = state.range(0); auto runner = BenchmarkState::Create(optimize); ABSL_CHECK_OK(runner.status()); auto program = runner->MakeProgram( "net.parseAddressMatcher('8.8.0.0-8.8.255.255').containsAddress(net." "parseAddress(ip))"); ABSL_CHECK_OK(program.status()); google::protobuf::Arena arena; Activation activation; activation.InsertOrAssignValue("ip", StringValue::From("8.8.4.4", &arena)); for (auto s : state) { auto result = (*program)->Evaluate(&arena, activation); ABSL_CHECK_OK(result.status()); } } BENCHMARK(BM_ParseAddress)->Arg(0)->Arg(1); BENCHMARK(BM_ParseAddressVar)->Arg(0)->Arg(1); BENCHMARK(BM_ParseAddressMatcher)->Arg(0)->Arg(1); BENCHMARK(BM_ParseAddressMatcherMatches)->Arg(0)->Arg(1); BENCHMARK(BM_ParseAddressMatcherMatchesVar)->Arg(0)->Arg(1); } // namespace } // namespace cel_codelab ================================================ FILE: codelab/solutions/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "exercise1", srcs = ["exercise1.cc"], hdrs = ["//codelab:exercise1.h"], deps = [ "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//internal:status_macros", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise1_test", srcs = ["//codelab:exercise1_test.cc"], deps = [ ":exercise1", "//internal:testing", "@com_google_absl//absl/status", ], ) cc_library( name = "exercise2", srcs = ["exercise2.cc"], hdrs = ["//codelab:exercise2.h"], deps = [ "//checker:type_checker_builder", "//codelab:cel_compiler", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//eval/public:activation", "//eval/public:activation_bind_helper", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise2_test", srcs = ["//codelab:exercise2_test.cc"], deps = [ ":exercise2", "//internal:testing", "@com_google_absl//absl/status", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise3_test", srcs = ["exercise3_test.cc"], deps = [ ":exercise2", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", ], ) cc_library( name = "exercise4", srcs = ["exercise4.cc"], hdrs = ["//codelab:exercise4.h"], deps = [ "//codelab:cel_compiler", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//eval/public:activation", "//eval/public:activation_bind_helper", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function_adapter", "//eval/public:cel_options", "//eval/public:cel_value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise4_test", srcs = ["//codelab:exercise4_test.cc"], deps = [ ":exercise4", "//internal:testing", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_library( name = "exercise10", srcs = ["exercise10.cc"], hdrs = ["//codelab:exercise10.h"], deps = [ "//checker:validation_result", "//codelab:network_functions", "//common:decl", "//common:minimal_descriptor_pool", "//common:type", "//common:value", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//runtime", "//runtime:activation", "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "exercise10_test", srcs = ["//codelab:exercise10_test.cc"], deps = [ ":exercise10", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: codelab/solutions/exercise1.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise1.h" #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace cel_codelab { namespace { using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpression; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; // Convert the CelResult to a C++ string if it is string typed. Otherwise, // return invalid argument error. This takes a copy to avoid lifecycle concerns // (the evaluator may represent strings as stringviews backed by the input // expression). absl::StatusOr ConvertResult(const CelValue& value) { if (CelValue::StringHolder inner_value; value.GetValue(&inner_value)) { return std::string(inner_value.value()); } else { return absl::InvalidArgumentError(absl::StrCat( "expected string result got '", CelValue::TypeName(value.type()), "'")); } } } // namespace absl::StatusOr ParseAndEvaluate(absl::string_view cel_expr) { // === Start Codelab === // Setup a default environment for building expressions. InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(options); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); // Parse the expression. This is fine for codelabs, but this skips the type // checking phase. It won't check that functions and variables are available // in the environment, and it won't handle certain ambiguous identifier // expressions (e.g. container lookup vs namespaced name, packaged function // vs. receiver call style function). ParsedExpr parsed_expr; CEL_ASSIGN_OR_RETURN(parsed_expr, Parse(cel_expr)); // The evaluator uses a proto Arena for incidental allocations during // evaluation. google::protobuf::Arena arena; // The activation provides variables and functions that are bound into the // expression environment. In this example, there's no context expected, so // we just provide an empty one to the evaluator. Activation activation; // Build the expression plan. This assumes that the source expression AST and // the expression builder outlive the CelExpression object. CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); // Actually run the expression plan. We don't support any environment // variables at the moment so just use an empty activation. CEL_ASSIGN_OR_RETURN(CelValue result, expression_plan->Evaluate(activation, &arena)); // Convert the result to a c++ string. CelValues may reference instances from // either the input expression, or objects allocated on the arena, so we need // to pass ownership (in this case by copying to a new instance and returning // that). return ConvertResult(result); // === End Codelab === } } // namespace cel_codelab ================================================ FILE: codelab/solutions/exercise10.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise10.h" #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "checker/validation_result.h" #include "codelab/network_functions.h" #include "common/decl.h" #include "common/minimal_descriptor_pool.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel_codelab { namespace { absl::StatusOr> ConfigureCompiler() { absl::StatusOr> compiler_builder = cel::NewCompilerBuilder(cel::GetMinimalDescriptorPool()); if (!compiler_builder.ok()) { return std::move(compiler_builder).status(); } absl::Status s = (*compiler_builder)->AddLibrary(cel::StandardCompilerLibrary()); // =========================================================================== // Codelab: Update compiler builder with functions from network_functions.h // and add a varible for the input IP. // =========================================================================== s.Update((*compiler_builder)->AddLibrary(NetworkFunctionsCompilerLibrary())); s.Update((*compiler_builder) ->GetCheckerBuilder() .AddVariable(cel::MakeVariableDecl("ip", cel::StringType()))); if (!s.ok()) return s; return (*compiler_builder)->Build(); } absl::StatusOr> ConfigureRuntime() { cel::RuntimeOptions runtime_options; // Note: this is needed to resolve net.Address as a `type` constant. runtime_options.enable_qualified_type_identifiers = true; absl::StatusOr runtime_builder = cel::CreateStandardRuntimeBuilder(cel::GetMinimalDescriptorPool(), runtime_options); // =========================================================================== // Codelab: Update runtime builder with functions from network_functions.h // =========================================================================== absl::Status s = RegisterNetworkTypes(runtime_builder->type_registry(), runtime_options); s.Update(RegisterNetworkFunctions(runtime_builder->function_registry(), runtime_options)); if (!s.ok()) return s; return std::move(runtime_builder).value().Build(); } } // namespace absl::StatusOr CompileAndEvaluateExercise10(absl::string_view expression, absl::string_view ip) { absl::StatusOr> compiler = ConfigureCompiler(); if (!compiler.ok()) { return std::move(compiler).status(); } absl::StatusOr> runtime = ConfigureRuntime(); if (!runtime.ok()) { return std::move(runtime).status(); } absl::StatusOr checked = (*compiler)->Compile(expression); if (!checked.ok()) { return std::move(checked).status(); } if (!checked->IsValid() || checked->GetAst() == nullptr) { return absl::InvalidArgumentError(checked->FormatError()); } absl::StatusOr> program = (*runtime)->CreateProgram(checked->ReleaseAst().value()); if (!program.ok()) { return std::move(program).status(); } cel::Activation activation; google::protobuf::Arena arena; activation.InsertOrAssignValue("ip", cel::StringValue::From(ip, &arena)); absl::StatusOr result = (*program)->Evaluate(&arena, activation); if (!result.ok()) { return std::move(result).status(); } if (result->IsBool()) { return result->GetBool(); } if (result->IsError()) { return result->GetError().ToStatus(); } return absl::InvalidArgumentError( absl::StrCat("unexpected result type: ", result->DebugString())); } } // namespace cel_codelab ================================================ FILE: codelab/solutions/exercise2.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise2.h" #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "codelab/cel_compiler.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "eval/public/activation.h" #include "eval/public/activation_bind_helper.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel_codelab { namespace { using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelError; using ::google::api::expr::runtime::CelExpression; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::ProtoUnsetFieldOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::rpc::context::AttributeContext; absl::StatusOr> MakeCelCompiler() { // Note: we are using the generated descriptor pool here for simplicity, but // it has the drawback of including all message types that are linked into the // binary instead of just the ones expected for the CEL environment. google::protobuf::LinkMessageReflection(); CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); // === Start Codelab === cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); CEL_RETURN_IF_ERROR(checker_builder.AddVariable( cel::MakeVariableDecl("bool_var", cel::BoolType()))); CEL_RETURN_IF_ERROR(checker_builder.AddContextDeclaration( AttributeContext::descriptor()->full_name())); // === End Codelab === return builder->Build(); } // Parse a cel expression and evaluate it against the given activation and // arena. absl::StatusOr EvalCheckedExpr(const CheckedExpr& checked_expr, const Activation& activation, google::protobuf::Arena* arena) { // Setup a default environment for building expressions. InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), options); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); // Note, the expression_plan below is reusable for different inputs, but we // create one just in time for evaluation here. CEL_ASSIGN_OR_RETURN(std::unique_ptr expression_plan, builder->CreateExpression(&checked_expr)); CEL_ASSIGN_OR_RETURN(CelValue result, expression_plan->Evaluate(activation, arena)); if (bool value; result.GetValue(&value)) { return value; } else if (const CelError * value; result.GetValue(&value)) { return *value; } else { return absl::InvalidArgumentError(absl::StrCat( "expected 'bool' result got '", result.DebugString(), "'")); } } } // namespace absl::StatusOr CompileAndEvaluateWithBoolVar(absl::string_view cel_expr, bool bool_var) { CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, MakeCelCompiler()); CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, CompileToCheckedExpr(*compiler, cel_expr)); Activation activation; google::protobuf::Arena arena; // === Start Codelab === activation.InsertValue("bool_var", CelValue::CreateBool(bool_var)); // === End Codelab === return EvalCheckedExpr(checked_expr, activation, &arena); } absl::StatusOr CompileAndEvaluateWithContext( absl::string_view cel_expr, const AttributeContext& context) { CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, MakeCelCompiler()); CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, CompileToCheckedExpr(*compiler, cel_expr)); Activation activation; google::protobuf::Arena arena; // === Start Codelab === CEL_RETURN_IF_ERROR(BindProtoToActivation( &context, &arena, &activation, ProtoUnsetFieldOptions::kBindDefault)); // === End Codelab === return EvalCheckedExpr(checked_expr, activation, &arena); } } // namespace cel_codelab ================================================ FILE: codelab/solutions/exercise3_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "codelab/exercise2.h" #include "internal/testing.h" namespace cel_codelab { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::google::rpc::context::AttributeContext; // Helper for a simple CelExpression with no context. absl::StatusOr TruthTableTest(absl::string_view statement) { return CompileAndEvaluateWithBoolVar(statement, /*unused*/ false); } TEST(Exercise3, LogicalOr) { EXPECT_THAT(TruthTableTest("true || (1 / 0 > 2)"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("false || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) || (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("true || true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("true || false"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("false || true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("false || false"), IsOkAndHolds(false)); } TEST(Exercise3, LogicalAnd) { EXPECT_THAT(TruthTableTest("true && (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("false && (1 / 0 > 2)"), IsOkAndHolds(false)); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && true"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && false"), IsOkAndHolds(false)); EXPECT_THAT(TruthTableTest("(1 / 0 > 2) && (1 / 0 > 2)"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("true && true"), IsOkAndHolds(true)); EXPECT_THAT(TruthTableTest("true && false"), IsOkAndHolds(false)); EXPECT_THAT(TruthTableTest("false && true"), IsOkAndHolds(false)); EXPECT_THAT(TruthTableTest("false && false"), IsOkAndHolds(false)); } TEST(Exercise3, Ternary) { EXPECT_THAT(TruthTableTest("(1 / 0 > 2) ? false : false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("true ? (1 / 0 > 2) : false"), StatusIs(absl::StatusCode::kInvalidArgument, "divide by zero")); EXPECT_THAT(TruthTableTest("false ? (1 / 0 > 2) : false"), IsOkAndHolds(false)); } TEST(Exercise3Context, BadFieldAccess) { AttributeContext context; // This type of error is normally caught by the type checker, to allow // it to pass we use the dyn() operator to defer checking to runtime. // typo-ed field name from 'request.host' EXPECT_THAT( CompileAndEvaluateWithContext( "dyn(request).hostname == 'localhost' && true", context), StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); EXPECT_THAT(CompileAndEvaluateWithContext( "dyn(request).hostname == 'localhost' && false", context), IsOkAndHolds(false)); EXPECT_THAT(CompileAndEvaluateWithContext( "dyn(request).hostname == 'localhost' || true", context), IsOkAndHolds(true)); EXPECT_THAT( CompileAndEvaluateWithContext( "dyn(request).hostname == 'localhost' || false", context), StatusIs(absl::StatusCode::kNotFound, "no_such_field : hostname")); } } // namespace } // namespace cel_codelab ================================================ FILE: codelab/solutions/exercise4.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "codelab/exercise4.h" #include #include "cel/expr/checked.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "codelab/cel_compiler.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "eval/public/activation.h" #include "eval/public/activation_bind_helper.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel_codelab { namespace { using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::BindProtoToActivation; using ::google::api::expr::runtime::CelError; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelMap; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::FunctionAdapter; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::rpc::context::AttributeContext; // Handle the parametric type overload with a single generic CelValue overload. absl::StatusOr ContainsExtensionFunction(google::protobuf::Arena* arena, const CelMap* map, CelValue::StringHolder key, const CelValue& value) { absl::optional entry = (*map)[CelValue::CreateString(key)]; if (!entry.has_value()) { return false; } if (value.IsInt64() && entry->IsInt64()) { return value.Int64OrDie() == entry->Int64OrDie(); } else if (value.IsString() && entry->IsString()) { return value.StringOrDie().value() == entry->StringOrDie().value(); } return false; } absl::StatusOr> MakeConfiguredCompiler() { // Setup for handling for protobuf types. // Using the generated descriptor pool is simpler to configure, but often // adds more types than necessary. google::protobuf::LinkMessageReflection(); CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); // Adds fields of AttributeContext as variables. CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddContextDeclaration( AttributeContext::descriptor()->full_name())); // Codelab part 1: // Add a declaration for the map.contains(string, V) function. auto& checker_builder = builder->GetCheckerBuilder(); // Note: we use MakeMemberOverloadDecl instead of MakeOverloadDecl // because the function is receiver style, meaning that it is called as // e1.f(e2) instead of f(e1, e2). CEL_ASSIGN_OR_RETURN( cel::FunctionDecl decl, cel::MakeFunctionDecl( "contains", cel::MakeMemberOverloadDecl( "map_contains_string_string", cel::BoolType(), cel::MapType(checker_builder.arena(), cel::StringType(), cel::TypeParamType("V")), cel::StringType(), cel::TypeParamType("V")))); // Note: we use MergeFunction instead of AddFunction because we are adding // an overload to an already declared function with the same name. CEL_RETURN_IF_ERROR(checker_builder.MergeFunction(decl)); return builder->Build(); } class Evaluator { public: Evaluator() { builder_ = CreateCelExpressionBuilder( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), options_); } absl::Status SetupEvaluatorEnvironment() { CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder_->GetRegistry())); // Codelab part 2: // Register the map.contains(string, string) function. // Hint: use `FunctionAdapter::CreateAndRegister` to adapt from a free // function ContainsExtensionFunction. using AdapterT = FunctionAdapter, const CelMap*, CelValue::StringHolder, CelValue>; CEL_RETURN_IF_ERROR(AdapterT::CreateAndRegister( "contains", /*receiver_style=*/true, &ContainsExtensionFunction, builder_->GetRegistry())); return absl::OkStatus(); } absl::StatusOr Evaluate(const CheckedExpr& expr, const AttributeContext& context) { Activation activation; CEL_RETURN_IF_ERROR(BindProtoToActivation(&context, &arena_, &activation)); CEL_ASSIGN_OR_RETURN(auto plan, builder_->CreateExpression(&expr)); CEL_ASSIGN_OR_RETURN(CelValue result, plan->Evaluate(activation, &arena_)); if (bool value; result.GetValue(&value)) { return value; } else if (const CelError* value; result.GetValue(&value)) { return *value; } else { return absl::InvalidArgumentError( absl::StrCat("unexpected return type: ", result.DebugString())); } } private: google::protobuf::Arena arena_; std::unique_ptr builder_; InterpreterOptions options_; }; } // namespace absl::StatusOr EvaluateWithExtensionFunction( absl::string_view expr, const AttributeContext& context) { // Prepare a checked expression. CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, MakeConfiguredCompiler()); CEL_ASSIGN_OR_RETURN(auto checked_expr, CompileToCheckedExpr(*compiler, expr)); // Prepare an evaluation environment. Evaluator evaluator; CEL_RETURN_IF_ERROR(evaluator.SetupEvaluatorEnvironment()); // Evaluate a checked expression against a particular activation return evaluator.Evaluate(checked_expr, context); } } // namespace cel_codelab ================================================ FILE: common/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "ast", srcs = ["ast.cc"], hdrs = ["ast.h"], deps = [ ":expr", ":source", "//common/ast:metadata", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "ast_test", srcs = ["ast_test.cc"], deps = [ ":ast", ":expr", ":source", "//internal:testing", "@com_google_absl//absl/container:flat_hash_map", ], ) cc_library( name = "expr", srcs = ["expr.cc"], hdrs = ["expr.h"], deps = [ ":constant", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) cc_test( name = "expr_test", srcs = ["expr_test.cc"], deps = [ ":expr", "//internal:testing", ], ) cc_library( name = "navigable_ast", srcs = ["navigable_ast.cc"], hdrs = ["navigable_ast.h"], deps = [ ":ast_traverse", ":ast_visitor", ":ast_visitor_base", ":expr", "//common/ast:navigable_ast_internal", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], ) cc_test( name = "navigable_ast_test", srcs = ["navigable_ast_test.cc"], deps = [ ":ast", ":expr", ":navigable_ast", ":source", ":standard_definitions", "//internal:status_macros", "//internal:testing", "//parser", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) cc_library( name = "decl", srcs = ["decl.cc"], hdrs = ["decl.h"], deps = [ ":constant", ":type", ":type_kind", "//common/internal:signature", "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_test( name = "decl_test", srcs = ["decl_test.cc"], deps = [ ":constant", ":decl", ":type", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "reference", srcs = ["reference.cc"], hdrs = ["reference.h"], deps = [ ":constant", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", ], ) cc_test( name = "reference_test", srcs = ["reference_test.cc"], deps = [ ":constant", ":reference", "//internal:testing", ], ) cc_library( name = "ast_rewrite", srcs = ["ast_rewrite.cc"], hdrs = ["ast_rewrite.h"], deps = [ ":ast_visitor", ":constant", ":expr", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) cc_test( name = "ast_rewrite_test", srcs = ["ast_rewrite_test.cc"], deps = [ ":ast", ":ast_rewrite", ":ast_visitor", ":expr", "//common/ast:expr_proto", "//extensions/protobuf:ast_converters", "//internal:testing", "//parser", "@com_google_absl//absl/status:status_matchers", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "ast_traverse", srcs = ["ast_traverse.cc"], hdrs = ["ast_traverse.h"], deps = [ ":ast_visitor", ":constant", ":expr", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:variant", ], ) cc_test( name = "ast_traverse_test", srcs = ["ast_traverse_test.cc"], deps = [ ":ast_traverse", ":ast_visitor", ":constant", ":expr", "//internal:testing", ], ) cc_library( name = "ast_visitor", hdrs = ["ast_visitor.h"], deps = [ ":constant", ":expr", ], ) cc_library( name = "ast_visitor_base", hdrs = ["ast_visitor_base.h"], deps = [ ":ast_visitor", ":constant", ":expr", ], ) cc_library( name = "constant", srcs = ["constant.cc"], hdrs = ["constant.h"], deps = [ "//internal:strings", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", ], ) cc_test( name = "constant_test", srcs = ["constant_test.cc"], deps = [ ":constant", "//internal:testing", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", ], ) cc_library( name = "expr_factory", hdrs = ["expr_factory.h"], deps = [ ":constant", ":expr", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) cc_library( name = "operators", srcs = [ "operators.cc", ], hdrs = [ "operators.h", ], deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "any", srcs = ["any.cc"], hdrs = ["any.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_protobuf//:any_cc_proto", ], ) cc_test( name = "any_test", srcs = ["any_test.cc"], deps = [ ":any", "//internal:testing", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:any_cc_proto", ], ) cc_library( name = "casting", hdrs = ["casting.h"], deps = [ "//common/internal:casting", "@com_google_absl//absl/base:core_headers", ], ) cc_library( name = "json", hdrs = ["json.h"], ) cc_library( name = "kind", srcs = ["kind.cc"], hdrs = ["kind.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], ) cc_test( name = "kind_test", srcs = ["kind_test.cc"], deps = [ ":kind", ":type_kind", ":value_kind", "//internal:testing", ], ) cc_library( name = "memory", srcs = ["memory.cc"], hdrs = ["memory.h"], deps = [ ":allocator", ":arena", ":data", ":native_type", ":reference_count", "//common/internal:metadata", "//common/internal:reference_count", "//internal:exceptions", "//internal:to_address", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/numeric:bits", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "memory_test", srcs = ["memory_test.cc"], deps = [ ":allocator", ":data", ":memory", ":native_type", "//common/internal:reference_count", "//internal:testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/debugging:leak_check", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_library( name = "memory_testing", testonly = True, hdrs = ["memory_testing.h"], deps = [ ":memory", "//internal:testing", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "type_testing", testonly = True, hdrs = ["type_testing.h"], ) cc_library( name = "value_testing", testonly = True, srcs = ["value_testing.cc"], hdrs = ["value_testing.h"], deps = [ ":value", ":value_kind", "//internal:equals_text_proto", "//internal:parse_text_proto", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//internal:testing_no_main", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "value_testing_test", srcs = ["value_testing_test.cc"], deps = [ ":value", ":value_testing", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/time", ], ) cc_library( name = "type_kind", hdrs = ["type_kind.h"], deps = [ ":kind", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], ) cc_library( name = "value_kind", hdrs = ["value_kind.h"], deps = [ ":kind", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], ) cc_library( name = "source", srcs = ["source.cc"], hdrs = ["source.h"], deps = [ "//internal:unicode", "//internal:utf8", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) cc_test( name = "source_test", srcs = ["source_test.cc"], deps = [ ":source", "//internal:testing", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "native_type", hdrs = ["native_type.h"], deps = [ ":typeinfo", ], ) cc_library( name = "type", srcs = glob( [ "types/*.cc", ], exclude = [ "types/*_test.cc", ], ) + [ "type.cc", "type_introspector.cc", ], hdrs = glob( [ "types/*.h", ], exclude = [ "types/*_test.h", ], ) + [ "type.h", "type_introspector.h", ], deps = [ ":type_kind", "//internal:string_pool", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_absl//absl/utility", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "type_test", srcs = glob([ "types/*_test.cc", ]) + [ "type_test.cc", ], deps = [ ":memory", ":type", ":type_kind", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/hash", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:die_if_null", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "value", srcs = glob( [ "values/*.cc", ], exclude = [ "values/*_test.cc", ], ) + [ "legacy_value.cc", "value.cc", ], hdrs = glob( [ "values/*.h", ], exclude = [ "values/*_test.h", ], ) + [ "legacy_value.h", "type_reflector.h", "value.h", ], deps = [ ":allocator", ":any", ":arena", ":casting", ":kind", ":memory", ":native_type", ":optional_ref", ":type", ":typeinfo", ":unknown", ":value_kind", "//base:attributes", "//common/internal:byte_string", "//common/internal:reference_count", "//eval/internal:cel_value_equal", "//eval/public:cel_value", "//eval/public:message_wrapper", "//eval/public/containers:field_backed_list_impl", "//eval/public/containers:field_backed_map_impl", "//eval/public/structs:cel_proto_wrap_util", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_info_apis", "//eval/public/structs:proto_message_type_adapter", "//extensions/protobuf/internal:map_reflection", "//extensions/protobuf/internal:qualify", "//internal:casts", "//internal:empty_descriptors", "//internal:json", "//internal:manual", "//internal:message_equality", "//internal:number", "//internal:protobuf_runtime_version", "//internal:status_macros", "//internal:strings", "//internal:time", "//internal:utf8", "//internal:well_known_types", "//runtime:runtime_options", "//runtime/internal:errors", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_absl//absl/utility", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", "@com_google_protobuf//src/google/protobuf/io", ], ) cc_test( name = "value_test", srcs = glob([ "values/*_test.cc", ]) + [ "type_reflector_test.cc", "value_test.cc", ], deps = [ ":casting", ":memory", ":native_type", ":type", ":value", ":value_kind", ":value_testing", "//base:attributes", "//internal:parse_text_proto", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:runtime_options", "//runtime/internal:errors", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:cord_test_helpers", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:type_cc_proto", "@com_google_protobuf//src/google/protobuf/io", ], ) cc_library( name = "unknown", hdrs = ["unknown.h"], deps = ["//base/internal:unknown_set"], ) alias( name = "legacy_value", actual = ":value", ) cc_library( name = "arena", hdrs = ["arena.h"], deps = [ "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "reference_count", hdrs = ["reference_count.h"], deps = ["//common/internal:reference_count"], ) cc_library( name = "allocator", hdrs = ["allocator.h"], deps = [ ":arena", ":data", "//internal:new", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/numeric:bits", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "allocator_test", srcs = ["allocator_test.cc"], deps = [ ":allocator", "//internal:testing", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "data", hdrs = ["data.h"], deps = [ "//common/internal:metadata", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "data_test", srcs = ["data_test.cc"], deps = [ ":data", "//common/internal:reference_count", "//internal:testing", "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "optional_ref", hdrs = ["optional_ref.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/utility", ], ) cc_library( name = "arena_string", hdrs = [ "arena_string.h", "arena_string_view.h", ], deps = [ "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "arena_string_test", srcs = [ "arena_string_test.cc", "arena_string_view_test.cc", ], tags = ["no_test_msvc"], deps = [ ":arena_string", "//internal:testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/hash", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "arena_string_pool", hdrs = ["arena_string_pool.h"], deps = [ ":arena_string", "//internal:string_pool", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "arena_string_pool_test", srcs = ["arena_string_pool_test.cc"], tags = ["no_test_msvc"], deps = [ ":arena_string_pool", "//internal:testing", "@com_google_absl//absl/strings:cord_test_helpers", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "minimal_descriptor_pool", srcs = ["minimal_descriptor_pool.cc"], hdrs = ["minimal_descriptor_pool.h"], deps = [ "//internal:minimal_descriptors", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "minimal_descriptor_pool_test", srcs = ["minimal_descriptor_pool_test.cc"], deps = [ ":minimal_descriptor_pool", "//internal:testing", "@com_google_absl//absl/status:status_matchers", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "minimal_descriptor_database", srcs = ["minimal_descriptor_database.cc"], hdrs = ["minimal_descriptor_database.h"], deps = [ "//internal:minimal_descriptors", "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "minimal_descriptor_database_test", srcs = ["minimal_descriptor_database_test.cc"], deps = [ ":minimal_descriptor_database", "//internal:testing", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "function_descriptor", srcs = [ "function_descriptor.cc", ], hdrs = [ "function_descriptor.h", ], deps = [ ":kind", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) cc_library( name = "decl_proto", srcs = ["decl_proto.cc"], hdrs = ["decl_proto.h"], deps = [ ":decl", ":type", ":type_proto", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "decl_proto_test", srcs = ["decl_proto_test.cc"], deps = [ ":decl", ":decl_proto", ":decl_proto_v1alpha1", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:variant", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "decl_proto_v1alpha1", srcs = ["decl_proto_v1alpha1.cc"], hdrs = ["decl_proto_v1alpha1.h"], deps = [ ":decl", ":decl_proto", ":type", ":type_proto", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "type_proto", srcs = ["type_proto.cc"], hdrs = ["type_proto.h"], deps = [ ":type", ":type_kind", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "type_proto_test", srcs = ["type_proto_test.cc"], deps = [ ":type", ":type_kind", ":type_proto", "//internal:proto_matchers", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "ast_proto", srcs = ["ast_proto.cc"], hdrs = ["ast_proto.h"], deps = [ ":ast", ":constant", ":expr", "//base:ast", "//common/ast:constant_proto", "//common/ast:expr_proto", "//common/ast:source_info_proto", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_test( name = "ast_proto_test", srcs = [ "ast_proto_test.cc", ], deps = [ ":ast", ":ast_proto", ":decl", ":expr", ":source", ":type", "//compiler", "//compiler:compiler_factory", "//compiler:optional", "//compiler:standard_library", "//extensions:comprehensions_v2", "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_library( name = "standard_definitions", hdrs = [ "standard_definitions.h", ], deps = [ "@com_google_absl//absl/strings:string_view", ], ) cc_library( name = "typeinfo", srcs = ["typeinfo.cc"], hdrs = ["typeinfo.h"], deps = [ "@com_google_absl//absl/base", "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", ], ) cc_test( name = "typeinfo_test", srcs = ["typeinfo_test.cc"], deps = [ ":typeinfo", "//internal:testing", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/strings", ], ) cc_library( name = "container", srcs = ["container.cc"], hdrs = ["container.h"], deps = [ "//internal:lexis", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) cc_test( name = "container_test", srcs = ["container_test.cc"], deps = [ ":container", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", ], ) ================================================ FILE: common/allocator.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_ALLOCATOR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_ALLOCATOR_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/log/die_if_null.h" #include "absl/numeric/bits.h" #include "common/arena.h" #include "common/data.h" #include "internal/new.h" #include "google/protobuf/arena.h" namespace cel { enum class AllocatorKind { kArena = 1, kNewDelete = 2, }; template void AbslStringify(S& sink, AllocatorKind kind) { switch (kind) { case AllocatorKind::kArena: sink.Append("ARENA"); return; case AllocatorKind::kNewDelete: sink.Append("NEW_DELETE"); return; default: sink.Append("ERROR"); return; } } template class NewDeleteAllocator; template class ArenaAllocator; template class Allocator; // `NewDeleteAllocator<>` is a type-erased vocabulary type capable of performing // allocation/deallocation and construction/destruction using memory owned by // `operator new`. template <> class NewDeleteAllocator { public: using size_type = size_t; using difference_type = ptrdiff_t; using propagate_on_container_copy_assignment = std::true_type; using propagate_on_container_move_assignment = std::true_type; using propagate_on_container_swap = std::true_type; using is_always_equal = std::true_type; NewDeleteAllocator() = default; NewDeleteAllocator(const NewDeleteAllocator&) = default; NewDeleteAllocator& operator=(const NewDeleteAllocator&) = default; template >> // NOLINTNEXTLINE(google-explicit-constructor) constexpr NewDeleteAllocator( [[maybe_unused]] const NewDeleteAllocator& other) noexcept {} // Allocates at least `nbytes` bytes with a minimum alignment of `alignment` // from the underlying memory resource. When the underlying memory resource is // `operator new`, `deallocate_bytes` must be called at some point, otherwise // calling `deallocate_bytes` is optional. The caller must not pass an object // constructed in the return memory to `delete_object`, doing so is undefined // behavior. ABSL_MUST_USE_RESULT void* allocate_bytes( size_type nbytes, size_type alignment = alignof(std::max_align_t)) { ABSL_DCHECK(absl::has_single_bit(alignment)); if (nbytes == 0) { return nullptr; } return internal::AlignedNew(nbytes, static_cast(alignment)); } // Deallocates memory previously returned by `allocate_bytes`. void deallocate_bytes( void* p, size_type nbytes, size_type alignment = alignof(std::max_align_t)) noexcept { ABSL_DCHECK((p == nullptr && nbytes == 0) || (p != nullptr && nbytes != 0)); ABSL_DCHECK(absl::has_single_bit(alignment)); internal::SizedAlignedDelete(p, nbytes, static_cast(alignment)); } template ABSL_MUST_USE_RESULT T* allocate_object(size_type n = 1) { return static_cast(allocate_bytes(sizeof(T) * n, alignof(T))); } template void deallocate_object(T* p, size_type n = 1) { deallocate_bytes(p, sizeof(T) * n, alignof(T)); } // Allocates memory suitable for an object of type `T` and constructs the // object by forwarding the provided arguments. If the underlying memory // resource is `operator new` is false, `delete_object` must eventually be // called. template ABSL_MUST_USE_RESULT T* new_object(Args&&... args) { return new T(std::forward(args)...); } // Destructs the object of type `T` located at address `p` and deallocates the // memory, `p` must have been previously returned by `new_object`. template void delete_object(T* p) noexcept { ABSL_DCHECK(p != nullptr); delete p; } void delete_object(std::nullptr_t) = delete; private: template friend class NewDeleteAllocator; }; // `NewDeleteAllocator` is an extension of `NewDeleteAllocator<>` which // adheres to the named C++ requirements for `Allocator`, allowing it to be used // in places which accept custom STL allocators. template class NewDeleteAllocator : public NewDeleteAllocator { public: static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(std::is_object_v, "T must be an object type"); using value_type = T; using pointer = value_type*; using const_pointer = const value_type*; using reference = value_type&; using const_reference = const value_type&; using NewDeleteAllocator::NewDeleteAllocator; template >> // NOLINTNEXTLINE(google-explicit-constructor) constexpr NewDeleteAllocator( [[maybe_unused]] const NewDeleteAllocator& other) noexcept {} pointer allocate(size_type n, const void* /*hint*/ = nullptr) { return reinterpret_cast(internal::AlignedNew( n * sizeof(T), static_cast(alignof(T)))); } #if defined(__cpp_lib_allocate_at_least) && \ __cpp_lib_allocate_at_least >= 202302L std::allocation_result allocate_at_least(size_type n) { void* addr; size_type size; std::tie(addr, size) = internal::SizeReturningAlignedNew( n * sizeof(T), static_cast(alignof(T))); std::allocation_result result; result.ptr = reinterpret_cast(addr); result.count = size / sizeof(T); return result; } #endif void deallocate(pointer p, size_type n) noexcept { internal::SizedAlignedDelete(p, n * sizeof(T), static_cast(alignof(T))); } template void construct(U* p, Args&&... args) { ::new (static_cast(p)) U(std::forward(args)...); } template void destroy(U* p) noexcept { std::destroy_at(p); } }; template inline bool operator==(NewDeleteAllocator, NewDeleteAllocator) noexcept { return true; } template inline bool operator!=(NewDeleteAllocator lhs, NewDeleteAllocator rhs) noexcept { return !operator==(lhs, rhs); } NewDeleteAllocator() -> NewDeleteAllocator; template NewDeleteAllocator(const NewDeleteAllocator&) -> NewDeleteAllocator; // `ArenaAllocator<>` is a type-erased vocabulary type capable of performing // allocation/deallocation and construction/destruction using memory owned by // `google::protobuf::Arena`. template <> class ArenaAllocator { public: using size_type = size_t; using difference_type = ptrdiff_t; using propagate_on_container_copy_assignment = std::true_type; using propagate_on_container_move_assignment = std::true_type; using propagate_on_container_swap = std::true_type; ArenaAllocator() = delete; ArenaAllocator(const ArenaAllocator&) = default; ArenaAllocator& operator=(const ArenaAllocator&) = delete; ArenaAllocator(std::nullptr_t) = delete; template >> // NOLINTNEXTLINE(google-explicit-constructor) constexpr ArenaAllocator(const ArenaAllocator& other) noexcept : arena_(other.arena()) {} // NOLINTNEXTLINE(google-explicit-constructor) ArenaAllocator(google::protobuf::Arena* absl_nonnull arena) noexcept : arena_(ABSL_DIE_IF_NULL(arena)) // Crash OK {} constexpr google::protobuf::Arena* absl_nonnull arena() const noexcept { ABSL_ASSUME(arena_ != nullptr); return arena_; } // Allocates at least `nbytes` bytes with a minimum alignment of `alignment` // from the underlying memory resource. When the underlying memory resource is // `operator new`, `deallocate_bytes` must be called at some point, otherwise // calling `deallocate_bytes` is optional. The caller must not pass an object // constructed in the return memory to `delete_object`, doing so is undefined // behavior. ABSL_MUST_USE_RESULT void* allocate_bytes( size_type nbytes, size_type alignment = alignof(std::max_align_t)) { ABSL_DCHECK(absl::has_single_bit(alignment)); if (nbytes == 0) { return nullptr; } return arena()->AllocateAligned(nbytes, alignment); } // Deallocates memory previously returned by `allocate_bytes`. void deallocate_bytes( void* p, size_type nbytes, size_type alignment = alignof(std::max_align_t)) noexcept { ABSL_DCHECK((p == nullptr && nbytes == 0) || (p != nullptr && nbytes != 0)); ABSL_DCHECK(absl::has_single_bit(alignment)); } template ABSL_MUST_USE_RESULT T* allocate_object(size_type n = 1) { return static_cast(allocate_bytes(sizeof(T) * n, alignof(T))); } template void deallocate_object(T* p, size_type n = 1) { deallocate_bytes(p, sizeof(T) * n, alignof(T)); } // Allocates memory suitable for an object of type `T` and constructs the // object by forwarding the provided arguments. If the underlying memory // resource is `operator new` is false, `delete_object` must eventually be // called. template ABSL_MUST_USE_RESULT T* new_object(Args&&... args) { using U = std::remove_const_t; U* object; if constexpr (google::protobuf::Arena::is_arena_constructable::value) { // Classes derived from `cel::Data` are manually allocated and constructed // as those class support determining whether the destructor is skippable // at runtime. object = google::protobuf::Arena::Create(arena(), std::forward(args)...); } else { if constexpr (ArenaTraits<>::constructible()) { object = ::new (static_cast(arena()->AllocateAligned( sizeof(U), alignof(U)))) U(arena(), std::forward(args)...); } else { object = ::new (static_cast(arena()->AllocateAligned( sizeof(U), alignof(U)))) U(std::forward(args)...); } if constexpr (!ArenaTraits<>::always_trivially_destructible()) { if (!ArenaTraits<>::trivially_destructible(*object)) { arena()->OwnDestructor(object); } } } if constexpr (google::protobuf::Arena::is_arena_constructable::value || std::is_base_of_v) { ABSL_DCHECK_EQ(object->GetArena(), arena()); } return object; } // Destructs the object of type `T` located at address `p` and deallocates the // memory, `p` must have been previously returned by `new_object`. template void delete_object(T* p) noexcept { using U = std::remove_const_t; ABSL_DCHECK(p != nullptr); if constexpr (google::protobuf::Arena::is_arena_constructable::value || std::is_base_of_v) { ABSL_DCHECK_EQ(p->GetArena(), arena()); } } void delete_object(std::nullptr_t) = delete; private: template friend class ArenaAllocator; google::protobuf::Arena* absl_nonnull arena_; }; // `ArenaAllocator` is an extension of `ArenaAllocator<>` which adheres to // the named C++ requirements for `Allocator`, allowing it to be used in places // which accept custom STL allocators. template class ArenaAllocator : public ArenaAllocator { private: using Base = ArenaAllocator; public: static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(std::is_object_v, "T must be an object type"); using value_type = T; using pointer = value_type*; using const_pointer = const value_type*; using reference = value_type&; using const_reference = const value_type&; using ArenaAllocator::ArenaAllocator; template >> // NOLINTNEXTLINE(google-explicit-constructor) constexpr ArenaAllocator(const ArenaAllocator& other) noexcept : Base(other) {} pointer allocate(size_type n, const void* /*hint*/ = nullptr) { return static_cast( arena()->AllocateAligned(n * sizeof(T), alignof(T))); } #if defined(__cpp_lib_allocate_at_least) && \ __cpp_lib_allocate_at_least >= 202302L std::allocation_result allocate_at_least(size_type n) { std::allocation_result result; result.ptr = allocate(n); result.count = n; return result; } #endif void deallocate(pointer, size_type) noexcept {} template void construct(U* p, Args&&... args) { static_assert(!google::protobuf::Arena::is_arena_constructable::value); ::new (static_cast(p)) U(std::forward(args)...); } template void destroy(U* p) noexcept { static_assert(!google::protobuf::Arena::is_arena_constructable::value); std::destroy_at(p); } }; template inline bool operator==(ArenaAllocator lhs, ArenaAllocator rhs) noexcept { return lhs.arena() == rhs.arena(); } template inline bool operator!=(ArenaAllocator lhs, ArenaAllocator rhs) noexcept { return !operator==(lhs, rhs); } ArenaAllocator(google::protobuf::Arena* absl_nonnull) -> ArenaAllocator; template ArenaAllocator(const ArenaAllocator&) -> ArenaAllocator; // `Allocator<>` is a type-erased vocabulary type capable of performing // allocation/deallocation and construction/destruction using memory owned by // `google::protobuf::Arena` or `operator new`. template <> class Allocator { public: using size_type = size_t; using difference_type = ptrdiff_t; using propagate_on_container_copy_assignment = std::true_type; using propagate_on_container_move_assignment = std::true_type; using propagate_on_container_swap = std::true_type; Allocator() = delete; Allocator(const Allocator&) = default; Allocator& operator=(const Allocator&) = delete; Allocator(std::nullptr_t) = delete; template >> // NOLINTNEXTLINE(google-explicit-constructor) constexpr Allocator(const Allocator& other) noexcept : arena_(other.arena_) {} // NOLINTNEXTLINE(google-explicit-constructor) constexpr Allocator(google::protobuf::Arena* absl_nullable arena) noexcept : arena_(arena) {} template // NOLINTNEXTLINE(google-explicit-constructor) constexpr Allocator( [[maybe_unused]] const NewDeleteAllocator& other) noexcept : arena_(nullptr) {} template // NOLINTNEXTLINE(google-explicit-constructor) constexpr Allocator(const ArenaAllocator& other) noexcept : arena_(other.arena()) {} constexpr google::protobuf::Arena* absl_nullable arena() const noexcept { return arena_; } // Allocates at least `nbytes` bytes with a minimum alignment of `alignment` // from the underlying memory resource. When the underlying memory resource is // `operator new`, `deallocate_bytes` must be called at some point, otherwise // calling `deallocate_bytes` is optional. The caller must not pass an object // constructed in the return memory to `delete_object`, doing so is undefined // behavior. ABSL_MUST_USE_RESULT void* allocate_bytes( size_type nbytes, size_type alignment = alignof(std::max_align_t)) { return arena() != nullptr ? ArenaAllocator(arena()).allocate_bytes(nbytes, alignment) : NewDeleteAllocator().allocate_bytes(nbytes, alignment); } // Deallocates memory previously returned by `allocate_bytes`. void deallocate_bytes( void* p, size_type nbytes, size_type alignment = alignof(std::max_align_t)) noexcept { arena() != nullptr ? ArenaAllocator(arena()).deallocate_bytes(p, nbytes, alignment) : NewDeleteAllocator().deallocate_bytes(p, nbytes, alignment); } template ABSL_MUST_USE_RESULT T* allocate_object(size_type n = 1) { return arena() != nullptr ? ArenaAllocator(arena()).allocate_object(n) : NewDeleteAllocator().allocate_object(n); } template void deallocate_object(T* p, size_type n = 1) { arena() != nullptr ? ArenaAllocator(arena()).deallocate_object(p, n) : NewDeleteAllocator().deallocate_object(p, n); } // Allocates memory suitable for an object of type `T` and constructs the // object by forwarding the provided arguments. If the underlying memory // resource is `operator new` is false, `delete_object` must eventually be // called. template ABSL_MUST_USE_RESULT T* new_object(Args&&... args) { return arena() != nullptr ? ArenaAllocator(arena()).new_object( std::forward(args)...) : NewDeleteAllocator().new_object( std::forward(args)...); } // Destructs the object of type `T` located at address `p` and deallocates the // memory, `p` must have been previously returned by `new_object`. template void delete_object(T* p) noexcept { arena() != nullptr ? ArenaAllocator(arena()).delete_object(p) : NewDeleteAllocator().delete_object(p); } void delete_object(std::nullptr_t) = delete; private: template friend class Allocator; google::protobuf::Arena* absl_nullable arena_; }; // `Allocator` is an extension of `Allocator<>` which adheres to the named // C++ requirements for `Allocator`, allowing it to be used in places which // accept custom STL allocators. template class Allocator : public Allocator { public: static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(std::is_object_v, "T must be an object type"); using value_type = T; using pointer = value_type*; using const_pointer = const value_type*; using reference = value_type&; using const_reference = const value_type&; using Allocator::Allocator; template >> // NOLINTNEXTLINE(google-explicit-constructor) constexpr Allocator(const Allocator& other) noexcept : Allocator(other.arena_) {} pointer allocate(size_type n, const void* /*hint*/ = nullptr) { return arena() != nullptr ? ArenaAllocator(arena()).allocate(n) : NewDeleteAllocator().allocate(n); } #if defined(__cpp_lib_allocate_at_least) && \ __cpp_lib_allocate_at_least >= 202302L std::allocation_result allocate_at_least(size_type n) { return arena() != nullptr ? ArenaAllocator(arena()).allocate_at_least(n) : NewDeleteAllocator().allocate_at_least(n); } #endif void deallocate(pointer p, size_type n) noexcept { arena() != nullptr ? ArenaAllocator(arena()).deallocate(p, n) : NewDeleteAllocator().deallocate(p, n); } template void construct(U* p, Args&&... args) { arena() != nullptr ? ArenaAllocator(arena()).construct(p, std::forward(args)...) : NewDeleteAllocator().construct(p, std::forward(args)...); } template void destroy(U* p) noexcept { arena() != nullptr ? ArenaAllocator(arena()).destroy(p) : NewDeleteAllocator().destroy(p); } }; template inline bool operator==(Allocator lhs, Allocator rhs) noexcept { return lhs.arena() == rhs.arena(); } template inline bool operator!=(Allocator lhs, Allocator rhs) noexcept { return !operator==(lhs, rhs); } Allocator(google::protobuf::Arena* absl_nullable) -> Allocator; template Allocator(const Allocator&) -> Allocator; template Allocator(const NewDeleteAllocator&) -> Allocator; template Allocator(const ArenaAllocator&) -> Allocator; template inline NewDeleteAllocator NewDeleteAllocatorFor() noexcept { static_assert(!std::is_void_v); return NewDeleteAllocator(); } template inline Allocator ArenaAllocatorFor( google::protobuf::Arena* absl_nonnull arena) noexcept { static_assert(!std::is_void_v); ABSL_DCHECK(arena != nullptr); return Allocator(arena); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_ALLOCATOR_H_ ================================================ FILE: common/allocator_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This header contains primitives for reference counting, roughly equivalent to // the primitives used to implement `std::shared_ptr`. These primitives should // not be used directly in most cases, instead `cel::ManagedMemory` should be // used instead. #include "common/allocator.h" #include #include "absl/strings/str_cat.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::testing::NotNull; TEST(AllocatorKind, AbslStringify) { EXPECT_EQ(absl::StrCat(AllocatorKind::kArena), "ARENA"); EXPECT_EQ(absl::StrCat(AllocatorKind::kNewDelete), "NEW_DELETE"); EXPECT_EQ(absl::StrCat(static_cast(0)), "ERROR"); } TEST(NewDeleteAllocator, Bytes) { auto allocator = NewDeleteAllocator<>(); void* p = allocator.allocate_bytes(17, 8); EXPECT_THAT(p, NotNull()); allocator.deallocate_bytes(p, 17, 8); } TEST(ArenaAllocator, Bytes) { google::protobuf::Arena arena; auto allocator = ArenaAllocator<>(&arena); void* p = allocator.allocate_bytes(17, 8); EXPECT_THAT(p, NotNull()); allocator.deallocate_bytes(p, 17, 8); } struct TrivialObject { char data[17]; }; TEST(NewDeleteAllocator, NewDeleteObject) { auto allocator = NewDeleteAllocator<>(); auto* p = allocator.new_object(); EXPECT_THAT(p, NotNull()); allocator.delete_object(p); } TEST(ArenaAllocator, NewDeleteObject) { google::protobuf::Arena arena; auto allocator = ArenaAllocator<>(&arena); auto* p = allocator.new_object(); EXPECT_THAT(p, NotNull()); allocator.delete_object(p); } TEST(NewDeleteAllocator, Object) { auto allocator = NewDeleteAllocator<>(); auto* p = allocator.allocate_object(); EXPECT_THAT(p, NotNull()); allocator.deallocate_object(p); } TEST(ArenaAllocator, Object) { google::protobuf::Arena arena; auto allocator = ArenaAllocator<>(&arena); auto* p = allocator.allocate_object(); EXPECT_THAT(p, NotNull()); allocator.deallocate_object(p); } TEST(NewDeleteAllocator, ObjectArray) { auto allocator = NewDeleteAllocator<>(); auto* p = allocator.allocate_object(2); EXPECT_THAT(p, NotNull()); allocator.deallocate_object(p, 2); } TEST(ArenaAllocator, ObjectArray) { google::protobuf::Arena arena; auto allocator = ArenaAllocator<>(&arena); auto* p = allocator.allocate_object(2); EXPECT_THAT(p, NotNull()); allocator.deallocate_object(p, 2); } TEST(NewDeleteAllocator, T) { auto allocator = NewDeleteAllocatorFor(); auto* p = allocator.allocate(1); EXPECT_THAT(p, NotNull()); allocator.construct(p); allocator.destroy(p); allocator.deallocate(p, 1); } TEST(ArenaAllocator, T) { google::protobuf::Arena arena; auto allocator = ArenaAllocatorFor(&arena); auto* p = allocator.allocate(1); EXPECT_THAT(p, NotNull()); allocator.construct(p); allocator.destroy(p); allocator.deallocate(p, 1); } TEST(NewDeleteAllocator, CopyConstructible) { EXPECT_TRUE( (std::is_trivially_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE( (std::is_trivially_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); } TEST(ArenaAllocator, CopyConstructible) { EXPECT_TRUE((std::is_trivially_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE((std::is_trivially_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const ArenaAllocator&>)); } TEST(Allocator, CopyConstructible) { EXPECT_TRUE((std::is_trivially_constructible_v, const Allocator&>)); EXPECT_TRUE((std::is_trivially_constructible_v, const Allocator&>)); EXPECT_TRUE( (std::is_constructible_v, const Allocator&>)); EXPECT_TRUE( (std::is_constructible_v, const Allocator&>)); EXPECT_TRUE( (std::is_constructible_v, const Allocator&>)); EXPECT_TRUE( (std::is_constructible_v, const Allocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE((std::is_constructible_v, const NewDeleteAllocator&>)); EXPECT_TRUE( (std::is_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE( (std::is_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE( (std::is_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE( (std::is_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE( (std::is_constructible_v, const ArenaAllocator&>)); EXPECT_TRUE( (std::is_constructible_v, const ArenaAllocator&>)); } } // namespace } // namespace cel ================================================ FILE: common/any.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/any.h" #include "absl/base/nullability.h" #include "absl/strings/string_view.h" namespace cel { bool ParseTypeUrl(absl::string_view type_url, absl::string_view* absl_nullable prefix, absl::string_view* absl_nullable type_name) { auto pos = type_url.find_last_of('/'); if (pos == absl::string_view::npos || pos + 1 == type_url.size()) { return false; } if (prefix) { *prefix = type_url.substr(0, pos + 1); } if (type_name) { *type_name = type_url.substr(pos + 1); } return true; } } // namespace cel ================================================ FILE: common/any.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_ANY_H_ #define THIRD_PARTY_CEL_CPP_COMMON_ANY_H_ #include #include "google/protobuf/any.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" namespace cel { inline google::protobuf::Any MakeAny(absl::string_view type_url, const absl::Cord& value) { google::protobuf::Any any; any.set_type_url(type_url); any.set_value(static_cast(value)); return any; } inline google::protobuf::Any MakeAny(absl::string_view type_url, absl::string_view value) { google::protobuf::Any any; any.set_type_url(type_url); any.set_value(value); return any; } inline absl::Cord GetAnyValueAsCord(const google::protobuf::Any& any) { return absl::Cord(any.value()); } inline std::string GetAnyValueAsString(const google::protobuf::Any& any) { return std::string(any.value()); } inline void SetAnyValueFromCord(google::protobuf::Any* absl_nonnull any, const absl::Cord& value) { any->set_value(static_cast(value)); } inline absl::string_view GetAnyValueAsStringView( const google::protobuf::Any& any ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return absl::string_view(any.value()); } inline constexpr absl::string_view kTypeGoogleApisComPrefix = "type.googleapis.com/"; inline std::string MakeTypeUrlWithPrefix(absl::string_view prefix, absl::string_view type_name) { return absl::StrCat(absl::StripSuffix(prefix, "/"), "/", type_name); } inline std::string MakeTypeUrl(absl::string_view type_name) { return MakeTypeUrlWithPrefix(kTypeGoogleApisComPrefix, type_name); } bool ParseTypeUrl(absl::string_view type_url, absl::string_view* absl_nullable prefix, absl::string_view* absl_nullable type_name); inline bool ParseTypeUrl(absl::string_view type_url, absl::string_view* absl_nullable type_name) { return ParseTypeUrl(type_url, nullptr, type_name); } inline bool ParseTypeUrl(absl::string_view type_url) { return ParseTypeUrl(type_url, nullptr); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_ANY_H_ ================================================ FILE: common/any_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/any.h" #include #include "google/protobuf/any.pb.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "internal/testing.h" namespace cel { namespace { TEST(Any, Value) { google::protobuf::Any any; std::string scratch; SetAnyValueFromCord(&any, absl::Cord("Hello World!")); EXPECT_EQ(GetAnyValueAsCord(any), "Hello World!"); EXPECT_EQ(GetAnyValueAsString(any), "Hello World!"); EXPECT_EQ(GetAnyValueAsStringView(any, scratch), "Hello World!"); } TEST(MakeTypeUrlWithPrefix, Basic) { EXPECT_EQ(MakeTypeUrlWithPrefix("foo", "bar.Baz"), "foo/bar.Baz"); EXPECT_EQ(MakeTypeUrlWithPrefix("foo/", "bar.Baz"), "foo/bar.Baz"); } TEST(MakeTypeUrl, Basic) { EXPECT_EQ(MakeTypeUrl("bar.Baz"), "type.googleapis.com/bar.Baz"); } TEST(ParseTypeUrl, Valid) { EXPECT_TRUE(ParseTypeUrl("type.googleapis.com/bar.Baz")); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com")); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/")); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/foo/")); } TEST(ParseTypeUrl, TypeName) { absl::string_view type_name; EXPECT_TRUE(ParseTypeUrl("type.googleapis.com/bar.Baz", &type_name)); EXPECT_EQ(type_name, "bar.Baz"); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com", &type_name)); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/", &type_name)); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/foo/", &type_name)); } TEST(ParseTypeUrl, PrefixAndTypeName) { absl::string_view prefix; absl::string_view type_name; EXPECT_TRUE(ParseTypeUrl("type.googleapis.com/bar.Baz", &prefix, &type_name)); EXPECT_EQ(prefix, "type.googleapis.com/"); EXPECT_EQ(type_name, "bar.Baz"); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com", &prefix, &type_name)); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/", &prefix, &type_name)); EXPECT_FALSE(ParseTypeUrl("type.googleapis.com/foo/", &prefix, &type_name)); } } // namespace } // namespace cel ================================================ FILE: common/arena.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_H_ #define THIRD_PARTY_CEL_CPP_COMMON_ARENA_H_ #include #include #include "absl/base/nullability.h" #include "google/protobuf/arena.h" namespace cel { template struct ArenaTraits; namespace common_internal { template struct AssertArenaType : std::false_type { static_assert(!std::is_void_v, "T must not be void"); static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_array_v, "T must not be an array"); }; template struct ArenaTraitsConstructible { using type = std::false_type; }; template struct ArenaTraitsConstructible< T, std::void_t::constructible)>> { using type = typename ArenaTraits::constructible; }; template std::enable_if_t::value, google::protobuf::Arena* absl_nullable> GetArena(const T* absl_nullable ptr) { return ptr != nullptr ? ptr->GetArena() : nullptr; } template std::enable_if_t::value, google::protobuf::Arena* absl_nullable> GetArena([[maybe_unused]] const T* absl_nullable ptr) { return nullptr; } template struct HasArenaTraitsTriviallyDestructible : std::false_type {}; template struct HasArenaTraitsTriviallyDestructible< T, std::void_t::trivially_destructible( std::declval()))>> : std::true_type {}; } // namespace common_internal template <> struct ArenaTraits { template using constructible = std::disjunction< typename common_internal::AssertArenaType::type, typename common_internal::ArenaTraitsConstructible::type>; template using always_trivially_destructible = std::disjunction::type, std::is_trivially_destructible>; template static bool trivially_destructible(const U& obj) { static_assert(!std::is_void_v, "T must not be void"); static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_array_v, "T must not be an array"); if constexpr (always_trivially_destructible()) { return true; } else if constexpr (google::protobuf::Arena::is_destructor_skippable::value) { return obj.GetArena() != nullptr; } else if constexpr (common_internal::HasArenaTraitsTriviallyDestructible< U>::value) { return ArenaTraits::trivially_destructible(obj); } else { return false; } } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_H_ ================================================ FILE: common/arena_string.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ #define THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ #include #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "common/arena_string_view.h" #include "google/protobuf/arena.h" namespace cel { class ArenaStringPool; // Bug in current Abseil LTS. Fixed in // https://github.com/abseil/abseil-cpp/commit/fd7713cb9a97c49096211ff40de280b6cebbb21c // which is not yet in an LTS. #if defined(__clang__) && (!defined(__clang_major__) || __clang_major__ >= 13) #define CEL_ATTRIBUTE_ARENA_STRING_OWNER ABSL_ATTRIBUTE_OWNER #else #define CEL_ATTRIBUTE_ARENA_STRING_OWNER #endif namespace common_internal { enum class ArenaStringKind : unsigned int { kSmall = 0, kLarge, }; struct ArenaStringSmallRep final { ArenaStringKind kind : 1; uint8_t size : 7; char data[23 - sizeof(google::protobuf::Arena*)]; google::protobuf::Arena* absl_nullable arena; }; struct ArenaStringLargeRep final { ArenaStringKind kind : 1; size_t size : sizeof(size_t) * 8 - 1; const char* absl_nonnull data; google::protobuf::Arena* absl_nullable arena; }; inline constexpr size_t kArenaStringSmallCapacity = sizeof(ArenaStringSmallRep::data); union ArenaStringRep final { struct { ArenaStringKind kind : 1; }; ArenaStringSmallRep small; ArenaStringLargeRep large; }; } // namespace common_internal // `ArenaString` is a read-only string which is either backed by a static string // literal or owned by the `ArenaStringPool` that created it. It is compatible // with `absl::string_view` and is implicitly convertible to it. class CEL_ATTRIBUTE_ARENA_STRING_OWNER ArenaString final { public: using traits_type = std::char_traits; using value_type = char; using pointer = char*; using const_pointer = const char*; using reference = char&; using const_reference = const char&; using const_iterator = const_pointer; using iterator = const_iterator; using const_reverse_iterator = std::reverse_iterator; using reverse_iterator = const_reverse_iterator; using size_type = size_t; using difference_type = ptrdiff_t; using absl_internal_is_view = std::false_type; ArenaString() : ArenaString(static_cast(nullptr)) {} ArenaString(const ArenaString&) = default; ArenaString& operator=(const ArenaString&) = default; explicit ArenaString( google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : ArenaString(absl::string_view(), arena) {} ArenaString(std::nullptr_t) = delete; ArenaString(absl::string_view string, google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { if (string.size() <= common_internal::kArenaStringSmallCapacity) { rep_.small.kind = common_internal::ArenaStringKind::kSmall; rep_.small.size = string.size(); std::memcpy(rep_.small.data, string.data(), string.size()); rep_.small.arena = arena; } else { rep_.large.kind = common_internal::ArenaStringKind::kLarge; rep_.large.size = string.size(); rep_.large.data = string.data(); rep_.large.arena = arena; } } ArenaString(absl::string_view, std::nullptr_t) = delete; explicit ArenaString(ArenaStringView other) : ArenaString(absl::implicit_cast(other), other.arena()) {} google::protobuf::Arena* absl_nullable arena() const { switch (rep_.kind) { case common_internal::ArenaStringKind::kSmall: return rep_.small.arena; case common_internal::ArenaStringKind::kLarge: return rep_.large.arena; } } size_type size() const { switch (rep_.kind) { case common_internal::ArenaStringKind::kSmall: return rep_.small.size; case common_internal::ArenaStringKind::kLarge: return rep_.large.size; } } bool empty() const { return size() == 0; } size_type max_size() const { return std::numeric_limits::max() >> 1; } absl_nonnull const_pointer data() const ABSL_ATTRIBUTE_LIFETIME_BOUND { switch (rep_.kind) { case common_internal::ArenaStringKind::kSmall: return rep_.small.data; case common_internal::ArenaStringKind::kLarge: return rep_.large.data; } } const_reference front() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(!empty()); return data()[0]; } const_reference back() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(!empty()); return data()[size() - 1]; } const_reference operator[](size_type index) const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK_LT(index, size()); return data()[index]; } void remove_prefix(size_type n) { ABSL_DCHECK_LE(n, size()); switch (rep_.kind) { case common_internal::ArenaStringKind::kSmall: std::memmove(rep_.small.data, rep_.small.data + n, rep_.small.size - n); rep_.small.size = rep_.small.size - n; break; case common_internal::ArenaStringKind::kLarge: rep_.large.data += n; rep_.large.size = rep_.large.size - n; break; } } void remove_suffix(size_type n) { ABSL_DCHECK_LE(n, size()); switch (rep_.kind) { case common_internal::ArenaStringKind::kSmall: rep_.small.size = rep_.small.size - n; break; case common_internal::ArenaStringKind::kLarge: rep_.large.size = rep_.large.size - n; break; } } const_iterator begin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return data(); } const_iterator cbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return begin(); } const_iterator end() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return data() + size(); } const_iterator cend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return end(); } const_reverse_iterator rbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::make_reverse_iterator(end()); } const_reverse_iterator crbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return rbegin(); } const_reverse_iterator rend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::make_reverse_iterator(begin()); } const_reverse_iterator crend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return rend(); } private: friend class ArenaStringView; common_internal::ArenaStringRep rep_; }; inline ArenaStringView::ArenaStringView( const ArenaString& arena_string ABSL_ATTRIBUTE_LIFETIME_BOUND) { switch (arena_string.rep_.kind) { case common_internal::ArenaStringKind::kSmall: string_ = absl::string_view(arena_string.rep_.small.data, arena_string.rep_.small.size); arena_ = arena_string.rep_.small.arena; break; case common_internal::ArenaStringKind::kLarge: string_ = absl::string_view(arena_string.rep_.large.data, arena_string.rep_.large.size); arena_ = arena_string.rep_.large.arena; break; } } inline ArenaStringView& ArenaStringView::operator=( const ArenaString& arena_string ABSL_ATTRIBUTE_LIFETIME_BOUND) { switch (arena_string.rep_.kind) { case common_internal::ArenaStringKind::kSmall: string_ = absl::string_view(arena_string.rep_.small.data, arena_string.rep_.small.size); arena_ = arena_string.rep_.small.arena; break; case common_internal::ArenaStringKind::kLarge: string_ = absl::string_view(arena_string.rep_.large.data, arena_string.rep_.large.size); arena_ = arena_string.rep_.large.arena; break; } return *this; } inline bool operator==(const ArenaString& lhs, const ArenaString& rhs) { return absl::implicit_cast(lhs) == absl::implicit_cast(rhs); } inline bool operator==(const ArenaString& lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) == rhs; } inline bool operator==(absl::string_view lhs, const ArenaString& rhs) { return lhs == absl::implicit_cast(rhs); } inline bool operator!=(const ArenaString& lhs, const ArenaString& rhs) { return absl::implicit_cast(lhs) != absl::implicit_cast(rhs); } inline bool operator!=(const ArenaString& lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) != rhs; } inline bool operator!=(absl::string_view lhs, const ArenaString& rhs) { return lhs != absl::implicit_cast(rhs); } inline bool operator<(const ArenaString& lhs, const ArenaString& rhs) { return absl::implicit_cast(lhs) < absl::implicit_cast(rhs); } inline bool operator<(const ArenaString& lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) < rhs; } inline bool operator<(absl::string_view lhs, const ArenaString& rhs) { return lhs < absl::implicit_cast(rhs); } inline bool operator<=(const ArenaString& lhs, const ArenaString& rhs) { return absl::implicit_cast(lhs) <= absl::implicit_cast(rhs); } inline bool operator<=(const ArenaString& lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) <= rhs; } inline bool operator<=(absl::string_view lhs, const ArenaString& rhs) { return lhs <= absl::implicit_cast(rhs); } inline bool operator>(const ArenaString& lhs, const ArenaString& rhs) { return absl::implicit_cast(lhs) > absl::implicit_cast(rhs); } inline bool operator>(const ArenaString& lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) > rhs; } inline bool operator>(absl::string_view lhs, const ArenaString& rhs) { return lhs > absl::implicit_cast(rhs); } inline bool operator>=(const ArenaString& lhs, const ArenaString& rhs) { return absl::implicit_cast(lhs) >= absl::implicit_cast(rhs); } inline bool operator>=(const ArenaString& lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) >= rhs; } inline bool operator>=(absl::string_view lhs, const ArenaString& rhs) { return lhs >= absl::implicit_cast(rhs); } template H AbslHashValue(H state, const ArenaString& arena_string) { return H::combine(std::move(state), absl::implicit_cast(arena_string)); } #undef CEL_ATTRIBUTE_ARENA_STRING_OWNER } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_H_ ================================================ FILE: common/arena_string_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "common/arena_string_view.h" #include "internal/string_pool.h" #include "google/protobuf/arena.h" namespace cel { class ArenaStringPool; absl_nonnull std::unique_ptr NewArenaStringPool( google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); class ArenaStringPool final { public: ArenaStringPool(const ArenaStringPool&) = delete; ArenaStringPool(ArenaStringPool&&) = delete; ArenaStringPool& operator=(const ArenaStringPool&) = delete; ArenaStringPool& operator=(ArenaStringPool&&) = delete; ArenaStringView InternString(const char* absl_nullable string) { return ArenaStringView(strings_.InternString(string), strings_.arena()); } ArenaStringView InternString(absl::string_view string) { return ArenaStringView(strings_.InternString(string), strings_.arena()); } ArenaStringView InternString(std::string&& string) { return ArenaStringView(strings_.InternString(std::move(string)), strings_.arena()); } ArenaStringView InternString(const absl::Cord& string) { return ArenaStringView(strings_.InternString(string), strings_.arena()); } ArenaStringView InternString(ArenaStringView string) { if (string.arena() == strings_.arena()) { return string; } return InternString(absl::implicit_cast(string)); } private: friend absl_nonnull std::unique_ptr NewArenaStringPool( google::protobuf::Arena* absl_nonnull); explicit ArenaStringPool(google::protobuf::Arena* absl_nonnull arena) : strings_(arena) {} internal::StringPool strings_; }; inline absl_nonnull std::unique_ptr NewArenaStringPool( google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { return std::unique_ptr(new ArenaStringPool(arena)); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_POOL_H_ ================================================ FILE: common/arena_string_pool_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/arena_string_pool.h" #include #include "absl/strings/cord_test_helpers.h" #include "absl/strings/string_view.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { TEST(ArenaStringPool, InternCString) { google::protobuf::Arena arena; auto string_pool = NewArenaStringPool(&arena); auto expected = string_pool->InternString("Hello World!"); auto got = string_pool->InternString("Hello World!"); EXPECT_EQ(expected.data(), got.data()); } TEST(ArenaStringPool, InternStringView) { google::protobuf::Arena arena; auto string_pool = NewArenaStringPool(&arena); auto expected = string_pool->InternString(absl::string_view("Hello World!")); auto got = string_pool->InternString("Hello World!"); EXPECT_EQ(expected.data(), got.data()); } TEST(ArenaStringPool, InternStringSmall) { google::protobuf::Arena arena; auto string_pool = NewArenaStringPool(&arena); auto expected = string_pool->InternString(std::string("Hello World!")); auto got = string_pool->InternString("Hello World!"); EXPECT_EQ(expected.data(), got.data()); } TEST(ArenaStringPool, InternStringLarge) { google::protobuf::Arena arena; auto string_pool = NewArenaStringPool(&arena); auto expected = string_pool->InternString( std::string("This string is larger than std::string itself!")); auto got = string_pool->InternString( "This string is larger than std::string itself!"); EXPECT_EQ(expected.data(), got.data()); } TEST(ArenaStringPool, InternCord) { google::protobuf::Arena arena; auto string_pool = NewArenaStringPool(&arena); auto expected = string_pool->InternString(absl::MakeFragmentedCord( {"This string is larger", " ", "than absl::Cord itself!"})); auto got = string_pool->InternString( "This string is larger than absl::Cord itself!"); EXPECT_EQ(expected.data(), got.data()); } } // namespace } // namespace cel ================================================ FILE: common/arena_string_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/arena_string.h" #include "absl/base/nullability.h" #include "absl/hash/hash.h" #include "absl/hash/hash_testing.h" #include "absl/strings/string_view.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::testing::Eq; using ::testing::Ge; using ::testing::Gt; using ::testing::IsEmpty; using ::testing::Le; using ::testing::Lt; using ::testing::Ne; using ::testing::Not; using ::testing::NotNull; using ::testing::SizeIs; class ArenaStringTest : public ::testing::Test { protected: google::protobuf::Arena* absl_nonnull arena() { return &arena_; } private: google::protobuf::Arena arena_; }; TEST_F(ArenaStringTest, Default) { ArenaString string; EXPECT_THAT(string, IsEmpty()); EXPECT_THAT(string, SizeIs(0)); EXPECT_THAT(string, Eq(ArenaString())); } TEST_F(ArenaStringTest, Small) { static constexpr absl::string_view kSmall = "Hello World!"; ArenaString string(kSmall, arena()); EXPECT_THAT(string, Not(IsEmpty())); EXPECT_THAT(string, SizeIs(kSmall.size())); EXPECT_THAT(string.data(), NotNull()); EXPECT_THAT(string, kSmall); } TEST_F(ArenaStringTest, Large) { static constexpr absl::string_view kLarge = "This string is larger than the inline storage!"; ArenaString string(kLarge, arena()); EXPECT_THAT(string, Not(IsEmpty())); EXPECT_THAT(string, SizeIs(kLarge.size())); EXPECT_THAT(string.data(), NotNull()); EXPECT_THAT(string, kLarge); } TEST_F(ArenaStringTest, Iterator) { ArenaString string = ArenaString("Hello World!", arena()); auto it = string.cbegin(); EXPECT_THAT(*it++, Eq('H')); EXPECT_THAT(*it++, Eq('e')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('o')); EXPECT_THAT(*it++, Eq(' ')); EXPECT_THAT(*it++, Eq('W')); EXPECT_THAT(*it++, Eq('o')); EXPECT_THAT(*it++, Eq('r')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('d')); EXPECT_THAT(*it++, Eq('!')); EXPECT_THAT(it, Eq(string.cend())); } TEST_F(ArenaStringTest, ReverseIterator) { ArenaString string = ArenaString("Hello World!", arena()); auto it = string.crbegin(); EXPECT_THAT(*it++, Eq('!')); EXPECT_THAT(*it++, Eq('d')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('r')); EXPECT_THAT(*it++, Eq('o')); EXPECT_THAT(*it++, Eq('W')); EXPECT_THAT(*it++, Eq(' ')); EXPECT_THAT(*it++, Eq('o')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('e')); EXPECT_THAT(*it++, Eq('H')); EXPECT_THAT(it, Eq(string.crend())); } TEST_F(ArenaStringTest, RemovePrefix) { ArenaString string = ArenaString("Hello World!", arena()); string.remove_prefix(6); EXPECT_EQ(string, "World!"); } TEST_F(ArenaStringTest, RemoveSuffix) { ArenaString string = ArenaString("Hello World!", arena()); string.remove_suffix(7); EXPECT_EQ(string, "Hello"); } TEST_F(ArenaStringTest, Equal) { EXPECT_THAT(ArenaString("1", arena()), Eq(ArenaString("1", arena()))); } TEST_F(ArenaStringTest, NotEqual) { EXPECT_THAT(ArenaString("1", arena()), Ne(ArenaString("2", arena()))); } TEST_F(ArenaStringTest, Less) { EXPECT_THAT(ArenaString("1", arena()), Lt(ArenaString("2", arena()))); } TEST_F(ArenaStringTest, LessEqual) { EXPECT_THAT(ArenaString("1", arena()), Le(ArenaString("1", arena()))); } TEST_F(ArenaStringTest, Greater) { EXPECT_THAT(ArenaString("2", arena()), Gt(ArenaString("1", arena()))); } TEST_F(ArenaStringTest, GreaterEqual) { EXPECT_THAT(ArenaString("1", arena()), Ge(ArenaString("1", arena()))); } TEST_F(ArenaStringTest, ImplementsAbslHashCorrectly) { EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( {ArenaString("", arena()), ArenaString("Hello World!", arena()), ArenaString("How much wood could a woodchuck chuck if a " "woodchuck could chuck wood?", arena())})); } TEST_F(ArenaStringTest, Hash) { EXPECT_EQ(absl::HashOf(ArenaString("Hello World!", arena())), absl::HashOf(absl::string_view("Hello World!"))); } } // namespace } // namespace cel ================================================ FILE: common/arena_string_view.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_VIEW_H_ #define THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_VIEW_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "google/protobuf/arena.h" namespace cel { class ArenaString; // Bug in current Abseil LTS. Fixed in // https://github.com/abseil/abseil-cpp/commit/fd7713cb9a97c49096211ff40de280b6cebbb21c // which is not yet in an LTS. #if defined(__clang__) && (!defined(__clang_major__) || __clang_major__ >= 13) #define CEL_ATTRIBUTE_ARENA_STRING_VIEW ABSL_ATTRIBUTE_VIEW #else #define CEL_ATTRIBUTE_ARENA_STRING_VIEW #endif class CEL_ATTRIBUTE_ARENA_STRING_VIEW ArenaStringView final { public: using traits_type = std::char_traits; using value_type = char; using pointer = char*; using const_pointer = const char*; using reference = char&; using const_reference = const char&; using const_iterator = typename absl::string_view::const_pointer; using iterator = typename absl::string_view::const_iterator; using const_reverse_iterator = typename absl::string_view::const_reverse_iterator; using reverse_iterator = typename absl::string_view::reverse_iterator; using size_type = size_t; using difference_type = ptrdiff_t; using absl_internal_is_view = std::true_type; ArenaStringView() = default; ArenaStringView(const ArenaStringView&) = default; ArenaStringView& operator=(const ArenaStringView&) = default; // NOLINTNEXTLINE(google-explicit-constructor) ArenaStringView( const ArenaString& arena_string ABSL_ATTRIBUTE_LIFETIME_BOUND); // NOLINTNEXTLINE(google-explicit-constructor) ArenaStringView& operator=( const ArenaString& arena_string ABSL_ATTRIBUTE_LIFETIME_BOUND); ArenaStringView& operator=(ArenaString&&) = delete; explicit ArenaStringView( google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : arena_(arena) {} ArenaStringView(std::nullptr_t) = delete; ArenaStringView(absl::string_view string ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : string_(string), arena_(arena) {} ArenaStringView(absl::string_view, std::nullptr_t) = delete; google::protobuf::Arena* absl_nullable arena() const { return arena_; } size_type size() const { return string_.size(); } bool empty() const { return string_.empty(); } size_type max_size() const { return std::numeric_limits::max() >> 1; } absl_nonnull const_pointer data() const { return string_.data(); } const_reference front() const { ABSL_DCHECK(!empty()); return string_.front(); } const_reference back() const { ABSL_DCHECK(!empty()); return string_.back(); } const_reference operator[](size_type index) const { ABSL_DCHECK_LT(index, size()); return string_[index]; } void remove_prefix(size_type n) { ABSL_DCHECK_LE(n, size()); string_.remove_prefix(n); } void remove_suffix(size_type n) { ABSL_DCHECK_LE(n, size()); string_.remove_suffix(n); } const_iterator begin() const { return string_.begin(); } const_iterator cbegin() const { return string_.cbegin(); } const_iterator end() const { return string_.end(); } const_iterator cend() const { return string_.cend(); } const_reverse_iterator rbegin() const { return string_.rbegin(); } const_reverse_iterator crbegin() const { return string_.crbegin(); } const_reverse_iterator rend() const { return string_.rend(); } const_reverse_iterator crend() const { return string_.crend(); } // NOLINTNEXTLINE(google-explicit-constructor) operator absl::string_view() const { return string_; } private: absl::string_view string_; google::protobuf::Arena* absl_nullable arena_ = nullptr; }; inline bool operator==(ArenaStringView lhs, ArenaStringView rhs) { return absl::implicit_cast(lhs) == absl::implicit_cast(rhs); } inline bool operator==(ArenaStringView lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) == rhs; } inline bool operator==(absl::string_view lhs, ArenaStringView rhs) { return lhs == absl::implicit_cast(rhs); } inline bool operator!=(ArenaStringView lhs, ArenaStringView rhs) { return absl::implicit_cast(lhs) != absl::implicit_cast(rhs); } inline bool operator!=(ArenaStringView lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) != rhs; } inline bool operator!=(absl::string_view lhs, ArenaStringView rhs) { return lhs != absl::implicit_cast(rhs); } inline bool operator<(ArenaStringView lhs, ArenaStringView rhs) { return absl::implicit_cast(lhs) < absl::implicit_cast(rhs); } inline bool operator<(ArenaStringView lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) < rhs; } inline bool operator<(absl::string_view lhs, ArenaStringView rhs) { return lhs < absl::implicit_cast(rhs); } inline bool operator<=(ArenaStringView lhs, ArenaStringView rhs) { return absl::implicit_cast(lhs) <= absl::implicit_cast(rhs); } inline bool operator<=(ArenaStringView lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) <= rhs; } inline bool operator<=(absl::string_view lhs, ArenaStringView rhs) { return lhs <= absl::implicit_cast(rhs); } inline bool operator>(ArenaStringView lhs, ArenaStringView rhs) { return absl::implicit_cast(lhs) > absl::implicit_cast(rhs); } inline bool operator>(ArenaStringView lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) > rhs; } inline bool operator>(absl::string_view lhs, ArenaStringView rhs) { return lhs > absl::implicit_cast(rhs); } inline bool operator>=(ArenaStringView lhs, ArenaStringView rhs) { return absl::implicit_cast(lhs) >= absl::implicit_cast(rhs); } inline bool operator>=(ArenaStringView lhs, absl::string_view rhs) { return absl::implicit_cast(lhs) >= rhs; } inline bool operator>=(absl::string_view lhs, ArenaStringView rhs) { return lhs >= absl::implicit_cast(rhs); } template H AbslHashValue(H state, ArenaStringView arena_string_view) { return H::combine(std::move(state), absl::implicit_cast(arena_string_view)); } #undef CEL_ATTRIBUTE_ARENA_STRING_VIEW } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_ARENA_STRING_VIEW_H_ ================================================ FILE: common/arena_string_view_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/arena_string_view.h" #include "absl/base/nullability.h" #include "absl/hash/hash.h" #include "absl/hash/hash_testing.h" #include "absl/strings/string_view.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::testing::Eq; using ::testing::Ge; using ::testing::Gt; using ::testing::IsEmpty; using ::testing::Le; using ::testing::Lt; using ::testing::Ne; using ::testing::SizeIs; class ArenaStringViewTest : public ::testing::Test { protected: google::protobuf::Arena* absl_nonnull arena() { return &arena_; } private: google::protobuf::Arena arena_; }; TEST_F(ArenaStringViewTest, Default) { ArenaStringView string; EXPECT_THAT(string, IsEmpty()); EXPECT_THAT(string, SizeIs(0)); EXPECT_THAT(string, Eq(ArenaStringView())); } TEST_F(ArenaStringViewTest, Iterator) { ArenaStringView string = ArenaStringView("Hello World!", arena()); auto it = string.cbegin(); EXPECT_THAT(*it++, Eq('H')); EXPECT_THAT(*it++, Eq('e')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('o')); EXPECT_THAT(*it++, Eq(' ')); EXPECT_THAT(*it++, Eq('W')); EXPECT_THAT(*it++, Eq('o')); EXPECT_THAT(*it++, Eq('r')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('d')); EXPECT_THAT(*it++, Eq('!')); EXPECT_THAT(it, Eq(string.cend())); } TEST_F(ArenaStringViewTest, ReverseIterator) { ArenaStringView string = ArenaStringView("Hello World!", arena()); auto it = string.crbegin(); EXPECT_THAT(*it++, Eq('!')); EXPECT_THAT(*it++, Eq('d')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('r')); EXPECT_THAT(*it++, Eq('o')); EXPECT_THAT(*it++, Eq('W')); EXPECT_THAT(*it++, Eq(' ')); EXPECT_THAT(*it++, Eq('o')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('l')); EXPECT_THAT(*it++, Eq('e')); EXPECT_THAT(*it++, Eq('H')); EXPECT_THAT(it, Eq(string.crend())); } TEST_F(ArenaStringViewTest, RemovePrefix) { ArenaStringView string = ArenaStringView("Hello World!", arena()); string.remove_prefix(6); EXPECT_EQ(string, "World!"); } TEST_F(ArenaStringViewTest, RemoveSuffix) { ArenaStringView string = ArenaStringView("Hello World!", arena()); string.remove_suffix(7); EXPECT_EQ(string, "Hello"); } TEST_F(ArenaStringViewTest, Equal) { EXPECT_THAT(ArenaStringView("1", arena()), Eq(ArenaStringView("1", arena()))); } TEST_F(ArenaStringViewTest, NotEqual) { EXPECT_THAT(ArenaStringView("1", arena()), Ne(ArenaStringView("2", arena()))); } TEST_F(ArenaStringViewTest, Less) { EXPECT_THAT(ArenaStringView("1", arena()), Lt(ArenaStringView("2", arena()))); } TEST_F(ArenaStringViewTest, LessEqual) { EXPECT_THAT(ArenaStringView("1", arena()), Le(ArenaStringView("1", arena()))); } TEST_F(ArenaStringViewTest, Greater) { EXPECT_THAT(ArenaStringView("2", arena()), Gt(ArenaStringView("1", arena()))); } TEST_F(ArenaStringViewTest, GreaterEqual) { EXPECT_THAT(ArenaStringView("1", arena()), Ge(ArenaStringView("1", arena()))); } TEST_F(ArenaStringViewTest, ImplementsAbslHashCorrectly) { EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( {ArenaStringView("", arena()), ArenaStringView("Hello World!", arena()), ArenaStringView("How much wood could a woodchuck chuck if a " "woodchuck could chuck wood?", arena())})); } TEST_F(ArenaStringViewTest, Hash) { EXPECT_EQ(absl::HashOf(ArenaStringView("Hello World!", arena())), absl::HashOf(absl::string_view("Hello World!"))); } } // namespace } // namespace cel ================================================ FILE: common/ast/BUILD ================================================ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") # Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Internal AST implementation and utilities # These are needed by various parts of the CEL-C++ library, but are not intended for public use at # this time. package(default_visibility = ["//visibility:public"]) cc_library( name = "constant_proto", srcs = ["constant_proto.cc"], hdrs = ["constant_proto.h"], deps = [ "//common:constant", "//internal:proto_time_encoding", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:struct_cc_proto", ], ) cc_library( name = "expr_proto", srcs = ["expr_proto.cc"], hdrs = ["expr_proto.h"], deps = [ ":constant_proto", "//common:constant", "//common:expr", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "expr_proto_test", srcs = ["expr_proto_test.cc"], deps = [ ":expr_proto", "//common:expr", "//internal:proto_matchers", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "source_info_proto", srcs = ["source_info_proto.cc"], hdrs = ["source_info_proto.h"], deps = [ ":expr_proto", "//common:ast", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_library( name = "metadata", srcs = ["metadata.cc"], hdrs = ["metadata.h"], deps = [ "//common:constant", "//common:expr", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], ) cc_test( name = "metadata_test", srcs = ["metadata_test.cc"], deps = [ ":metadata", "//common:expr", "//internal:testing", "@com_google_absl//absl/types:variant", ], ) cc_library( name = "navigable_ast_internal", srcs = ["navigable_ast_kinds.cc"], hdrs = [ "navigable_ast_internal.h", "navigable_ast_kinds.h", ], deps = [ "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) cc_test( name = "navigable_ast_internal_test", srcs = ["navigable_ast_internal_test.cc"], deps = [ ":navigable_ast_internal", "//internal:testing", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) ================================================ FILE: common/ast/constant_proto.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast/constant_proto.h" #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "absl/types/variant.h" #include "common/constant.h" #include "internal/proto_time_encoding.h" namespace cel::ast_internal { using ConstantProto = cel::expr::Constant; absl::Status ConstantToProto(const Constant& constant, ConstantProto* absl_nonnull proto) { return absl::visit(absl::Overload( [proto](absl::monostate) -> absl::Status { proto->clear_constant_kind(); return absl::OkStatus(); }, [proto](std::nullptr_t) -> absl::Status { proto->set_null_value(google::protobuf::NULL_VALUE); return absl::OkStatus(); }, [proto](bool value) -> absl::Status { proto->set_bool_value(value); return absl::OkStatus(); }, [proto](int64_t value) -> absl::Status { proto->set_int64_value(value); return absl::OkStatus(); }, [proto](uint64_t value) -> absl::Status { proto->set_uint64_value(value); return absl::OkStatus(); }, [proto](double value) -> absl::Status { proto->set_double_value(value); return absl::OkStatus(); }, [proto](const BytesConstant& value) -> absl::Status { proto->set_bytes_value(value); return absl::OkStatus(); }, [proto](const StringConstant& value) -> absl::Status { proto->set_string_value(value); return absl::OkStatus(); }, [proto](absl::Duration value) -> absl::Status { return internal::EncodeDuration( value, proto->mutable_duration_value()); }, [proto](absl::Time value) -> absl::Status { return internal::EncodeTime( value, proto->mutable_timestamp_value()); }), constant.kind()); } absl::Status ConstantFromProto(const ConstantProto& proto, Constant& constant) { switch (proto.constant_kind_case()) { case ConstantProto::CONSTANT_KIND_NOT_SET: constant = Constant{}; break; case ConstantProto::kNullValue: constant.set_null_value(); break; case ConstantProto::kBoolValue: constant.set_bool_value(proto.bool_value()); break; case ConstantProto::kInt64Value: constant.set_int_value(proto.int64_value()); break; case ConstantProto::kUint64Value: constant.set_uint_value(proto.uint64_value()); break; case ConstantProto::kDoubleValue: constant.set_double_value(proto.double_value()); break; case ConstantProto::kStringValue: constant.set_string_value(proto.string_value()); break; case ConstantProto::kBytesValue: constant.set_bytes_value(proto.bytes_value()); break; case ConstantProto::kDurationValue: constant.set_duration_value( internal::DecodeDuration(proto.duration_value())); break; case ConstantProto::kTimestampValue: constant.set_timestamp_value( internal::DecodeTime(proto.timestamp_value())); break; default: return absl::InvalidArgumentError( absl::StrCat("unexpected ConstantKindCase: ", static_cast(proto.constant_kind_case()))); } return absl::OkStatus(); } } // namespace cel::ast_internal ================================================ FILE: common/ast/constant_proto.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_CONSTANT_PROTO_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_CONSTANT_PROTO_H_ #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/constant.h" namespace cel::ast_internal { // `ConstantToProto` converts from native `Constant` to its protocol buffer // message equivalent. absl::Status ConstantToProto(const Constant& constant, cel::expr::Constant* absl_nonnull proto); // `ConstantToProto` converts to native `Constant` from its protocol buffer // message equivalent. absl::Status ConstantFromProto(const cel::expr::Constant& proto, Constant& constant); } // namespace cel::ast_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_CONSTANT_PROTO_H_ ================================================ FILE: common/ast/expr_proto.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast/expr_proto.h" #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/variant.h" #include "common/ast/constant_proto.h" #include "common/constant.h" #include "common/expr.h" #include "internal/status_macros.h" namespace cel::ast_internal { namespace { using ExprProto = cel::expr::Expr; using ConstantProto = cel::expr::Constant; using StructExprProto = cel::expr::Expr::CreateStruct; class ExprToProtoState final { private: struct Frame final { const Expr* absl_nonnull expr; cel::expr::Expr* absl_nonnull proto; }; public: absl::Status ExprToProto(const Expr& expr, cel::expr::Expr* absl_nonnull proto) { Push(expr, proto); Frame frame; while (Pop(frame)) { CEL_RETURN_IF_ERROR(ExprToProtoImpl(*frame.expr, frame.proto)); } return absl::OkStatus(); } private: absl::Status ExprToProtoImpl(const Expr& expr, cel::expr::Expr* absl_nonnull proto) { return absl::visit( absl::Overload( [&expr, proto](const UnspecifiedExpr&) -> absl::Status { proto->Clear(); proto->set_id(expr.id()); return absl::OkStatus(); }, [this, &expr, proto](const Constant& const_expr) -> absl::Status { return ConstExprToProto(expr, const_expr, proto); }, [this, &expr, proto](const IdentExpr& ident_expr) -> absl::Status { return IdentExprToProto(expr, ident_expr, proto); }, [this, &expr, proto](const SelectExpr& select_expr) -> absl::Status { return SelectExprToProto(expr, select_expr, proto); }, [this, &expr, proto](const CallExpr& call_expr) -> absl::Status { return CallExprToProto(expr, call_expr, proto); }, [this, &expr, proto](const ListExpr& list_expr) -> absl::Status { return ListExprToProto(expr, list_expr, proto); }, [this, &expr, proto](const StructExpr& struct_expr) -> absl::Status { return StructExprToProto(expr, struct_expr, proto); }, [this, &expr, proto](const MapExpr& map_expr) -> absl::Status { return MapExprToProto(expr, map_expr, proto); }, [this, &expr, proto]( const ComprehensionExpr& comprehension_expr) -> absl::Status { return ComprehensionExprToProto(expr, comprehension_expr, proto); }), expr.kind()); } absl::Status ConstExprToProto(const Expr& expr, const Constant& const_expr, ExprProto* absl_nonnull proto) { proto->Clear(); proto->set_id(expr.id()); return ConstantToProto(const_expr, proto->mutable_const_expr()); } absl::Status IdentExprToProto(const Expr& expr, const IdentExpr& ident_expr, ExprProto* absl_nonnull proto) { proto->Clear(); auto* ident_proto = proto->mutable_ident_expr(); proto->set_id(expr.id()); ident_proto->set_name(ident_expr.name()); return absl::OkStatus(); } absl::Status SelectExprToProto(const Expr& expr, const SelectExpr& select_expr, ExprProto* absl_nonnull proto) { proto->Clear(); auto* select_proto = proto->mutable_select_expr(); proto->set_id(expr.id()); if (select_expr.has_operand()) { Push(select_expr.operand(), select_proto->mutable_operand()); } select_proto->set_field(select_expr.field()); select_proto->set_test_only(select_expr.test_only()); return absl::OkStatus(); } absl::Status CallExprToProto(const Expr& expr, const CallExpr& call_expr, ExprProto* absl_nonnull proto) { proto->Clear(); auto* call_proto = proto->mutable_call_expr(); proto->set_id(expr.id()); if (call_expr.has_target()) { Push(call_expr.target(), call_proto->mutable_target()); } call_proto->set_function(call_expr.function()); if (!call_expr.args().empty()) { call_proto->mutable_args()->Reserve( static_cast(call_expr.args().size())); for (const auto& argument : call_expr.args()) { Push(argument, call_proto->add_args()); } } return absl::OkStatus(); } absl::Status ListExprToProto(const Expr& expr, const ListExpr& list_expr, ExprProto* absl_nonnull proto) { proto->Clear(); auto* list_proto = proto->mutable_list_expr(); proto->set_id(expr.id()); if (!list_expr.elements().empty()) { list_proto->mutable_elements()->Reserve( static_cast(list_expr.elements().size())); for (size_t i = 0; i < list_expr.elements().size(); ++i) { const auto& element_expr = list_expr.elements()[i]; auto* element_proto = list_proto->add_elements(); if (element_expr.has_expr()) { Push(element_expr.expr(), element_proto); } if (element_expr.optional()) { list_proto->add_optional_indices(static_cast(i)); } } } return absl::OkStatus(); } absl::Status StructExprToProto(const Expr& expr, const StructExpr& struct_expr, ExprProto* absl_nonnull proto) { proto->Clear(); auto* struct_proto = proto->mutable_struct_expr(); proto->set_id(expr.id()); struct_proto->set_message_name(struct_expr.name()); if (!struct_expr.fields().empty()) { struct_proto->mutable_entries()->Reserve( static_cast(struct_expr.fields().size())); for (const auto& field_expr : struct_expr.fields()) { auto* field_proto = struct_proto->add_entries(); field_proto->set_id(field_expr.id()); field_proto->set_field_key(field_expr.name()); if (field_expr.has_value()) { Push(field_expr.value(), field_proto->mutable_value()); } if (field_expr.optional()) { field_proto->set_optional_entry(true); } } } return absl::OkStatus(); } absl::Status MapExprToProto(const Expr& expr, const MapExpr& map_expr, ExprProto* absl_nonnull proto) { proto->Clear(); auto* map_proto = proto->mutable_struct_expr(); proto->set_id(expr.id()); if (!map_expr.entries().empty()) { map_proto->mutable_entries()->Reserve( static_cast(map_expr.entries().size())); for (const auto& entry_expr : map_expr.entries()) { auto* entry_proto = map_proto->add_entries(); entry_proto->set_id(entry_expr.id()); if (entry_expr.has_key()) { Push(entry_expr.key(), entry_proto->mutable_map_key()); } if (entry_expr.has_value()) { Push(entry_expr.value(), entry_proto->mutable_value()); } if (entry_expr.optional()) { entry_proto->set_optional_entry(true); } } } return absl::OkStatus(); } absl::Status ComprehensionExprToProto( const Expr& expr, const ComprehensionExpr& comprehension_expr, ExprProto* absl_nonnull proto) { proto->Clear(); auto* comprehension_proto = proto->mutable_comprehension_expr(); proto->set_id(expr.id()); comprehension_proto->set_iter_var(comprehension_expr.iter_var()); comprehension_proto->set_iter_var2(comprehension_expr.iter_var2()); if (comprehension_expr.has_iter_range()) { Push(comprehension_expr.iter_range(), comprehension_proto->mutable_iter_range()); } comprehension_proto->set_accu_var(comprehension_expr.accu_var()); if (comprehension_expr.has_accu_init()) { Push(comprehension_expr.accu_init(), comprehension_proto->mutable_accu_init()); } if (comprehension_expr.has_loop_condition()) { Push(comprehension_expr.loop_condition(), comprehension_proto->mutable_loop_condition()); } if (comprehension_expr.has_loop_step()) { Push(comprehension_expr.loop_step(), comprehension_proto->mutable_loop_step()); } if (comprehension_expr.has_result()) { Push(comprehension_expr.result(), comprehension_proto->mutable_result()); } return absl::OkStatus(); } void Push(const Expr& expr, ExprProto* absl_nonnull proto) { frames_.push(Frame{&expr, proto}); } bool Pop(Frame& frame) { if (frames_.empty()) { return false; } frame = frames_.top(); frames_.pop(); return true; } std::stack> frames_; }; class ExprFromProtoState final { private: struct Frame final { const ExprProto* absl_nonnull proto; Expr* absl_nonnull expr; }; public: absl::Status ExprFromProto(const ExprProto& proto, Expr& expr) { Push(proto, expr); Frame frame; while (Pop(frame)) { CEL_RETURN_IF_ERROR(ExprFromProtoImpl(*frame.proto, *frame.expr)); } return absl::OkStatus(); } private: absl::Status ExprFromProtoImpl(const ExprProto& proto, Expr& expr) { switch (proto.expr_kind_case()) { case ExprProto::EXPR_KIND_NOT_SET: expr.Clear(); expr.set_id(proto.id()); return absl::OkStatus(); case ExprProto::kConstExpr: return ConstExprFromProto(proto, proto.const_expr(), expr); case ExprProto::kIdentExpr: return IdentExprFromProto(proto, proto.ident_expr(), expr); case ExprProto::kSelectExpr: return SelectExprFromProto(proto, proto.select_expr(), expr); case ExprProto::kCallExpr: return CallExprFromProto(proto, proto.call_expr(), expr); case ExprProto::kListExpr: return ListExprFromProto(proto, proto.list_expr(), expr); case ExprProto::kStructExpr: if (proto.struct_expr().message_name().empty()) { return MapExprFromProto(proto, proto.struct_expr(), expr); } return StructExprFromProto(proto, proto.struct_expr(), expr); case ExprProto::kComprehensionExpr: return ComprehensionExprFromProto(proto, proto.comprehension_expr(), expr); default: return absl::InvalidArgumentError( absl::StrCat("unexpected ExprKindCase: ", static_cast(proto.expr_kind_case()))); } } absl::Status ConstExprFromProto(const ExprProto& proto, const ConstantProto& const_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); return ConstantFromProto(const_proto, expr.mutable_const_expr()); } absl::Status IdentExprFromProto(const ExprProto& proto, const ExprProto::Ident& ident_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); auto& ident_expr = expr.mutable_ident_expr(); ident_expr.set_name(ident_proto.name()); return absl::OkStatus(); } absl::Status SelectExprFromProto(const ExprProto& proto, const ExprProto::Select& select_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); auto& select_expr = expr.mutable_select_expr(); if (select_proto.has_operand()) { Push(select_proto.operand(), select_expr.mutable_operand()); } select_expr.set_field(select_proto.field()); select_expr.set_test_only(select_proto.test_only()); return absl::OkStatus(); } absl::Status CallExprFromProto(const ExprProto& proto, const ExprProto::Call& call_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); auto& call_expr = expr.mutable_call_expr(); call_expr.set_function(call_proto.function()); if (call_proto.has_target()) { Push(call_proto.target(), call_expr.mutable_target()); } call_expr.mutable_args().reserve( static_cast(call_proto.args().size())); for (const auto& argument_proto : call_proto.args()) { Push(argument_proto, call_expr.add_args()); } return absl::OkStatus(); } absl::Status ListExprFromProto(const ExprProto& proto, const ExprProto::CreateList& list_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); auto& list_expr = expr.mutable_list_expr(); list_expr.mutable_elements().reserve( static_cast(list_proto.elements().size())); for (int i = 0; i < list_proto.elements().size(); ++i) { const auto& element_proto = list_proto.elements()[i]; auto& element_expr = list_expr.add_elements(); Push(element_proto, element_expr.mutable_expr()); const auto& optional_indicies_proto = list_proto.optional_indices(); element_expr.set_optional(std::find(optional_indicies_proto.begin(), optional_indicies_proto.end(), i) != optional_indicies_proto.end()); } return absl::OkStatus(); } absl::Status StructExprFromProto(const ExprProto& proto, const StructExprProto& struct_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); auto& struct_expr = expr.mutable_struct_expr(); struct_expr.set_name(struct_proto.message_name()); struct_expr.mutable_fields().reserve( static_cast(struct_proto.entries().size())); for (const auto& field_proto : struct_proto.entries()) { switch (field_proto.key_kind_case()) { case StructExprProto::Entry::KEY_KIND_NOT_SET: ABSL_FALLTHROUGH_INTENDED; case StructExprProto::Entry::kFieldKey: break; case StructExprProto::Entry::kMapKey: return absl::InvalidArgumentError("encountered map entry in struct"); default: return absl::InvalidArgumentError(absl::StrCat( "unexpected struct field kind: ", field_proto.key_kind_case())); } auto& field_expr = struct_expr.add_fields(); field_expr.set_id(field_proto.id()); field_expr.set_name(field_proto.field_key()); if (field_proto.has_value()) { Push(field_proto.value(), field_expr.mutable_value()); } field_expr.set_optional(field_proto.optional_entry()); } return absl::OkStatus(); } absl::Status MapExprFromProto(const ExprProto& proto, const ExprProto::CreateStruct& map_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); auto& map_expr = expr.mutable_map_expr(); map_expr.mutable_entries().reserve( static_cast(map_proto.entries().size())); for (const auto& entry_proto : map_proto.entries()) { switch (entry_proto.key_kind_case()) { case StructExprProto::Entry::KEY_KIND_NOT_SET: ABSL_FALLTHROUGH_INTENDED; case StructExprProto::Entry::kMapKey: break; case StructExprProto::Entry::kFieldKey: return absl::InvalidArgumentError("encountered struct field in map"); default: return absl::InvalidArgumentError(absl::StrCat( "unexpected map entry kind: ", entry_proto.key_kind_case())); } auto& entry_expr = map_expr.add_entries(); entry_expr.set_id(entry_proto.id()); if (entry_proto.has_map_key()) { Push(entry_proto.map_key(), entry_expr.mutable_key()); } if (entry_proto.has_value()) { Push(entry_proto.value(), entry_expr.mutable_value()); } entry_expr.set_optional(entry_proto.optional_entry()); } return absl::OkStatus(); } absl::Status ComprehensionExprFromProto( const ExprProto& proto, const ExprProto::Comprehension& comprehension_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); auto& comprehension_expr = expr.mutable_comprehension_expr(); comprehension_expr.set_iter_var(comprehension_proto.iter_var()); comprehension_expr.set_iter_var2(comprehension_proto.iter_var2()); comprehension_expr.set_accu_var(comprehension_proto.accu_var()); if (comprehension_proto.has_iter_range()) { Push(comprehension_proto.iter_range(), comprehension_expr.mutable_iter_range()); } if (comprehension_proto.has_accu_init()) { Push(comprehension_proto.accu_init(), comprehension_expr.mutable_accu_init()); } if (comprehension_proto.has_loop_condition()) { Push(comprehension_proto.loop_condition(), comprehension_expr.mutable_loop_condition()); } if (comprehension_proto.has_loop_step()) { Push(comprehension_proto.loop_step(), comprehension_expr.mutable_loop_step()); } if (comprehension_proto.has_result()) { Push(comprehension_proto.result(), comprehension_expr.mutable_result()); } return absl::OkStatus(); } void Push(const ExprProto& proto, Expr& expr) { frames_.push(Frame{&proto, &expr}); } bool Pop(Frame& frame) { if (frames_.empty()) { return false; } frame = frames_.top(); frames_.pop(); return true; } std::stack> frames_; }; } // namespace absl::Status ExprToProto(const Expr& expr, cel::expr::Expr* absl_nonnull proto) { ExprToProtoState state; return state.ExprToProto(expr, proto); } absl::Status ExprFromProto(const cel::expr::Expr& proto, Expr& expr) { ExprFromProtoState state; return state.ExprFromProto(proto, expr); } } // namespace cel::ast_internal ================================================ FILE: common/ast/expr_proto.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_EXPR_PROTO_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_EXPR_PROTO_H_ #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/expr.h" namespace cel::ast_internal { absl::Status ExprToProto(const Expr& expr, cel::expr::Expr* absl_nonnull proto); absl::Status ExprFromProto(const cel::expr::Expr& proto, Expr& expr); } // namespace cel::ast_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_EXPR_PROTO_H_ ================================================ FILE: common/ast/expr_proto_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast/expr_proto.h" #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "common/expr.h" #include "internal/proto_matchers.h" #include "internal/testing.h" #include "google/protobuf/text_format.h" namespace cel::ast_internal { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::internal::test::EqualsProto; using ExprProto = cel::expr::Expr; struct ExprRoundtripTestCase { std::string input; }; using ExprRoundTripTest = ::testing::TestWithParam; TEST_P(ExprRoundTripTest, RoundTrip) { const auto& test_case = GetParam(); ExprProto original_proto; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.input, &original_proto)); Expr expr; ASSERT_THAT(ExprFromProto(original_proto, expr), IsOk()); ExprProto proto; ASSERT_THAT(ExprToProto(expr, &proto), IsOk()); EXPECT_THAT(proto, EqualsProto(original_proto)); } INSTANTIATE_TEST_SUITE_P( ExprRoundTripTest, ExprRoundTripTest, ::testing::ValuesIn({ {R"pb( )pb"}, {R"pb( id: 1 )pb"}, {R"pb( id: 1 const_expr {} )pb"}, {R"pb( id: 1 const_expr { null_value: NULL_VALUE } )pb"}, {R"pb( id: 1 const_expr { bool_value: true } )pb"}, {R"pb( id: 1 const_expr { int64_value: 1 } )pb"}, {R"pb( id: 1 const_expr { uint64_value: 1 } )pb"}, {R"pb( id: 1 const_expr { double_value: 1 } )pb"}, {R"pb( id: 1 const_expr { string_value: "foo" } )pb"}, {R"pb( id: 1 const_expr { bytes_value: "foo" } )pb"}, {R"pb( id: 1 const_expr { duration_value { seconds: 1 nanos: 1 } } )pb"}, {R"pb( id: 1 const_expr { timestamp_value { seconds: 1 nanos: 1 } } )pb"}, {R"pb( id: 1 ident_expr { name: "foo" } )pb"}, {R"pb( id: 1 select_expr { operand { id: 2 ident_expr { name: "bar" } } field: "foo" test_only: true } )pb"}, {R"pb( id: 1 call_expr { target { id: 2 ident_expr { name: "bar" } } function: "foo" args { id: 3 ident_expr { name: "baz" } } } )pb"}, {R"pb( id: 1 list_expr { elements { id: 2 ident_expr { name: "bar" } } elements { id: 3 ident_expr { name: "baz" } } optional_indices: 0 } )pb"}, {R"pb( id: 1 struct_expr { message_name: "google.type.Expr" entries { id: 2 field_key: "description" value { id: 3 const_expr { string_value: "foo" } } optional_entry: true } entries { id: 4 field_key: "expr" value { id: 5 const_expr { string_value: "bar" } } } } )pb"}, {R"pb( id: 1 struct_expr { entries { id: 2 map_key { id: 3 const_expr { string_value: "description" } } value { id: 4 const_expr { string_value: "foo" } } optional_entry: true } entries { id: 5 map_key { id: 6 const_expr { string_value: "expr" } } value { id: 7 const_expr { string_value: "foo" } } optional_entry: true } } )pb"}, {R"pb( id: 1 comprehension_expr { iter_var: "foo" iter_range { id: 2 list_expr {} } accu_var: "bar" accu_init { id: 3 list_expr {} } loop_condition { id: 4 const_expr { bool_value: true } } loop_step { id: 4 ident_expr { name: "bar" } } result { id: 5 ident_expr { name: "foo" } } } )pb"}, {R"pb( id: 1 comprehension_expr { iter_var: "foo" iter_var2: "baz" iter_range { id: 2 list_expr {} } accu_var: "bar" accu_init { id: 3 list_expr {} } loop_condition { id: 4 const_expr { bool_value: true } } loop_step { id: 4 ident_expr { name: "bar" } } result { id: 5 ident_expr { name: "foo" } } } )pb"}, })); TEST(ExprFromProto, StructFieldInMap) { ExprProto original_proto; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(R"pb( id: 1 struct_expr: { entries: { id: 2 field_key: "foo" value: { id: 3 ident_expr: { name: "bar" } } } } )pb", &original_proto)); Expr expr; ASSERT_THAT(ExprFromProto(original_proto, expr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ExprFromProto, MapEntryInStruct) { ExprProto original_proto; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(R"pb( id: 1 struct_expr: { message_name: "some.Message" entries: { id: 2 map_key: { id: 3 ident_expr: { name: "foo" } } value: { id: 4 ident_expr: { name: "bar" } } } } )pb", &original_proto)); Expr expr; ASSERT_THAT(ExprFromProto(original_proto, expr), StatusIs(absl::StatusCode::kInvalidArgument)); } } // namespace } // namespace cel::ast_internal ================================================ FILE: common/ast/metadata.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast/metadata.h" #include #include #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/types/variant.h" namespace cel { namespace { const TypeSpec& DefaultTypeSpec() { static absl::NoDestructor type(TypeSpecKind{UnsetTypeSpec()}); return *type; } std::string FormatPrimitive(PrimitiveType t) { switch (t) { case PrimitiveType::kBool: return "bool"; case PrimitiveType::kInt64: return "int"; case PrimitiveType::kUint64: return "uint"; case PrimitiveType::kDouble: return "double"; case PrimitiveType::kString: return "string"; case PrimitiveType::kBytes: return "bytes"; default: return "*unspecified primitive*"; } } std::string FormatWellKnown(WellKnownTypeSpec t) { switch (t) { case WellKnownTypeSpec::kAny: return "google.protobuf.Any"; case WellKnownTypeSpec::kDuration: return "google.protobuf.Duration"; case WellKnownTypeSpec::kTimestamp: return "google.protobuf.Timestamp"; default: return "*unspecified well known*"; } } using FormatIns = std::variant; using FormatStack = std::vector; void HandleFormatTypeSpec(const TypeSpec& t, FormatStack& stack, std::string* out) { if (t.has_dyn()) { absl::StrAppend(out, "dyn"); } else if (t.has_null()) { absl::StrAppend(out, "null"); } else if (t.has_primitive()) { absl::StrAppend(out, FormatPrimitive(t.primitive())); } else if (t.has_wrapper()) { absl::StrAppend(out, "wrapper(", FormatPrimitive(t.wrapper()), ")"); } else if (t.has_well_known()) { absl::StrAppend(out, FormatWellKnown(t.well_known())); return; } else if (t.has_abstract_type()) { const auto& abs_type = t.abstract_type(); if (abs_type.parameter_types().empty()) { absl::StrAppend(out, abs_type.name()); return; } absl::StrAppend(out, abs_type.name(), "("); stack.push_back(")"); for (size_t i = abs_type.parameter_types().size(); i > 0; --i) { stack.push_back(&abs_type.parameter_types()[i - 1]); if (i > 1) { stack.push_back(", "); } } } else if (t.has_type()) { if (t.type() == TypeSpec()) { absl::StrAppend(out, "type"); return; } absl::StrAppend(out, "type("); stack.push_back(")"); stack.push_back(&t.type()); } else if (t.has_message_type()) { absl::StrAppend(out, t.message_type().type()); } else if (t.has_type_param()) { absl::StrAppend(out, t.type_param().type()); } else if (t.has_list_type()) { absl::StrAppend(out, "list("); stack.push_back(")"); stack.push_back(&t.list_type().elem_type()); } else if (t.has_map_type()) { absl::StrAppend(out, "map("); stack.push_back(")"); stack.push_back(&t.map_type().value_type()); stack.push_back(", "); stack.push_back(&t.map_type().key_type()); } else { absl::StrAppend(out, "*error*"); } } TypeSpecKind CopyImpl(const TypeSpecKind& other) { return absl::visit( absl::Overload( [](const std::unique_ptr& other) -> TypeSpecKind { if (other == nullptr) { return std::make_unique(); } return std::make_unique(*other); }, [](const auto& other) -> TypeSpecKind { // Other variants define copy ctor. return other; }), other); } } // namespace const ExtensionSpec::Version& ExtensionSpec::Version::DefaultInstance() { static absl::NoDestructor instance; return *instance; } const ExtensionSpec& ExtensionSpec::DefaultInstance() { static absl::NoDestructor instance; return *instance; } ExtensionSpec::ExtensionSpec(const ExtensionSpec& other) : id_(other.id_), affected_components_(other.affected_components_), version_(other.version_ == nullptr ? nullptr : std::make_unique(*other.version_)) {} ExtensionSpec& ExtensionSpec::operator=(const ExtensionSpec& other) { id_ = other.id_; affected_components_ = other.affected_components_; if (other.version_ != nullptr) { version_ = std::make_unique(other.version()); } else { version_ = nullptr; } return *this; } const TypeSpec& ListTypeSpec::elem_type() const { if (elem_type_ != nullptr) { return *elem_type_; } return DefaultTypeSpec(); } bool ListTypeSpec::operator==(const ListTypeSpec& other) const { return elem_type() == other.elem_type(); } const TypeSpec& MapTypeSpec::key_type() const { if (key_type_ != nullptr) { return *key_type_; } return DefaultTypeSpec(); } const TypeSpec& MapTypeSpec::value_type() const { if (value_type_ != nullptr) { return *value_type_; } return DefaultTypeSpec(); } bool MapTypeSpec::operator==(const MapTypeSpec& other) const { return key_type() == other.key_type() && value_type() == other.value_type(); } const TypeSpec& FunctionTypeSpec::result_type() const { if (result_type_ != nullptr) { return *result_type_; } return DefaultTypeSpec(); } bool FunctionTypeSpec::operator==(const FunctionTypeSpec& other) const { return result_type() == other.result_type() && arg_types_ == other.arg_types_; } const TypeSpec& TypeSpec::type() const { auto* value = absl::get_if>(&type_kind_); if (value != nullptr) { if (*value != nullptr) return **value; } return DefaultTypeSpec(); } TypeSpec::TypeSpec(const TypeSpec& other) : type_kind_(CopyImpl(other.type_kind_)) {} TypeSpec& TypeSpec::operator=(const TypeSpec& other) { type_kind_ = CopyImpl(other.type_kind_); return *this; } FunctionTypeSpec::FunctionTypeSpec(const FunctionTypeSpec& other) : result_type_(std::make_unique(other.result_type())), arg_types_(other.arg_types()) {} FunctionTypeSpec& FunctionTypeSpec::operator=(const FunctionTypeSpec& other) { result_type_ = std::make_unique(other.result_type()); arg_types_ = other.arg_types(); return *this; } std::string FormatTypeSpec(const TypeSpec& t) { // Use a stack to avoid recursion. // Probably overly defensive, but fuzzers will often notice the recursion // and try to trigger it. std::string out; FormatStack seq; seq.push_back(&t); while (!seq.empty()) { FormatIns ins = std::move(seq.back()); seq.pop_back(); if (std::holds_alternative(ins)) { absl::StrAppend(&out, std::get(ins)); continue; } ABSL_DCHECK(std::holds_alternative(ins)); HandleFormatTypeSpec(*std::get(ins), seq, &out); } return out; } } // namespace cel ================================================ FILE: common/ast/metadata.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Type definitions for auxiliary structures in the AST. // // These are more direct equivalents to the public protobuf definitions. // // IWYU pragma: private, include "common/ast.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_METADATA_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_METADATA_H_ #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "common/constant.h" #include "common/expr.h" namespace cel { // An extension that was requested for the source expression. class ExtensionSpec { public: // Version class Version { public: Version() : major_(0), minor_(0) {} Version(int64_t major, int64_t minor) : major_(major), minor_(minor) {} Version(const Version& other) = default; Version(Version&& other) = default; Version& operator=(const Version& other) = default; Version& operator=(Version&& other) = default; static const Version& DefaultInstance(); // Major version changes indicate different required support level from // the required components. int64_t major() const { return major_; } void set_major(int64_t val) { major_ = val; } // Minor version changes must not change the observed behavior from // existing implementations, but may be provided informationally. int64_t minor() const { return minor_; } void set_minor(int64_t val) { minor_ = val; } bool operator==(const Version& other) const { return major_ == other.major_ && minor_ == other.minor_; } bool operator!=(const Version& other) const { return !operator==(other); } private: int64_t major_; int64_t minor_; }; // CEL component specifier. enum class Component { // Unspecified, default. kUnspecified, // Parser. Converts a CEL string to an AST. kParser, // Type checker. Checks that references in an AST are defined and types // agree. kTypeChecker, // Runtime. Evaluates a parsed and optionally checked CEL AST against a // context. kRuntime }; static const ExtensionSpec& DefaultInstance(); ExtensionSpec() = default; ExtensionSpec(std::string id, std::unique_ptr version, std::vector affected_components) : id_(std::move(id)), affected_components_(std::move(affected_components)), version_(std::move(version)) {} ExtensionSpec(const ExtensionSpec& other); ExtensionSpec(ExtensionSpec&& other) = default; ExtensionSpec& operator=(const ExtensionSpec& other); ExtensionSpec& operator=(ExtensionSpec&& other) = default; // Identifier for the extension. Example: constant_folding const std::string& id() const { return id_; } void set_id(std::string id) { id_ = std::move(id); } // If set, the listed components must understand the extension for the // expression to evaluate correctly. // // This field has set semantics, repeated values should be deduplicated. const std::vector& affected_components() const { return affected_components_; } std::vector& mutable_affected_components() { return affected_components_; } // Version info. May be skipped if it isn't meaningful for the extension. // (for example constant_folding might always be v0.0). const Version& version() const { if (version_ == nullptr) { return Version::DefaultInstance(); } return *version_; } Version& mutable_version() { if (version_ == nullptr) { version_ = std::make_unique(); } return *version_; } void set_version(std::unique_ptr version) { version_ = std::move(version); } bool operator==(const ExtensionSpec& other) const { return id_ == other.id_ && affected_components_ == other.affected_components_ && version() == other.version(); } bool operator!=(const ExtensionSpec& other) const { return !operator==(other); } private: std::string id_; std::vector affected_components_; std::unique_ptr version_; }; // Source information collected at parse time. class SourceInfo { public: SourceInfo() = default; SourceInfo(std::string syntax_version, std::string location, std::vector line_offsets, absl::flat_hash_map positions, absl::flat_hash_map macro_calls, std::vector extensions) : syntax_version_(std::move(syntax_version)), location_(std::move(location)), line_offsets_(std::move(line_offsets)), positions_(std::move(positions)), macro_calls_(std::move(macro_calls)), extensions_(std::move(extensions)) {} SourceInfo(const SourceInfo& other) = default; SourceInfo(SourceInfo&& other) = default; SourceInfo& operator=(const SourceInfo& other) = default; SourceInfo& operator=(SourceInfo&& other) = default; void set_syntax_version(std::string syntax_version) { syntax_version_ = std::move(syntax_version); } void set_location(std::string location) { location_ = std::move(location); } void set_line_offsets(std::vector line_offsets) { line_offsets_ = std::move(line_offsets); } void set_positions(absl::flat_hash_map positions) { positions_ = std::move(positions); } void set_macro_calls(absl::flat_hash_map macro_calls) { macro_calls_ = std::move(macro_calls); } const std::string& syntax_version() const { return syntax_version_; } const std::string& location() const { return location_; } const std::vector& line_offsets() const { return line_offsets_; } std::vector& mutable_line_offsets() { return line_offsets_; } const absl::flat_hash_map& positions() const { return positions_; } absl::flat_hash_map& mutable_positions() { return positions_; } const absl::flat_hash_map& macro_calls() const { return macro_calls_; } absl::flat_hash_map& mutable_macro_calls() { return macro_calls_; } bool operator==(const SourceInfo& other) const { return syntax_version_ == other.syntax_version_ && location_ == other.location_ && line_offsets_ == other.line_offsets_ && positions_ == other.positions_ && macro_calls_ == other.macro_calls_ && extensions_ == other.extensions_; } bool operator!=(const SourceInfo& other) const { return !operator==(other); } const std::vector& extensions() const { return extensions_; } std::vector& mutable_extensions() { return extensions_; } private: // The syntax version of the source, e.g. `cel1`. std::string syntax_version_; // The location name. All position information attached to an expression is // relative to this location. // // The location could be a file, UI element, or similar. For example, // `acme/app/AnvilPolicy.cel`. std::string location_; // Monotonically increasing list of code point offsets where newlines // `\n` appear. // // The line number of a given position is the index `i` where for a given // `id` the `line_offsets[i] < id_positions[id] < line_offsets[i+1]`. The // column may be derivd from `id_positions[id] - line_offsets[i]`. // // TODO(uncreated-issue/14): clarify this documentation std::vector line_offsets_; // A map from the parse node id (e.g. `Expr.id`) to the code point offset // within source. absl::flat_hash_map positions_; // A map from the parse node id where a macro replacement was made to the // call `Expr` that resulted in a macro expansion. // // For example, `has(value.field)` is a function call that is replaced by a // `test_only` field selection in the AST. Likewise, the call // `list.exists(e, e > 10)` translates to a comprehension expression. The key // in the map corresponds to the expression id of the expanded macro, and the // value is the call `Expr` that was replaced. absl::flat_hash_map macro_calls_; // A list of tags for extensions that were used while parsing or type checking // the source expression. For example, optimizations that require special // runtime support may be specified. // // These are used to check feature support between components in separate // implementations. This can be used to either skip redundant work or // report an error if the extension is unsupported. std::vector extensions_; }; // CEL primitive types. enum class PrimitiveType { // Unspecified type. kPrimitiveTypeUnspecified = 0, // Boolean type. kBool = 1, // Int64 type. // // Proto-based integer values are widened to int64. kInt64 = 2, // Uint64 type. // // Proto-based unsigned integer values are widened to uint64. kUint64 = 3, // Double type. // // Proto-based float values are widened to double values. kDouble = 4, // String type. kString = 5, // Bytes type. kBytes = 6, }; // Well-known protobuf types treated with first-class support in CEL. // // TODO(uncreated-issue/15): represent well-known via abstract types (or however) // they will be named. enum class WellKnownTypeSpec { // Unspecified type. kWellKnownTypeUnspecified = 0, // Well-known protobuf.Any type. // // Any types are a polymorphic message type. During type-checking they are // treated like `DYN` types, but at runtime they are resolved to a specific // message type specified at evaluation time. kAny = 1, // Well-known protobuf.Timestamp type, internally referenced as `timestamp`. kTimestamp = 2, // Well-known protobuf.Duration type, internally referenced as `duration`. kDuration = 3, }; // forward declare for recursive types. class TypeSpec; // List type with typed elements, e.g. `list`. class ListTypeSpec { public: ListTypeSpec() = default; ListTypeSpec(const ListTypeSpec& rhs); ListTypeSpec& operator=(const ListTypeSpec& rhs); ListTypeSpec(ListTypeSpec&& rhs) = default; ListTypeSpec& operator=(ListTypeSpec&& rhs) = default; explicit ListTypeSpec(std::unique_ptr elem_type); void set_elem_type(std::unique_ptr elem_type); bool has_elem_type() const { return elem_type_ != nullptr; } const TypeSpec& elem_type() const; TypeSpec& mutable_elem_type(); bool operator==(const ListTypeSpec& other) const; private: std::unique_ptr elem_type_; }; // Map type specifier with parameterized key and value types, e.g. // `map`. class MapTypeSpec { public: MapTypeSpec() = default; MapTypeSpec(std::unique_ptr key_type, std::unique_ptr value_type); MapTypeSpec(const MapTypeSpec& rhs); MapTypeSpec& operator=(const MapTypeSpec& rhs); MapTypeSpec(MapTypeSpec&& rhs) = default; MapTypeSpec& operator=(MapTypeSpec&& rhs) = default; void set_key_type(std::unique_ptr key_type); void set_value_type(std::unique_ptr value_type); bool has_key_type() const { return key_type_ != nullptr; } bool has_value_type() const { return value_type_ != nullptr; } const TypeSpec& key_type() const; const TypeSpec& value_type() const; bool operator==(const MapTypeSpec& other) const; TypeSpec& mutable_key_type(); TypeSpec& mutable_value_type(); private: // The type of the key. std::unique_ptr key_type_; // The type of the value. std::unique_ptr value_type_; }; // Function type specifiers with result and arg types. // // NOTE: function type represents a lambda-style argument to another function. // Supported through macros, but not yet a first-class concept in CEL. class FunctionTypeSpec { public: FunctionTypeSpec() = default; FunctionTypeSpec(std::unique_ptr result_type, std::vector arg_types); FunctionTypeSpec(const FunctionTypeSpec& other); FunctionTypeSpec& operator=(const FunctionTypeSpec& other); FunctionTypeSpec(FunctionTypeSpec&&) = default; FunctionTypeSpec& operator=(FunctionTypeSpec&&) = default; void set_result_type(std::unique_ptr result_type); void set_arg_types(std::vector arg_types); bool has_result_type() const { return result_type_ != nullptr; } const TypeSpec& result_type() const; TypeSpec& mutable_result_type(); const std::vector& arg_types() const { return arg_types_; } std::vector& mutable_arg_types() { return arg_types_; } bool operator==(const FunctionTypeSpec& other) const; private: // Result type of the function. std::unique_ptr result_type_; // Argument types of the function. std::vector arg_types_; }; // Application defined abstract type. // // Abstract types provide a name as an identifier for the application, and // optionally one or more type parameters. // // For cel::Type representation, see OpaqueType. class AbstractType { public: AbstractType() = default; AbstractType(std::string name, std::vector parameter_types); void set_name(std::string name) { name_ = std::move(name); } void set_parameter_types(std::vector parameter_types); const std::string& name() const { return name_; } const std::vector& parameter_types() const { return parameter_types_; } std::vector& mutable_parameter_types() { return parameter_types_; } bool operator==(const AbstractType& other) const; private: // The fully qualified name of this abstract type. std::string name_; // Parameter types for this abstract type. std::vector parameter_types_; }; // Wrapper of a primitive type, e.g. `google.protobuf.Int64Value`. class PrimitiveTypeWrapper { public: explicit PrimitiveTypeWrapper(PrimitiveType type) : type_(std::move(type)) {} void set_type(PrimitiveType type) { type_ = std::move(type); } const PrimitiveType& type() const { return type_; } PrimitiveType& mutable_type() { return type_; } bool operator==(const PrimitiveTypeWrapper& other) const { return type_ == other.type_; } private: PrimitiveType type_; }; // Protocol buffer message type specifier. // // The `message_type` string specifies the qualified message type name. For // example, `google.plus.Profile`. This must be mapped to a google::protobuf::Descriptor // for type checking. class MessageTypeSpec { public: MessageTypeSpec() = default; explicit MessageTypeSpec(std::string type) : type_(std::move(type)) {} void set_type(std::string type) { type_ = std::move(type); } const std::string& type() const { return type_; } bool operator==(const MessageTypeSpec& other) const { return type_ == other.type_; } private: std::string type_; }; // TypeSpec param type. // // The `type_param` string specifies the type parameter name, e.g. `list` // would be a `list_type` whose element type was a `type_param` type // named `E`. class ParamTypeSpec { public: ParamTypeSpec() = default; explicit ParamTypeSpec(std::string type) : type_(std::move(type)) {} void set_type(std::string type) { type_ = std::move(type); } const std::string& type() const { return type_; } bool operator==(const ParamTypeSpec& other) const { return type_ == other.type_; } private: std::string type_; }; // Error type specifier. // // During type-checking if an expression is an error, its type is propagated // as the `ERROR` type. This permits the type-checker to discover other // errors present in the expression. enum class ErrorTypeSpec { kValue = 0 }; using UnsetTypeSpec = absl::monostate; struct DynTypeSpec {}; inline bool operator==(const DynTypeSpec&, const DynTypeSpec&) { return true; } inline bool operator!=(const DynTypeSpec&, const DynTypeSpec&) { return false; } struct NullTypeSpec {}; inline bool operator==(const NullTypeSpec&, const NullTypeSpec&) { return true; } inline bool operator!=(const NullTypeSpec&, const NullTypeSpec&) { return false; } using TypeSpecKind = absl::variant, ErrorTypeSpec, AbstractType>; // Analogous to cel::expr::Type. // Represents a CEL type. // // TODO(uncreated-issue/15): align with value.proto class TypeSpec { public: TypeSpec() = default; explicit TypeSpec(TypeSpecKind type_kind) : type_kind_(std::move(type_kind)) {} TypeSpec(const TypeSpec& other); TypeSpec& operator=(const TypeSpec& other); TypeSpec(TypeSpec&&) = default; TypeSpec& operator=(TypeSpec&&) = default; void set_type_kind(TypeSpecKind type_kind) { type_kind_ = std::move(type_kind); } const TypeSpecKind& type_kind() const { return type_kind_; } TypeSpecKind& mutable_type_kind() { return type_kind_; } bool has_dyn() const { return absl::holds_alternative(type_kind_); } bool has_null() const { return absl::holds_alternative(type_kind_); } bool has_primitive() const { return absl::holds_alternative(type_kind_); } bool has_wrapper() const { return absl::holds_alternative(type_kind_); } bool has_well_known() const { return absl::holds_alternative(type_kind_); } bool has_list_type() const { return absl::holds_alternative(type_kind_); } bool has_map_type() const { return absl::holds_alternative(type_kind_); } bool has_function() const { return absl::holds_alternative(type_kind_); } bool has_message_type() const { return absl::holds_alternative(type_kind_); } bool has_type_param() const { return absl::holds_alternative(type_kind_); } bool has_type() const { return absl::holds_alternative>(type_kind_); } bool has_error() const { return absl::holds_alternative(type_kind_); } bool has_abstract_type() const { return absl::holds_alternative(type_kind_); } NullTypeSpec null() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } return {}; } PrimitiveType primitive() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } return PrimitiveType::kPrimitiveTypeUnspecified; } PrimitiveType wrapper() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return value->type(); } return PrimitiveType::kPrimitiveTypeUnspecified; } WellKnownTypeSpec well_known() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } return WellKnownTypeSpec::kWellKnownTypeUnspecified; } const ListTypeSpec& list_type() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } static const ListTypeSpec* default_list_type = new ListTypeSpec(); return *default_list_type; } const MapTypeSpec& map_type() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } static const MapTypeSpec* default_map_type = new MapTypeSpec(); return *default_map_type; } const FunctionTypeSpec& function() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } static const FunctionTypeSpec* default_function_type = new FunctionTypeSpec(); return *default_function_type; } const MessageTypeSpec& message_type() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } static const MessageTypeSpec* default_message_type = new MessageTypeSpec(); return *default_message_type; } const ParamTypeSpec& type_param() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } static const ParamTypeSpec* default_param_type = new ParamTypeSpec(); return *default_param_type; } const TypeSpec& type() const; ErrorTypeSpec error_type() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } return ErrorTypeSpec::kValue; } const AbstractType& abstract_type() const { auto* value = absl::get_if(&type_kind_); if (value != nullptr) { return *value; } static const AbstractType* default_abstract_type = new AbstractType(); return *default_abstract_type; } bool operator==(const TypeSpec& other) const { if (absl::holds_alternative>(type_kind_) && absl::holds_alternative>(other.type_kind_)) { const auto& self_type = absl::get>(type_kind_); const auto& other_type = absl::get>(other.type_kind_); if (self_type == nullptr || other_type == nullptr) { return self_type == other_type; } return *self_type == *other_type; } return type_kind_ == other.type_kind_; } private: TypeSpecKind type_kind_; }; // Returns a string representation of the given TypeSpec. std::string FormatTypeSpec(const TypeSpec& t); // Describes a resolved reference to a declaration. class Reference { public: Reference() = default; Reference(std::string name, std::vector overload_id, Constant value) : name_(std::move(name)), overload_id_(std::move(overload_id)), value_(std::move(value)) {} Reference(const Reference& other) = default; Reference& operator=(const Reference& other) = default; Reference(Reference&&) = default; Reference& operator=(Reference&&) = default; void set_name(std::string name) { name_ = std::move(name); } void set_overload_id(std::vector overload_id) { overload_id_ = std::move(overload_id); } void set_value(Constant value) { value_ = std::move(value); } const std::string& name() const { return name_; } const std::vector& overload_id() const { return overload_id_; } const Constant& value() const { if (value_.has_value()) { return value_.value(); } static const Constant* default_constant = new Constant; return *default_constant; } std::vector& mutable_overload_id() { return overload_id_; } Constant& mutable_value() { if (!value_.has_value()) { value_.emplace(); } return *value_; } bool has_value() const { return value_.has_value(); } bool operator==(const Reference& other) const { return name_ == other.name_ && overload_id_ == other.overload_id_ && value() == other.value(); } private: // The fully qualified name of the declaration. std::string name_; // For references to functions, this is a list of `Overload.overload_id` // values which match according to typing rules. // // If the list has more than one element, overload resolution among the // presented candidates must happen at runtime because of dynamic types. The // type checker attempts to narrow down this list as much as possible. // // Empty if this is not a reference to a [Decl.FunctionDecl][]. std::vector overload_id_; // For references to constants, this may contain the value of the // constant if known at compile time. absl::optional value_; }; //////////////////////////////////////////////////////////////////////// // Out-of-line method declarations //////////////////////////////////////////////////////////////////////// inline ListTypeSpec::ListTypeSpec(const ListTypeSpec& rhs) : elem_type_(std::make_unique(rhs.elem_type())) {} inline ListTypeSpec& ListTypeSpec::operator=(const ListTypeSpec& rhs) { elem_type_ = std::make_unique(rhs.elem_type()); return *this; } inline ListTypeSpec::ListTypeSpec(std::unique_ptr elem_type) : elem_type_(std::move(elem_type)) {} inline void ListTypeSpec::set_elem_type(std::unique_ptr elem_type) { elem_type_ = std::move(elem_type); } inline TypeSpec& ListTypeSpec::mutable_elem_type() { if (elem_type_ == nullptr) { elem_type_ = std::make_unique(); } return *elem_type_; } inline MapTypeSpec::MapTypeSpec(std::unique_ptr key_type, std::unique_ptr value_type) : key_type_(std::move(key_type)), value_type_(std::move(value_type)) {} inline MapTypeSpec::MapTypeSpec(const MapTypeSpec& rhs) : key_type_(std::make_unique(rhs.key_type())), value_type_(std::make_unique(rhs.value_type())) {} inline MapTypeSpec& MapTypeSpec::operator=(const MapTypeSpec& rhs) { key_type_ = std::make_unique(rhs.key_type()); value_type_ = std::make_unique(rhs.value_type()); return *this; } inline void MapTypeSpec::set_key_type(std::unique_ptr key_type) { key_type_ = std::move(key_type); } inline void MapTypeSpec::set_value_type(std::unique_ptr value_type) { value_type_ = std::move(value_type); } inline TypeSpec& MapTypeSpec::mutable_key_type() { if (key_type_ == nullptr) { key_type_ = std::make_unique(); } return *key_type_; } inline TypeSpec& MapTypeSpec::mutable_value_type() { if (value_type_ == nullptr) { value_type_ = std::make_unique(); } return *value_type_; } inline void FunctionTypeSpec::set_result_type( std::unique_ptr result_type) { result_type_ = std::move(result_type); } inline TypeSpec& FunctionTypeSpec::mutable_result_type() { if (result_type_ == nullptr) { result_type_ = std::make_unique(); } return *result_type_; } //////////////////////////////////////////////////////////////////////// // Implementation details //////////////////////////////////////////////////////////////////////// inline FunctionTypeSpec::FunctionTypeSpec(std::unique_ptr result_type, std::vector arg_types) : result_type_(std::move(result_type)), arg_types_(std::move(arg_types)) {} inline void FunctionTypeSpec::set_arg_types(std::vector arg_types) { arg_types_ = std::move(arg_types); } inline AbstractType::AbstractType(std::string name, std::vector parameter_types) : name_(std::move(name)), parameter_types_(std::move(parameter_types)) {} inline void AbstractType::set_parameter_types( std::vector parameter_types) { parameter_types_ = std::move(parameter_types); } inline bool AbstractType::operator==(const AbstractType& other) const { return name_ == other.name_ && parameter_types_ == other.parameter_types_; } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_METADATA_H_ ================================================ FILE: common/ast/metadata_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast/metadata.h" #include #include #include #include "absl/types/variant.h" #include "common/expr.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::ElementsAre; TEST(AstTest, ListTypeSpecMutableConstruction) { ListTypeSpec type; type.mutable_elem_type() = TypeSpec(PrimitiveType::kBool); EXPECT_EQ(absl::get(type.elem_type().type_kind()), PrimitiveType::kBool); } TEST(AstTest, MapTypeSpecMutableConstruction) { MapTypeSpec type; type.mutable_key_type() = TypeSpec(PrimitiveType::kBool); type.mutable_value_type() = TypeSpec(PrimitiveType::kBool); EXPECT_EQ(absl::get(type.key_type().type_kind()), PrimitiveType::kBool); EXPECT_EQ(absl::get(type.value_type().type_kind()), PrimitiveType::kBool); } TEST(AstTest, MapTypeSpecComparatorKeyType) { MapTypeSpec type; type.mutable_key_type() = TypeSpec(PrimitiveType::kBool); EXPECT_FALSE(type == MapTypeSpec()); } TEST(AstTest, MapTypeSpecComparatorValueType) { MapTypeSpec type; type.mutable_value_type() = TypeSpec(PrimitiveType::kBool); EXPECT_FALSE(type == MapTypeSpec()); } TEST(AstTest, FunctionTypeSpecMutableConstruction) { FunctionTypeSpec type; type.mutable_result_type() = TypeSpec(PrimitiveType::kBool); EXPECT_EQ(absl::get(type.result_type().type_kind()), PrimitiveType::kBool); } TEST(AstTest, FunctionTypeSpecComparatorArgTypes) { FunctionTypeSpec type; type.mutable_arg_types().emplace_back(TypeSpec()); EXPECT_FALSE(type == FunctionTypeSpec()); } TEST(AstTest, ListTypeSpecDefaults) { EXPECT_EQ(ListTypeSpec().elem_type(), TypeSpec()); } TEST(AstTest, MapTypeSpecDefaults) { EXPECT_EQ(MapTypeSpec().key_type(), TypeSpec()); EXPECT_EQ(MapTypeSpec().value_type(), TypeSpec()); } TEST(AstTest, FunctionTypeSpecDefaults) { EXPECT_EQ(FunctionTypeSpec().result_type(), TypeSpec()); } TEST(AstTest, TypeDefaults) { EXPECT_EQ(TypeSpec().null(), NullTypeSpec()); EXPECT_EQ(TypeSpec().primitive(), PrimitiveType::kPrimitiveTypeUnspecified); EXPECT_EQ(TypeSpec().wrapper(), PrimitiveType::kPrimitiveTypeUnspecified); EXPECT_EQ(TypeSpec().well_known(), WellKnownTypeSpec::kWellKnownTypeUnspecified); EXPECT_EQ(TypeSpec().list_type(), ListTypeSpec()); EXPECT_EQ(TypeSpec().map_type(), MapTypeSpec()); EXPECT_EQ(TypeSpec().function(), FunctionTypeSpec()); EXPECT_EQ(TypeSpec().message_type(), MessageTypeSpec()); EXPECT_EQ(TypeSpec().type_param(), ParamTypeSpec()); EXPECT_EQ(TypeSpec().type(), TypeSpec()); EXPECT_EQ(TypeSpec().error_type(), ErrorTypeSpec()); EXPECT_EQ(TypeSpec().abstract_type(), AbstractType()); } TEST(AstTest, TypeComparatorTest) { TypeSpec type; type.set_type_kind(std::make_unique(PrimitiveType::kBool)); EXPECT_TRUE(type == TypeSpec(std::make_unique(PrimitiveType::kBool))); EXPECT_FALSE(type == TypeSpec(PrimitiveType::kBool)); EXPECT_FALSE(type == TypeSpec(std::unique_ptr())); EXPECT_FALSE(type == TypeSpec(std::make_unique(PrimitiveType::kInt64))); } TEST(AstTest, ExprMutableConstruction) { Expr expr; expr.mutable_const_expr().set_bool_value(true); ASSERT_TRUE(expr.has_const_expr()); EXPECT_TRUE(expr.const_expr().bool_value()); expr.mutable_ident_expr().set_name("expr"); ASSERT_TRUE(expr.has_ident_expr()); EXPECT_FALSE(expr.has_const_expr()); EXPECT_EQ(expr.ident_expr().name(), "expr"); expr.mutable_select_expr().set_field("field"); ASSERT_TRUE(expr.has_select_expr()); EXPECT_FALSE(expr.has_ident_expr()); EXPECT_EQ(expr.select_expr().field(), "field"); expr.mutable_call_expr().set_function("function"); ASSERT_TRUE(expr.has_call_expr()); EXPECT_FALSE(expr.has_select_expr()); EXPECT_EQ(expr.call_expr().function(), "function"); expr.mutable_list_expr(); EXPECT_TRUE(expr.has_list_expr()); EXPECT_FALSE(expr.has_call_expr()); expr.mutable_struct_expr().set_name("name"); ASSERT_TRUE(expr.has_struct_expr()); EXPECT_EQ(expr.struct_expr().name(), "name"); EXPECT_FALSE(expr.has_list_expr()); expr.mutable_comprehension_expr().set_accu_var("accu_var"); ASSERT_TRUE(expr.has_comprehension_expr()); EXPECT_FALSE(expr.has_list_expr()); EXPECT_EQ(expr.comprehension_expr().accu_var(), "accu_var"); } TEST(AstTest, ReferenceConstantDefaultValue) { Reference reference; EXPECT_EQ(reference.value(), Constant()); } TEST(AstTest, TypeCopyable) { TypeSpec type = TypeSpec(PrimitiveType::kBool); TypeSpec type2 = type; EXPECT_TRUE(type2.has_primitive()); EXPECT_EQ(type2, type); type = TypeSpec(ListTypeSpec(std::make_unique(PrimitiveType::kBool))); type2 = type; EXPECT_TRUE(type2.has_list_type()); EXPECT_EQ(type2, type); type = TypeSpec(MapTypeSpec(std::make_unique(PrimitiveType::kBool), std::make_unique(PrimitiveType::kBool))); type2 = type; EXPECT_TRUE(type2.has_map_type()); EXPECT_EQ(type2, type); type = TypeSpec( FunctionTypeSpec(std::make_unique(PrimitiveType::kBool), {})); type2 = type; EXPECT_TRUE(type2.has_function()); EXPECT_EQ(type2, type); type = TypeSpec(AbstractType("optional", {TypeSpec(PrimitiveType::kBool)})); type2 = type; EXPECT_TRUE(type2.has_abstract_type()); EXPECT_EQ(type2, type); } TEST(AstTest, TypeMoveable) { TypeSpec type = TypeSpec(PrimitiveType::kBool); TypeSpec type2 = type; TypeSpec type3 = std::move(type); EXPECT_TRUE(type2.has_primitive()); EXPECT_EQ(type2, type3); type = TypeSpec(ListTypeSpec(std::make_unique(PrimitiveType::kBool))); type2 = type; type3 = std::move(type); EXPECT_TRUE(type2.has_list_type()); EXPECT_EQ(type2, type3); type = TypeSpec(MapTypeSpec(std::make_unique(PrimitiveType::kBool), std::make_unique(PrimitiveType::kBool))); type2 = type; type3 = std::move(type); EXPECT_TRUE(type2.has_map_type()); EXPECT_EQ(type2, type3); type = TypeSpec( FunctionTypeSpec(std::make_unique(PrimitiveType::kBool), {})); type2 = type; type3 = std::move(type); EXPECT_TRUE(type2.has_function()); EXPECT_EQ(type2, type3); type = TypeSpec(AbstractType("optional", {TypeSpec(PrimitiveType::kBool)})); type2 = type; type3 = std::move(type); EXPECT_TRUE(type2.has_abstract_type()); EXPECT_EQ(type2, type3); } TEST(AstTest, NestedTypeKindCopyAssignable) { ListTypeSpec list_type(std::make_unique(PrimitiveType::kBool)); ListTypeSpec list_type2; list_type2 = list_type; EXPECT_EQ(list_type2, list_type); MapTypeSpec map_type(std::make_unique(PrimitiveType::kBool), std::make_unique(PrimitiveType::kBool)); MapTypeSpec map_type2; map_type2 = map_type; AbstractType abstract_type("abstract", {TypeSpec(PrimitiveType::kBool), TypeSpec(PrimitiveType::kBool)}); AbstractType abstract_type2; abstract_type2 = abstract_type; EXPECT_EQ(abstract_type2, abstract_type); FunctionTypeSpec function_type( std::make_unique(PrimitiveType::kBool), {TypeSpec(PrimitiveType::kBool), TypeSpec(PrimitiveType::kBool)}); FunctionTypeSpec function_type2; function_type2 = function_type; EXPECT_EQ(function_type2, function_type); } TEST(AstTest, ExtensionSupported) { SourceInfo source_info; source_info.mutable_extensions().push_back( ExtensionSpec("constant_folding", nullptr, {})); EXPECT_EQ(source_info.extensions()[0], ExtensionSpec("constant_folding", nullptr, {})); } TEST(AstTest, ExtensionSpecEquality) { ExtensionSpec extension1("constant_folding", nullptr, {}); EXPECT_EQ(extension1, ExtensionSpec("constant_folding", nullptr, {})); EXPECT_NE(extension1, ExtensionSpec("constant_folding", std::make_unique(1, 0), {})); EXPECT_NE(extension1, ExtensionSpec("constant_folding", nullptr, {ExtensionSpec::Component::kRuntime})); EXPECT_EQ(extension1, ExtensionSpec("constant_folding", std::make_unique(0, 0), {})); } TEST(AstTest, ExtensionCopyMove) { ExtensionSpec a("constant_folding", nullptr, {}); a.mutable_version().set_major(1); a.mutable_version().set_minor(2); a.mutable_affected_components().push_back(ExtensionSpec::Component::kRuntime); ExtensionSpec b(a); EXPECT_EQ(b.id(), "constant_folding"); EXPECT_EQ(b.version().major(), 1); EXPECT_EQ(b.version().minor(), 2); EXPECT_THAT(b.affected_components(), ElementsAre(ExtensionSpec::Component::kRuntime)); ExtensionSpec c(std::move(b)); EXPECT_EQ(c, a); a.set_version(nullptr); b = a; EXPECT_EQ(b.id(), "constant_folding"); EXPECT_EQ(b.version().major(), 0); EXPECT_EQ(b.version().minor(), 0); EXPECT_THAT(b.affected_components(), ElementsAre(ExtensionSpec::Component::kRuntime)); c = std::move(b); EXPECT_EQ(c, a); } } // namespace } // namespace cel ================================================ FILE: common/ast/navigable_ast_internal.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_INTERNAL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_INTERNAL_H_ #include #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/types/span.h" #include "common/ast/navigable_ast_kinds.h" // IWYU pragma: keep namespace cel::common_internal { // Implementation for range used for traversals backed by an absl::Span. // // This is intended to abstract the metadata layout from clients using the // traversal methods in navigable_expr.h // // RangeTraits provide type info needed to construct the span and adapt to the // range element type. template class NavigableAstRange { private: using UnderlyingType = typename RangeTraits::UnderlyingType; using PtrType = const UnderlyingType*; using SpanType = absl::Span; public: class Iterator { public: using difference_type = ptrdiff_t; using value_type = decltype(RangeTraits::Adapt(*PtrType())); using iterator_category = std::bidirectional_iterator_tag; Iterator() : ptr_(nullptr), span_() {} Iterator(SpanType span, size_t i) : ptr_(span.data() + i), span_(span) {} value_type operator*() const { ABSL_DCHECK(ptr_ != nullptr); ABSL_DCHECK(span_.data() != nullptr); ABSL_DCHECK_GE(ptr_, span_.data()); ABSL_DCHECK_LT(ptr_, span_.data() + span_.size()); return RangeTraits::Adapt(*ptr_); } template std::enable_if_t::value, std::add_pointer_t>> operator->() const { return &operator*(); } Iterator& operator++() { ++ptr_; return *this; } Iterator operator++(int) { Iterator tmp = *this; ++ptr_; return tmp; } Iterator& operator--() { --ptr_; return *this; } Iterator operator--(int) { Iterator tmp = *this; --ptr_; return tmp; } bool operator==(const Iterator& other) const { return ptr_ == other.ptr_ && span_ == other.span_; } bool operator!=(const Iterator& other) const { return !(*this == other); } private: PtrType ptr_; SpanType span_; }; explicit NavigableAstRange(SpanType span) : span_(span) {} Iterator begin() const { return Iterator(span_, 0); } Iterator end() const { return Iterator(span_, span_.size()); } explicit operator bool() const { return !span_.empty(); } private: SpanType span_; }; template struct NavigableAstMetadata; // Internal implementation for data-structures handling cross-referencing nodes. // // This is exposed separately to allow building up the AST relationships // without exposing too much mutable state on the client facing classes. template struct NavigableAstNodeData { typename AstTraits::NodeType* parent; const typename AstTraits::ExprType* expr; ChildKind parent_relation; NodeKind node_kind; const NavigableAstMetadata* absl_nonnull metadata; size_t index; size_t tree_size; size_t height; int child_index; std::vector children; }; template struct NavigableAstMetadata { // The nodes in the AST in preorder. // // unique_ptr is used to guarantee pointer stability in the other tables. std::vector> nodes; std::vector postorder; absl::flat_hash_map id_to_node; absl::flat_hash_map expr_to_node; }; template struct PostorderTraits { using UnderlyingType = const AstNode*; static const AstNode& Adapt(const AstNode* const node) { return *node; } }; template struct PreorderTraits { using UnderlyingType = std::unique_ptr; static const AstNode& Adapt(const std::unique_ptr& node) { return *node; } }; // Base class for NavigableAstNode and NavigableProtoAstNode. template class NavigableAstNodeBase { private: using MetadataType = NavigableAstMetadata; using NodeDataType = NavigableAstNodeData; using Derived = typename AstTraits::NodeType; using ExprType = typename AstTraits::ExprType; public: using PreorderRange = NavigableAstRange>; using PostorderRange = NavigableAstRange>; // The parent of this node or nullptr if it is a root. const Derived* absl_nullable parent() const { return data_.parent; } const ExprType* absl_nonnull expr() const { return data_.expr; } // The index of this node in the parent's children. -1 if this is a root. int child_index() const { return data_.child_index; } // The type of traversal from parent to this node. ChildKind parent_relation() const { return data_.parent_relation; } // The type of this node, analogous to Expr::ExprKindCase. NodeKind node_kind() const { return data_.node_kind; } // The number of nodes in the tree rooted at this node (including self). size_t tree_size() const { return data_.tree_size; } // The height of this node in the tree (the number of descendants including // self on the longest path). size_t height() const { return data_.height; } absl::Span children() const { return absl::MakeConstSpan(data_.children); } // Range over the descendants of this node (including self) using preorder // semantics. Each node is visited immediately before all of its descendants. PreorderRange DescendantsPreorder() const { return PreorderRange(absl::MakeConstSpan(data_.metadata->nodes) .subspan(data_.index, data_.tree_size)); } // Range over the descendants of this node (including self) using postorder // semantics. Each node is visited immediately after all of its descendants. PostorderRange DescendantsPostorder() const { return PostorderRange(absl::MakeConstSpan(data_.metadata->postorder) .subspan(data_.index, data_.tree_size)); } private: friend Derived; NavigableAstNodeBase() = default; NavigableAstNodeBase(const NavigableAstNodeBase&) = delete; NavigableAstNodeBase& operator=(const NavigableAstNodeBase&) = delete; protected: NodeDataType data_; }; // Shared implementation for NavigableAst and NavigableProtoAst. // // AstTraits provides type info for the derived classes that implement building // the traversal metadata. It provides the following types: // // ExprType is the expression node type of the source AST. // // AstType is the subclass of NavigableAstBase for the implementation. // // NodeType is the subclass of NavigableAstNodeBase for the implementation. template class NavigableAstBase { private: using MetadataType = NavigableAstMetadata; using Derived = typename AstTraits::AstType; using NodeType = typename AstTraits::NodeType; using ExprType = typename AstTraits::ExprType; public: NavigableAstBase(const NavigableAstBase&) = delete; NavigableAstBase& operator=(const NavigableAstBase&) = delete; NavigableAstBase(NavigableAstBase&&) = default; NavigableAstBase& operator=(NavigableAstBase&&) = default; // Return ptr to the AST node with id if present. Otherwise returns nullptr. // // If ids are non-unique, the first pre-order node encountered with id is // returned. const NodeType* absl_nullable FindId(int64_t id) const { auto it = metadata_->id_to_node.find(id); if (it == metadata_->id_to_node.end()) { return nullptr; } return it->second; } // Return ptr to the AST node representing the given Expr protobuf node. const NodeType* absl_nullable FindExpr( const ExprType* absl_nonnull expr) const { auto it = metadata_->expr_to_node.find(expr); if (it == metadata_->expr_to_node.end()) { return nullptr; } return it->second; } // The root of the AST. const NodeType& Root() const { return *metadata_->nodes[0]; } // Check whether the source AST used unique IDs for each node. // // This is typically the case, but older versions of the parsers didn't // guarantee uniqueness for nodes generated by some macros and ASTs modified // outside of CEL's parse/type check may not have unique IDs. bool IdsAreUnique() const { return metadata_->id_to_node.size() == metadata_->nodes.size(); } // Equality operators test for identity. They are intended to distinguish // moved-from or uninitialized instances from initialized. bool operator==(const NavigableAstBase& other) const { return metadata_ == other.metadata_; } bool operator!=(const NavigableAstBase& other) const { return metadata_ != other.metadata_; } // Return true if this instance is initialized. explicit operator bool() const { return metadata_ != nullptr; } private: friend Derived; NavigableAstBase() = default; explicit NavigableAstBase(std::unique_ptr metadata) : metadata_(std::move(metadata)) {} std::unique_ptr metadata_; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_INTERNAL_H_ ================================================ FILE: common/ast/navigable_ast_internal_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast/navigable_ast_internal.h" #include #include #include "absl/base/casts.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "common/ast/navigable_ast_kinds.h" #include "internal/testing.h" namespace cel::common_internal { namespace { struct TestRangeTraits { using UnderlyingType = int; static double Adapt(const UnderlyingType& value) { return static_cast(value) + 0.5; } }; TEST(NavigableAstRangeTest, BasicIteration) { std::vector values{1, 2, 3}; NavigableAstRange range(absl::MakeConstSpan(values)); absl::Span span(values); auto it = range.begin(); EXPECT_EQ(*it, 1.5); EXPECT_EQ(*++it, 2.5); EXPECT_EQ(*++it, 3.5); EXPECT_EQ(++it, range.end()); EXPECT_EQ(*--it, 3.5); EXPECT_EQ(*--it, 2.5); EXPECT_EQ(*--it, 1.5); EXPECT_EQ(it, range.begin()); } TEST(NodeKind, Stringify) { // Note: the specific values are not important or guaranteed to be stable, // they are only intended to make test outputs clearer. EXPECT_EQ(absl::StrCat(NodeKind::kConstant), "Constant"); EXPECT_EQ(absl::StrCat(NodeKind::kIdent), "Ident"); EXPECT_EQ(absl::StrCat(NodeKind::kSelect), "Select"); EXPECT_EQ(absl::StrCat(NodeKind::kCall), "Call"); EXPECT_EQ(absl::StrCat(NodeKind::kList), "List"); EXPECT_EQ(absl::StrCat(NodeKind::kMap), "Map"); EXPECT_EQ(absl::StrCat(NodeKind::kStruct), "Struct"); EXPECT_EQ(absl::StrCat(NodeKind::kComprehension), "Comprehension"); EXPECT_EQ(absl::StrCat(NodeKind::kUnspecified), "Unspecified"); EXPECT_EQ(absl::StrCat(absl::bit_cast(255)), "Unknown NodeKind 255"); } TEST(ChildKind, Stringify) { // Note: the specific values are not important or guaranteed to be stable, // they are only intended to make test outputs clearer. EXPECT_EQ(absl::StrCat(ChildKind::kSelectOperand), "SelectOperand"); EXPECT_EQ(absl::StrCat(ChildKind::kCallReceiver), "CallReceiver"); EXPECT_EQ(absl::StrCat(ChildKind::kCallArg), "CallArg"); EXPECT_EQ(absl::StrCat(ChildKind::kListElem), "ListElem"); EXPECT_EQ(absl::StrCat(ChildKind::kMapKey), "MapKey"); EXPECT_EQ(absl::StrCat(ChildKind::kMapValue), "MapValue"); EXPECT_EQ(absl::StrCat(ChildKind::kStructValue), "StructValue"); EXPECT_EQ(absl::StrCat(ChildKind::kComprehensionRange), "ComprehensionRange"); EXPECT_EQ(absl::StrCat(ChildKind::kComprehensionInit), "ComprehensionInit"); EXPECT_EQ(absl::StrCat(ChildKind::kComprehensionCondition), "ComprehensionCondition"); EXPECT_EQ(absl::StrCat(ChildKind::kComprehensionLoopStep), "ComprehensionLoopStep"); EXPECT_EQ(absl::StrCat(ChildKind::kComprensionResult), "ComprehensionResult"); EXPECT_EQ(absl::StrCat(ChildKind::kUnspecified), "Unspecified"); EXPECT_EQ(absl::StrCat(absl::bit_cast(255)), "Unknown ChildKind 255"); } } // namespace } // namespace cel::common_internal ================================================ FILE: common/ast/navigable_ast_kinds.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast/navigable_ast_kinds.h" #include #include "absl/strings/str_cat.h" namespace cel { std::string ChildKindName(ChildKind kind) { switch (kind) { case ChildKind::kUnspecified: return "Unspecified"; case ChildKind::kSelectOperand: return "SelectOperand"; case ChildKind::kCallReceiver: return "CallReceiver"; case ChildKind::kCallArg: return "CallArg"; case ChildKind::kListElem: return "ListElem"; case ChildKind::kMapKey: return "MapKey"; case ChildKind::kMapValue: return "MapValue"; case ChildKind::kStructValue: return "StructValue"; case ChildKind::kComprehensionRange: return "ComprehensionRange"; case ChildKind::kComprehensionInit: return "ComprehensionInit"; case ChildKind::kComprehensionCondition: return "ComprehensionCondition"; case ChildKind::kComprehensionLoopStep: return "ComprehensionLoopStep"; case ChildKind::kComprensionResult: return "ComprehensionResult"; default: return absl::StrCat("Unknown ChildKind ", static_cast(kind)); } } std::string NodeKindName(NodeKind kind) { switch (kind) { case NodeKind::kUnspecified: return "Unspecified"; case NodeKind::kConstant: return "Constant"; case NodeKind::kIdent: return "Ident"; case NodeKind::kSelect: return "Select"; case NodeKind::kCall: return "Call"; case NodeKind::kList: return "List"; case NodeKind::kMap: return "Map"; case NodeKind::kStruct: return "Struct"; case NodeKind::kComprehension: return "Comprehension"; default: return absl::StrCat("Unknown NodeKind ", static_cast(kind)); } } } // namespace cel ================================================ FILE: common/ast/navigable_ast_kinds.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_KINDS_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_KINDS_H_ #include #include "absl/strings/str_format.h" namespace cel { // The traversal relationship from parent to the given node in a NavigableAst. enum class ChildKind { kUnspecified, kSelectOperand, kCallReceiver, kCallArg, kListElem, kMapKey, kMapValue, kStructValue, kComprehensionRange, kComprehensionInit, kComprehensionCondition, kComprehensionLoopStep, kComprensionResult }; // The type of the node in a NavigableAst. enum class NodeKind { kUnspecified, kConstant, kIdent, kSelect, kCall, kList, kMap, kStruct, kComprehension, }; // Human readable ChildKind name. Provided for test readability -- do not depend // on the specific values. std::string ChildKindName(ChildKind kind); template void AbslStringify(Sink& sink, ChildKind kind) { absl::Format(&sink, "%s", ChildKindName(kind)); } // Human readable NodeKind name. Provided for test readability -- do not depend // on the specific values. std::string NodeKindName(NodeKind kind); template void AbslStringify(Sink& sink, NodeKind kind) { absl::Format(&sink, "%s", NodeKindName(kind)); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_KINDS_H_ ================================================ FILE: common/ast/source_info_proto.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast/source_info_proto.h" #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/status/status.h" #include "common/ast.h" #include "common/ast/expr_proto.h" #include "internal/status_macros.h" namespace cel::ast_internal { using ::cel::ast_internal::ExprToProto; using ExprPb = cel::expr::Expr; using ParsedExprPb = cel::expr::ParsedExpr; using CheckedExprPb = cel::expr::CheckedExpr; using ExtensionPb = cel::expr::SourceInfo::Extension; absl::Status SourceInfoToProto(const cel::SourceInfo& source_info, cel::expr::SourceInfo* out) { cel::expr::SourceInfo& result = *out; result.set_syntax_version(source_info.syntax_version()); result.set_location(source_info.location()); for (int32_t line_offset : source_info.line_offsets()) { result.add_line_offsets(line_offset); } for (auto pos_iter = source_info.positions().begin(); pos_iter != source_info.positions().end(); ++pos_iter) { (*result.mutable_positions())[pos_iter->first] = pos_iter->second; } for (auto macro_iter = source_info.macro_calls().begin(); macro_iter != source_info.macro_calls().end(); ++macro_iter) { ExprPb& dest_macro = (*result.mutable_macro_calls())[macro_iter->first]; CEL_RETURN_IF_ERROR(ExprToProto(macro_iter->second, &dest_macro)); } for (const auto& extension : source_info.extensions()) { auto* extension_pb = result.add_extensions(); extension_pb->set_id(extension.id()); auto* version_pb = extension_pb->mutable_version(); version_pb->set_major(extension.version().major()); version_pb->set_minor(extension.version().minor()); for (auto component : extension.affected_components()) { switch (component) { case cel::ExtensionSpec::Component::kParser: extension_pb->add_affected_components(ExtensionPb::COMPONENT_PARSER); break; case cel::ExtensionSpec::Component::kTypeChecker: extension_pb->add_affected_components( ExtensionPb::COMPONENT_TYPE_CHECKER); break; case cel::ExtensionSpec::Component::kRuntime: extension_pb->add_affected_components(ExtensionPb::COMPONENT_RUNTIME); break; default: extension_pb->add_affected_components( ExtensionPb::COMPONENT_UNSPECIFIED); break; } } } return absl::OkStatus(); } } // namespace cel::ast_internal ================================================ FILE: common/ast/source_info_proto.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_SOURCE_INFO_PROTO_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_SOURCE_INFO_PROTO_H_ #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/ast.h" namespace cel::ast_internal { // Conversion utility for the CEL-C++ source info representation to the protobuf // representation. absl::Status SourceInfoToProto(const SourceInfo& source_info, cel::expr::SourceInfo* absl_nonnull out); } // namespace cel::ast_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_SOURCE_INFO_PROTO_H_ ================================================ FILE: common/ast.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast.h" #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "common/ast/metadata.h" #include "common/source.h" namespace cel { namespace { const TypeSpec& DynSingleton() { static absl::NoDestructor singleton{TypeSpecKind(DynTypeSpec())}; return *singleton; } } // namespace const TypeSpec* absl_nullable Ast::GetType(int64_t expr_id) const { auto iter = type_map_.find(expr_id); if (iter == type_map_.end()) { return nullptr; } return &iter->second; } const TypeSpec& Ast::GetTypeOrDyn(int64_t expr_id) const { if (const TypeSpec* type = GetType(expr_id); type != nullptr) { return *type; } return DynSingleton(); } const TypeSpec& Ast::GetReturnType() const { return GetTypeOrDyn(root_expr().id()); } const Reference* absl_nullable Ast::GetReference(int64_t expr_id) const { auto iter = reference_map_.find(expr_id); if (iter == reference_map_.end()) { return nullptr; } return &iter->second; } SourceLocation Ast::ComputeSourceLocation(int64_t expr_id) const { const auto& source_info = this->source_info(); auto iter = source_info.positions().find(expr_id); if (iter == source_info.positions().end()) { return SourceLocation{}; } int32_t absolute_position = iter->second; if (absolute_position < 0) { return SourceLocation{}; } // Find the first line offset that is greater than the absolute position. int32_t line_idx = -1; int32_t offset = 0; for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { int32_t next_offset = source_info.line_offsets()[i]; if (next_offset <= offset) { // Line offset is not monotonically increasing, so line information is // invalid. return SourceLocation{}; } if (absolute_position < next_offset) { line_idx = i; break; } offset = next_offset; } if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { return SourceLocation{}; } int32_t rel_position = absolute_position - offset; return SourceLocation{line_idx + 1, rel_position}; } } // namespace cel ================================================ FILE: common/ast.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_H_ #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "common/ast/metadata.h" // IWYU pragma: export #include "common/expr.h" #include "common/source.h" namespace cel { // In memory representation of a CEL abstract syntax tree. // // If AST inspection or manipulation is needed, prefer to use an existing tool // or traverse the protobuf representation rather than directly manipulating // through this class. See `cel::NavigableAst` and `cel::AstTraverse`. // // Type and reference maps are only populated if the AST is checked. Any changes // to the AST are not automatically reflected in the type or reference maps. // // To create a new instance from a protobuf representation, use the conversion // utilities in `common/ast_proto.h`. class Ast final { public: using ReferenceMap = absl::flat_hash_map; using TypeMap = absl::flat_hash_map; Ast() : is_checked_(false) {} Ast(Expr expr, SourceInfo source_info) : root_expr_(std::move(expr)), source_info_(std::move(source_info)), is_checked_(false) {} Ast(Expr expr, SourceInfo source_info, ReferenceMap reference_map, TypeMap type_map, std::string expr_version) : root_expr_(std::move(expr)), source_info_(std::move(source_info)), reference_map_(std::move(reference_map)), type_map_(std::move(type_map)), expr_version_(std::move(expr_version)), is_checked_(true) {} Ast(const Ast& other) = default; Ast& operator=(const Ast& other) = default; Ast(Ast&& other) = default; Ast& operator=(Ast&& other) = default; // Deprecated. Use `is_checked()` instead. bool IsChecked() const { return is_checked_; } bool is_checked() const { return is_checked_; } void set_is_checked(bool is_checked) { is_checked_ = is_checked; } // The root expression of the AST. // // This is the entry point for evaluation and determines the overall result // of the expression given a context. const Expr& root_expr() const { return root_expr_; } Expr& mutable_root_expr() { return root_expr_; } // Metadata about the source expression. const SourceInfo& source_info() const { return source_info_; } SourceInfo& mutable_source_info() { return source_info_; } // Returns the type of the expression with the given `expr_id`. // // Returns `nullptr` if the expression node is not found or has dynamic type. const TypeSpec* absl_nullable GetType(int64_t expr_id) const; const TypeSpec& GetTypeOrDyn(int64_t expr_id) const; const TypeSpec& GetReturnType() const; // Returns the resolved reference for the expression with the given `expr_id`. // // Returns `nullptr` if the expression node is not found or no reference was // resolved. const Reference* absl_nullable GetReference(int64_t expr_id) const; // A map from expression ids to resolved references. // // The following entries are in this table: // // - An Ident or Select expression is represented here if it resolves to a // declaration. For instance, if `a.b.c` is represented by // `select(select(id(a), b), c)`, and `a.b` resolves to a declaration, // while `c` is a field selection, then the reference is attached to the // nested select expression (but not to the id or or the outer select). // In turn, if `a` resolves to a declaration and `b.c` are field selections, // the reference is attached to the ident expression. // - Every Call expression has an entry here, identifying the function being // called. // - Every CreateStruct expression for a message has an entry, identifying // the message. // // Unpopulated if the AST is not checked. const ReferenceMap& reference_map() const { return reference_map_; } ReferenceMap& mutable_reference_map() { return reference_map_; } // A map from expression ids to types. // // Every expression node which has a type different than DYN has a mapping // here. If an expression has type DYN, it is omitted from this map to save // space. // // Unpopulated if the AST is not checked. const TypeMap& type_map() const { return type_map_; } TypeMap& mutable_type_map() { return type_map_; } // The expr version indicates the major / minor version number of the `expr` // representation. // // The most common reason for a version change will be to indicate to the CEL // runtimes that transformations have been performed on the expr during static // analysis. absl::string_view expr_version() const { return expr_version_; } void set_expr_version(absl::string_view expr_version) { expr_version_ = expr_version; } // Computes the source location (line and column) for the given expression ID // from the source info (which stores absolute positions). // // Returns a default (empty) source location if the expression ID is not found // or the source info is not populated correctly. SourceLocation ComputeSourceLocation(int64_t expr_id) const; private: Expr root_expr_; SourceInfo source_info_; ReferenceMap reference_map_; TypeMap type_map_; std::string expr_version_; bool is_checked_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_H_ ================================================ FILE: common/ast_proto.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast_proto.h" #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/variant.h" #include "common/ast.h" #include "common/ast/constant_proto.h" #include "common/ast/expr_proto.h" #include "common/ast/source_info_proto.h" #include "common/constant.h" #include "common/expr.h" #include "internal/status_macros.h" namespace cel { namespace { using ::cel::ast_internal::ConstantFromProto; using ::cel::ast_internal::ConstantToProto; using ::cel::ast_internal::ExprFromProto; using ::cel::ast_internal::ExprToProto; using ExprPb = cel::expr::Expr; using ParsedExprPb = cel::expr::ParsedExpr; using CheckedExprPb = cel::expr::CheckedExpr; using SourceInfoPb = cel::expr::SourceInfo; using ExtensionPb = cel::expr::SourceInfo::Extension; using ReferencePb = cel::expr::Reference; using TypePb = cel::expr::Type; using ExtensionPb = cel::expr::SourceInfo::Extension; absl::StatusOr ExprValueFromProto(const ExprPb& expr) { Expr result; CEL_RETURN_IF_ERROR(ExprFromProto(expr, result)); return result; } absl::StatusOr ConvertProtoSourceInfoToNative( const cel::expr::SourceInfo& source_info) { absl::flat_hash_map macro_calls; for (const auto& pair : source_info.macro_calls()) { auto native_expr = ExprValueFromProto(pair.second); if (!native_expr.ok()) { return native_expr.status(); } macro_calls.emplace(pair.first, *(std::move(native_expr))); } std::vector extensions; extensions.reserve(source_info.extensions_size()); for (const auto& extension : source_info.extensions()) { std::vector components; components.reserve(extension.affected_components().size()); for (const auto& component : extension.affected_components()) { switch (component) { case ExtensionPb::COMPONENT_PARSER: components.push_back(ExtensionSpec::Component::kParser); break; case ExtensionPb::COMPONENT_TYPE_CHECKER: components.push_back(ExtensionSpec::Component::kTypeChecker); break; case ExtensionPb::COMPONENT_RUNTIME: components.push_back(ExtensionSpec::Component::kRuntime); break; default: components.push_back(ExtensionSpec::Component::kUnspecified); break; } } extensions.push_back(ExtensionSpec( extension.id(), std::make_unique(extension.version().major(), extension.version().minor()), std::move(components))); } return SourceInfo( source_info.syntax_version(), source_info.location(), std::vector(source_info.line_offsets().begin(), source_info.line_offsets().end()), absl::flat_hash_map(source_info.positions().begin(), source_info.positions().end()), std::move(macro_calls), std::move(extensions)); } absl::StatusOr ConvertProtoTypeToNative( const cel::expr::Type& type); absl::StatusOr ToNative( cel::expr::Type::PrimitiveType primitive_type) { switch (primitive_type) { case cel::expr::Type::PRIMITIVE_TYPE_UNSPECIFIED: return PrimitiveType::kPrimitiveTypeUnspecified; case cel::expr::Type::BOOL: return PrimitiveType::kBool; case cel::expr::Type::INT64: return PrimitiveType::kInt64; case cel::expr::Type::UINT64: return PrimitiveType::kUint64; case cel::expr::Type::DOUBLE: return PrimitiveType::kDouble; case cel::expr::Type::STRING: return PrimitiveType::kString; case cel::expr::Type::BYTES: return PrimitiveType::kBytes; default: return absl::InvalidArgumentError( "Illegal type specified for " "cel::expr::Type::PrimitiveType."); } } absl::StatusOr ToNative( cel::expr::Type::WellKnownType well_known_type) { switch (well_known_type) { case cel::expr::Type::WELL_KNOWN_TYPE_UNSPECIFIED: return WellKnownTypeSpec::kWellKnownTypeUnspecified; case cel::expr::Type::ANY: return WellKnownTypeSpec::kAny; case cel::expr::Type::TIMESTAMP: return WellKnownTypeSpec::kTimestamp; case cel::expr::Type::DURATION: return WellKnownTypeSpec::kDuration; default: return absl::InvalidArgumentError( "Illegal type specified for " "cel::expr::Type::WellKnownType."); } } absl::StatusOr ToNative( const cel::expr::Type::ListType& list_type) { auto native_elem_type = ConvertProtoTypeToNative(list_type.elem_type()); if (!native_elem_type.ok()) { return native_elem_type.status(); } return ListTypeSpec( std::make_unique(*(std::move(native_elem_type)))); } absl::StatusOr ToNative( const cel::expr::Type::MapType& map_type) { auto native_key_type = ConvertProtoTypeToNative(map_type.key_type()); if (!native_key_type.ok()) { return native_key_type.status(); } auto native_value_type = ConvertProtoTypeToNative(map_type.value_type()); if (!native_value_type.ok()) { return native_value_type.status(); } return MapTypeSpec( std::make_unique(*(std::move(native_key_type))), std::make_unique(*(std::move(native_value_type)))); } absl::StatusOr ToNative( const cel::expr::Type::FunctionType& function_type) { std::vector arg_types; arg_types.reserve(function_type.arg_types_size()); for (const auto& arg_type : function_type.arg_types()) { auto native_arg = ConvertProtoTypeToNative(arg_type); if (!native_arg.ok()) { return native_arg.status(); } arg_types.emplace_back(*(std::move(native_arg))); } auto native_result = ConvertProtoTypeToNative(function_type.result_type()); if (!native_result.ok()) { return native_result.status(); } return FunctionTypeSpec( std::make_unique(*(std::move(native_result))), std::move(arg_types)); } absl::StatusOr ToNative( const cel::expr::Type::AbstractType& abstract_type) { std::vector parameter_types; for (const auto& parameter_type : abstract_type.parameter_types()) { auto native_parameter_type = ConvertProtoTypeToNative(parameter_type); if (!native_parameter_type.ok()) { return native_parameter_type.status(); } parameter_types.emplace_back(*(std::move(native_parameter_type))); } return AbstractType(abstract_type.name(), std::move(parameter_types)); } absl::StatusOr ConvertProtoTypeToNative( const cel::expr::Type& type) { switch (type.type_kind_case()) { case cel::expr::Type::kDyn: return TypeSpec(DynTypeSpec()); case cel::expr::Type::kNull: return TypeSpec(NullTypeSpec()); case cel::expr::Type::kPrimitive: { auto native_primitive = ToNative(type.primitive()); if (!native_primitive.ok()) { return native_primitive.status(); } return TypeSpec(*(std::move(native_primitive))); } case cel::expr::Type::kWrapper: { auto native_wrapper = ToNative(type.wrapper()); if (!native_wrapper.ok()) { return native_wrapper.status(); } return TypeSpec(PrimitiveTypeWrapper(*(std::move(native_wrapper)))); } case cel::expr::Type::kWellKnown: { auto native_well_known = ToNative(type.well_known()); if (!native_well_known.ok()) { return native_well_known.status(); } return TypeSpec(*std::move(native_well_known)); } case cel::expr::Type::kListType: { auto native_list_type = ToNative(type.list_type()); if (!native_list_type.ok()) { return native_list_type.status(); } return TypeSpec(*(std::move(native_list_type))); } case cel::expr::Type::kMapType: { auto native_map_type = ToNative(type.map_type()); if (!native_map_type.ok()) { return native_map_type.status(); } return TypeSpec(*(std::move(native_map_type))); } case cel::expr::Type::kFunction: { auto native_function = ToNative(type.function()); if (!native_function.ok()) { return native_function.status(); } return TypeSpec(*(std::move(native_function))); } case cel::expr::Type::kMessageType: return TypeSpec(MessageTypeSpec(type.message_type())); case cel::expr::Type::kTypeParam: return TypeSpec(ParamTypeSpec(type.type_param())); case cel::expr::Type::kType: { if (type.type().type_kind_case() == cel::expr::Type::TypeKindCase::TYPE_KIND_NOT_SET) { return TypeSpec(std::unique_ptr()); } auto native_type = ConvertProtoTypeToNative(type.type()); if (!native_type.ok()) { return native_type.status(); } return TypeSpec(std::make_unique(*std::move(native_type))); } case cel::expr::Type::kError: return TypeSpec(ErrorTypeSpec::kValue); case cel::expr::Type::kAbstractType: { auto native_abstract = ToNative(type.abstract_type()); if (!native_abstract.ok()) { return native_abstract.status(); } return TypeSpec(*(std::move(native_abstract))); } case cel::expr::Type::TYPE_KIND_NOT_SET: return TypeSpec(UnsetTypeSpec()); default: return absl::InvalidArgumentError( "Illegal type specified for cel::expr::Type."); } } absl::StatusOr ConvertProtoReferenceToNative( const cel::expr::Reference& reference) { Reference ret_val; ret_val.set_name(reference.name()); ret_val.mutable_overload_id().reserve(reference.overload_id_size()); for (const auto& elem : reference.overload_id()) { ret_val.mutable_overload_id().emplace_back(elem); } if (reference.has_value()) { CEL_RETURN_IF_ERROR( ConstantFromProto(reference.value(), ret_val.mutable_value())); } return ret_val; } absl::StatusOr ReferenceToProto(const Reference& reference) { ReferencePb result; result.set_name(reference.name()); for (const auto& overload_id : reference.overload_id()) { result.add_overload_id(overload_id); } if (reference.has_value()) { CEL_RETURN_IF_ERROR( ConstantToProto(reference.value(), result.mutable_value())); } return result; } absl::Status TypeToProto(const TypeSpec& type, TypePb* result); struct TypeKindToProtoVisitor { absl::Status operator()(PrimitiveType primitive) { switch (primitive) { case PrimitiveType::kPrimitiveTypeUnspecified: result->set_primitive(TypePb::PRIMITIVE_TYPE_UNSPECIFIED); return absl::OkStatus(); case PrimitiveType::kBool: result->set_primitive(TypePb::BOOL); return absl::OkStatus(); case PrimitiveType::kInt64: result->set_primitive(TypePb::INT64); return absl::OkStatus(); case PrimitiveType::kUint64: result->set_primitive(TypePb::UINT64); return absl::OkStatus(); case PrimitiveType::kDouble: result->set_primitive(TypePb::DOUBLE); return absl::OkStatus(); case PrimitiveType::kString: result->set_primitive(TypePb::STRING); return absl::OkStatus(); case PrimitiveType::kBytes: result->set_primitive(TypePb::BYTES); return absl::OkStatus(); default: break; } return absl::InvalidArgumentError("Unsupported primitive type"); } absl::Status operator()(PrimitiveTypeWrapper wrapper) { CEL_RETURN_IF_ERROR(this->operator()(wrapper.type())); auto wrapped = result->primitive(); result->set_wrapper(wrapped); return absl::OkStatus(); } absl::Status operator()(UnsetTypeSpec) { result->clear_type_kind(); return absl::OkStatus(); } absl::Status operator()(DynTypeSpec) { result->mutable_dyn(); return absl::OkStatus(); } absl::Status operator()(ErrorTypeSpec) { result->mutable_error(); return absl::OkStatus(); } absl::Status operator()(NullTypeSpec) { result->set_null(google::protobuf::NULL_VALUE); return absl::OkStatus(); } absl::Status operator()(const ListTypeSpec& list_type) { return TypeToProto(list_type.elem_type(), result->mutable_list_type()->mutable_elem_type()); } absl::Status operator()(const MapTypeSpec& map_type) { CEL_RETURN_IF_ERROR(TypeToProto( map_type.key_type(), result->mutable_map_type()->mutable_key_type())); return TypeToProto(map_type.value_type(), result->mutable_map_type()->mutable_value_type()); } absl::Status operator()(const MessageTypeSpec& message_type) { result->set_message_type(message_type.type()); return absl::OkStatus(); } absl::Status operator()(const WellKnownTypeSpec& well_known_type) { switch (well_known_type) { case WellKnownTypeSpec::kWellKnownTypeUnspecified: result->set_well_known(TypePb::WELL_KNOWN_TYPE_UNSPECIFIED); return absl::OkStatus(); case WellKnownTypeSpec::kAny: result->set_well_known(TypePb::ANY); return absl::OkStatus(); case WellKnownTypeSpec::kDuration: result->set_well_known(TypePb::DURATION); return absl::OkStatus(); case WellKnownTypeSpec::kTimestamp: result->set_well_known(TypePb::TIMESTAMP); return absl::OkStatus(); default: break; } return absl::InvalidArgumentError("Unsupported well-known type"); } absl::Status operator()(const FunctionTypeSpec& function_type) { CEL_RETURN_IF_ERROR( TypeToProto(function_type.result_type(), result->mutable_function()->mutable_result_type())); for (const TypeSpec& arg_type : function_type.arg_types()) { CEL_RETURN_IF_ERROR( TypeToProto(arg_type, result->mutable_function()->add_arg_types())); } return absl::OkStatus(); } absl::Status operator()(const AbstractType& type) { auto* abstract_type_pb = result->mutable_abstract_type(); abstract_type_pb->set_name(type.name()); for (const TypeSpec& type_param : type.parameter_types()) { CEL_RETURN_IF_ERROR( TypeToProto(type_param, abstract_type_pb->add_parameter_types())); } return absl::OkStatus(); } absl::Status operator()(const std::unique_ptr& type_type) { return TypeToProto((type_type != nullptr) ? *type_type : TypeSpec(), result->mutable_type()); } absl::Status operator()(const ParamTypeSpec& param_type) { result->set_type_param(param_type.type()); return absl::OkStatus(); } TypePb* result; }; absl::Status TypeToProto(const TypeSpec& type, TypePb* result) { return absl::visit(TypeKindToProtoVisitor{result}, type.type_kind()); } } // namespace absl::StatusOr> CreateAstFromParsedExpr( const cel::expr::Expr& expr, const cel::expr::SourceInfo* source_info) { CEL_ASSIGN_OR_RETURN(auto runtime_expr, ExprValueFromProto(expr)); SourceInfo runtime_source_info; if (source_info != nullptr) { CEL_ASSIGN_OR_RETURN(runtime_source_info, ConvertProtoSourceInfoToNative(*source_info)); } return std::make_unique(std::move(runtime_expr), std::move(runtime_source_info)); } absl::StatusOr> CreateAstFromParsedExpr( const ParsedExprPb& parsed_expr) { return CreateAstFromParsedExpr(parsed_expr.expr(), &parsed_expr.source_info()); } absl::Status AstToParsedExpr(const Ast& ast, cel::expr::ParsedExpr* absl_nonnull out) { ParsedExprPb& parsed_expr = *out; CEL_RETURN_IF_ERROR(ExprToProto(ast.root_expr(), parsed_expr.mutable_expr())); CEL_RETURN_IF_ERROR(ast_internal::SourceInfoToProto( ast.source_info(), parsed_expr.mutable_source_info())); return absl::OkStatus(); } absl::StatusOr> CreateAstFromCheckedExpr( const CheckedExprPb& checked_expr) { CEL_ASSIGN_OR_RETURN(Expr expr, ExprValueFromProto(checked_expr.expr())); CEL_ASSIGN_OR_RETURN(SourceInfo source_info, ConvertProtoSourceInfoToNative( checked_expr.source_info())); Ast::ReferenceMap reference_map; for (const auto& pair : checked_expr.reference_map()) { auto native_reference = ConvertProtoReferenceToNative(pair.second); if (!native_reference.ok()) { return native_reference.status(); } reference_map.emplace(pair.first, *(std::move(native_reference))); } Ast::TypeMap type_map; for (const auto& pair : checked_expr.type_map()) { auto native_type = ConvertProtoTypeToNative(pair.second); if (!native_type.ok()) { return native_type.status(); } type_map.emplace(pair.first, *(std::move(native_type))); } return std::make_unique(std::move(expr), std::move(source_info), std::move(reference_map), std::move(type_map), checked_expr.expr_version()); } absl::Status AstToCheckedExpr( const Ast& ast, cel::expr::CheckedExpr* absl_nonnull out) { if (!ast.is_checked()) { return absl::InvalidArgumentError("AST is not type-checked"); } CheckedExprPb& checked_expr = *out; checked_expr.set_expr_version(ast.expr_version()); CEL_RETURN_IF_ERROR( ExprToProto(ast.root_expr(), checked_expr.mutable_expr())); CEL_RETURN_IF_ERROR(ast_internal::SourceInfoToProto( ast.source_info(), checked_expr.mutable_source_info())); for (auto it = ast.reference_map().begin(); it != ast.reference_map().end(); ++it) { ReferencePb& dest_reference = (*checked_expr.mutable_reference_map())[it->first]; CEL_ASSIGN_OR_RETURN(dest_reference, ReferenceToProto(it->second)); } for (auto it = ast.type_map().begin(); it != ast.type_map().end(); ++it) { TypePb& dest_type = (*checked_expr.mutable_type_map())[it->first]; CEL_RETURN_IF_ERROR(TypeToProto(it->second, &dest_type)); } return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/ast_proto.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_PROTO_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_PROTO_H_ #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/ast.h" namespace cel { // Creates a runtime AST from a parsed-only protobuf AST. // May return a non-ok Status if the AST is malformed (e.g. unset required // fields). absl::StatusOr> CreateAstFromParsedExpr( const cel::expr::Expr& expr, const cel::expr::SourceInfo* source_info = nullptr); absl::StatusOr> CreateAstFromParsedExpr( const cel::expr::ParsedExpr& parsed_expr); absl::Status AstToParsedExpr(const Ast& ast, cel::expr::ParsedExpr* absl_nonnull out); // Creates a runtime AST from a checked protobuf AST. // May return a non-ok Status if the AST is malformed (e.g. unset required // fields). absl::StatusOr> CreateAstFromCheckedExpr( const cel::expr::CheckedExpr& checked_expr); absl::Status AstToCheckedExpr(const Ast& ast, cel::expr::CheckedExpr* absl_nonnull out); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_PROTO_H_ ================================================ FILE: common/ast_proto_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast_proto.h" #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "common/ast.h" #include "common/decl.h" #include "common/expr.h" #include "common/source.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/optional.h" #include "compiler/standard_library.h" #include "extensions/comprehensions_v2.h" #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/text_format.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::PrimitiveType; using ::cel::WellKnownTypeSpec; using ::cel::internal::test::EqualsProto; using ::cel::expr::CheckedExpr; using ::cel::expr::ParsedExpr; using ::testing::HasSubstr; using TypePb = cel::expr::Type; absl::StatusOr ConvertProtoTypeToNative( const cel::expr::Type& type) { CheckedExpr checked_expr; checked_expr.mutable_expr()->mutable_ident_expr()->set_name("foo"); (*checked_expr.mutable_type_map())[1] = type; CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromCheckedExpr(checked_expr)); const auto& type_map = ast->type_map(); auto iter = type_map.find(1); if (iter != type_map.end()) { return iter->second; } return absl::InternalError("conversion failed but reported success"); } TEST(AstConvertersTest, PrimitiveTypeUnspecifiedToNative) { cel::expr::Type type; type.set_primitive(cel::expr::Type::PRIMITIVE_TYPE_UNSPECIFIED); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_primitive()); EXPECT_EQ(native_type->primitive(), PrimitiveType::kPrimitiveTypeUnspecified); } TEST(AstConvertersTest, PrimitiveTypeBoolToNative) { cel::expr::Type type; type.set_primitive(cel::expr::Type::BOOL); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_primitive()); EXPECT_EQ(native_type->primitive(), PrimitiveType::kBool); } TEST(AstConvertersTest, PrimitiveTypeInt64ToNative) { cel::expr::Type type; type.set_primitive(cel::expr::Type::INT64); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_primitive()); EXPECT_EQ(native_type->primitive(), PrimitiveType::kInt64); } TEST(AstConvertersTest, PrimitiveTypeUint64ToNative) { cel::expr::Type type; type.set_primitive(cel::expr::Type::UINT64); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_primitive()); EXPECT_EQ(native_type->primitive(), PrimitiveType::kUint64); } TEST(AstConvertersTest, PrimitiveTypeDoubleToNative) { cel::expr::Type type; type.set_primitive(cel::expr::Type::DOUBLE); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_primitive()); EXPECT_EQ(native_type->primitive(), PrimitiveType::kDouble); } TEST(AstConvertersTest, PrimitiveTypeStringToNative) { cel::expr::Type type; type.set_primitive(cel::expr::Type::STRING); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_primitive()); EXPECT_EQ(native_type->primitive(), PrimitiveType::kString); } TEST(AstConvertersTest, PrimitiveTypeBytesToNative) { cel::expr::Type type; type.set_primitive(cel::expr::Type::BYTES); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_primitive()); EXPECT_EQ(native_type->primitive(), PrimitiveType::kBytes); } TEST(AstConvertersTest, PrimitiveTypeError) { cel::expr::Type type; type.set_primitive(::cel::expr::Type_PrimitiveType(7)); auto native_type = ConvertProtoTypeToNative(type); EXPECT_EQ(native_type.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(native_type.status().message(), ::testing::HasSubstr("Illegal type specified for " "cel::expr::Type::PrimitiveType.")); } TEST(AstConvertersTest, WellKnownTypeUnspecifiedToNative) { cel::expr::Type type; type.set_well_known(cel::expr::Type::WELL_KNOWN_TYPE_UNSPECIFIED); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_well_known()); EXPECT_EQ(native_type->well_known(), WellKnownTypeSpec::kWellKnownTypeUnspecified); } TEST(AstConvertersTest, WellKnownTypeAnyToNative) { cel::expr::Type type; type.set_well_known(cel::expr::Type::ANY); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_well_known()); EXPECT_EQ(native_type->well_known(), WellKnownTypeSpec::kAny); } TEST(AstConvertersTest, WellKnownTypeTimestampToNative) { cel::expr::Type type; type.set_well_known(cel::expr::Type::TIMESTAMP); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_well_known()); EXPECT_EQ(native_type->well_known(), WellKnownTypeSpec::kTimestamp); } TEST(AstConvertersTest, WellKnownTypeDuraionToNative) { cel::expr::Type type; type.set_well_known(cel::expr::Type::DURATION); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_well_known()); EXPECT_EQ(native_type->well_known(), WellKnownTypeSpec::kDuration); } TEST(AstConvertersTest, WellKnownTypeError) { cel::expr::Type type; type.set_well_known(::cel::expr::Type_WellKnownType(4)); auto native_type = ConvertProtoTypeToNative(type); EXPECT_EQ(native_type.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(native_type.status().message(), ::testing::HasSubstr("Illegal type specified for " "cel::expr::Type::WellKnownType.")); } TEST(AstConvertersTest, ListTypeToNative) { cel::expr::Type type; type.mutable_list_type()->mutable_elem_type()->set_primitive( cel::expr::Type::BOOL); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_list_type()); auto& native_list_type = native_type->list_type(); ASSERT_TRUE(native_list_type.elem_type().has_primitive()); EXPECT_EQ(native_list_type.elem_type().primitive(), PrimitiveType::kBool); } TEST(AstConvertersTest, MapTypeToNative) { cel::expr::Type type; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( map_type { key_type { primitive: BOOL } value_type { primitive: DOUBLE } } )pb", &type)); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_map_type()); auto& native_map_type = native_type->map_type(); ASSERT_TRUE(native_map_type.key_type().has_primitive()); EXPECT_EQ(native_map_type.key_type().primitive(), PrimitiveType::kBool); ASSERT_TRUE(native_map_type.value_type().has_primitive()); EXPECT_EQ(native_map_type.value_type().primitive(), PrimitiveType::kDouble); } TEST(AstConvertersTest, FunctionTypeToNative) { cel::expr::Type type; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( function { result_type { primitive: BOOL } arg_types { primitive: DOUBLE } arg_types { primitive: STRING } } )pb", &type)); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_function()); auto& native_function_type = native_type->function(); ASSERT_TRUE(native_function_type.result_type().has_primitive()); EXPECT_EQ(native_function_type.result_type().primitive(), PrimitiveType::kBool); ASSERT_TRUE(native_function_type.arg_types().at(0).has_primitive()); EXPECT_EQ(native_function_type.arg_types().at(0).primitive(), PrimitiveType::kDouble); ASSERT_TRUE(native_function_type.arg_types().at(1).has_primitive()); EXPECT_EQ(native_function_type.arg_types().at(1).primitive(), PrimitiveType::kString); } TEST(AstConvertersTest, AbstractTypeToNative) { cel::expr::Type type; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( abstract_type { name: "name" parameter_types { primitive: DOUBLE } parameter_types { primitive: STRING } } )pb", &type)); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_abstract_type()); auto& native_abstract_type = native_type->abstract_type(); EXPECT_EQ(native_abstract_type.name(), "name"); ASSERT_TRUE(native_abstract_type.parameter_types().at(0).has_primitive()); EXPECT_EQ(native_abstract_type.parameter_types().at(0).primitive(), PrimitiveType::kDouble); ASSERT_TRUE(native_abstract_type.parameter_types().at(1).has_primitive()); EXPECT_EQ(native_abstract_type.parameter_types().at(1).primitive(), PrimitiveType::kString); } TEST(AstConvertersTest, DynamicTypeToNative) { cel::expr::Type type; type.mutable_dyn(); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_dyn()); } TEST(AstConvertersTest, NullTypeToNative) { cel::expr::Type type; type.set_null(google::protobuf::NULL_VALUE); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_null()); EXPECT_EQ(native_type->null(), NullTypeSpec()); } TEST(AstConvertersTest, PrimitiveTypeWrapperToNative) { cel::expr::Type type; type.set_wrapper(cel::expr::Type::BOOL); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_wrapper()); EXPECT_EQ(native_type->wrapper(), PrimitiveType::kBool); } TEST(AstConvertersTest, MessageTypeToNative) { cel::expr::Type type; type.set_message_type("message"); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_message_type()); EXPECT_EQ(native_type->message_type().type(), "message"); } TEST(AstConvertersTest, ParamTypeToNative) { cel::expr::Type type; type.set_type_param("param"); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_type_param()); EXPECT_EQ(native_type->type_param().type(), "param"); } TEST(AstConvertersTest, NestedTypeToNative) { cel::expr::Type type; type.mutable_type()->mutable_dyn(); auto native_type = ConvertProtoTypeToNative(type); ASSERT_TRUE(native_type->has_type()); EXPECT_TRUE(native_type->type().has_dyn()); } TEST(AstConvertersTest, TypeTypeDefault) { auto native_type = ConvertProtoTypeToNative(cel::expr::Type()); ASSERT_THAT(native_type, IsOk()); EXPECT_TRUE(absl::holds_alternative(native_type->type_kind())); } TEST(AstConvertersTest, ReferenceToNative) { cel::expr::CheckedExpr reference_wrapper; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( reference_map { key: 1 value { name: "name" overload_id: "id1" overload_id: "id2" value { bool_value: true } } })pb", &reference_wrapper)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromCheckedExpr(reference_wrapper)); const auto& native_references = ast->reference_map(); auto native_reference = native_references.at(1); EXPECT_EQ(native_reference.name(), "name"); EXPECT_EQ(native_reference.overload_id(), std::vector({"id1", "id2"})); EXPECT_TRUE(native_reference.value().bool_value()); } TEST(AstConvertersTest, SourceInfoToNative) { cel::expr::ParsedExpr source_info_wrapper; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( source_info { syntax_version: "version" location: "location" line_offsets: 1 line_offsets: 2 positions { key: 1 value: 2 } positions { key: 3 value: 4 } macro_calls { key: 1 value { ident_expr { name: "name" } } } })pb", &source_info_wrapper)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(source_info_wrapper)); const auto& native_source_info = ast->source_info(); EXPECT_EQ(native_source_info.syntax_version(), "version"); EXPECT_EQ(native_source_info.location(), "location"); EXPECT_EQ(native_source_info.line_offsets(), std::vector({1, 2})); EXPECT_EQ(native_source_info.positions().at(1), 2); EXPECT_EQ(native_source_info.positions().at(3), 4); ASSERT_TRUE(native_source_info.macro_calls().at(1).has_ident_expr()); ASSERT_EQ(native_source_info.macro_calls().at(1).ident_expr().name(), "name"); } TEST(AstConvertersTest, CheckedExprToAst) { CheckedExpr checked_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( reference_map { key: 1 value { name: "name" overload_id: "id1" overload_id: "id2" value { bool_value: true } } } type_map { key: 1 value { dyn {} } } source_info { syntax_version: "version" location: "location" line_offsets: 1 line_offsets: 2 positions { key: 1 value: 2 } positions { key: 3 value: 4 } macro_calls { key: 1 value { ident_expr { name: "name" } } } } expr_version: "version" expr { ident_expr { name: "expr" } } )pb", &checked_expr)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromCheckedExpr(checked_expr)); ASSERT_TRUE(ast->IsChecked()); } TEST(AstConvertersTest, AstToCheckedExprBasic) { Ast ast; ast.mutable_root_expr().set_id(1); ast.mutable_root_expr().mutable_ident_expr().set_name("expr"); ast.mutable_source_info().set_syntax_version("version"); ast.mutable_source_info().set_location("location"); ast.mutable_source_info().mutable_line_offsets().push_back(1); ast.mutable_source_info().mutable_line_offsets().push_back(2); ast.mutable_source_info().mutable_positions().insert({1, 2}); ast.mutable_source_info().mutable_positions().insert({3, 4}); Expr macro; macro.mutable_ident_expr().set_name("name"); ast.mutable_source_info().mutable_macro_calls().insert({1, std::move(macro)}); Reference reference; reference.set_name("name"); reference.mutable_overload_id().push_back("id1"); reference.mutable_overload_id().push_back("id2"); reference.mutable_value().set_bool_value(true); TypeSpec type; type.set_type_kind(DynTypeSpec()); ast.mutable_reference_map().insert({1, std::move(reference)}); ast.mutable_type_map().insert({1, std::move(type)}); ast.set_expr_version("version"); ast.set_is_checked(true); CheckedExpr checked_expr; ASSERT_THAT(AstToCheckedExpr(ast, &checked_expr), IsOk()); EXPECT_THAT(checked_expr, EqualsProto(R"pb( reference_map { key: 1 value { name: "name" overload_id: "id1" overload_id: "id2" value { bool_value: true } } } type_map { key: 1 value { dyn {} } } source_info { syntax_version: "version" location: "location" line_offsets: 1 line_offsets: 2 positions { key: 1 value: 2 } positions { key: 3 value: 4 } macro_calls { key: 1 value { ident_expr { name: "name" } } } } expr_version: "version" expr { id: 1 ident_expr { name: "expr" } } )pb")); } constexpr absl::string_view kTypesTestCheckedExpr = R"pb(reference_map: { key: 1 value: { name: "x" } } type_map: { key: 1 value: { primitive: INT64 } } source_info: { location: "" line_offsets: 2 positions: { key: 1 value: 0 } } expr: { id: 1 ident_expr: { name: "x" } })pb"; struct CheckedExprToAstTypesTestCase { absl::string_view type; }; class CheckedExprToAstTypesTest : public testing::TestWithParam { public: void SetUp() override { ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kTypesTestCheckedExpr, &checked_expr_)); } protected: CheckedExpr checked_expr_; }; TEST_P(CheckedExprToAstTypesTest, CheckedExprToAstTypes) { TypePb test_type; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(GetParam().type, &test_type)); (*checked_expr_.mutable_type_map())[1] = test_type; ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromCheckedExpr(checked_expr_)); CheckedExpr checked_expr; ASSERT_THAT(AstToCheckedExpr(*ast, &checked_expr), IsOk()); EXPECT_THAT(checked_expr, EqualsProto(checked_expr_)); } INSTANTIATE_TEST_SUITE_P( Types, CheckedExprToAstTypesTest, testing::ValuesIn({ {R"pb(list_type { elem_type { primitive: INT64 } })pb"}, {R"pb(map_type { key_type { primitive: STRING } value_type { primitive: INT64 } })pb"}, {R"pb(message_type: "com.example.TestType")pb"}, {R"pb(primitive: BOOL)pb"}, {R"pb(primitive: INT64)pb"}, {R"pb(primitive: UINT64)pb"}, {R"pb(primitive: DOUBLE)pb"}, {R"pb(primitive: STRING)pb"}, {R"pb(primitive: BYTES)pb"}, {R"pb(wrapper: BOOL)pb"}, {R"pb(wrapper: INT64)pb"}, {R"pb(wrapper: UINT64)pb"}, {R"pb(wrapper: DOUBLE)pb"}, {R"pb(wrapper: STRING)pb"}, {R"pb(wrapper: BYTES)pb"}, {R"pb(well_known: TIMESTAMP)pb"}, {R"pb(well_known: DURATION)pb"}, {R"pb(well_known: ANY)pb"}, {R"pb(dyn {})pb"}, {R"pb(error {})pb"}, {R"pb(null: NULL_VALUE)pb"}, {R"pb( abstract_type { name: "MyType" parameter_types { primitive: INT64 } } )pb"}, {R"pb( type { primitive: INT64 } )pb"}, {R"pb( type { type {} } )pb"}, {R"pb(type_param: "T")pb"}, {R"pb( function { result_type { primitive: INT64 } arg_types { primitive: INT64 } } )pb"}, })); TEST(AstConvertersTest, ParsedExprToAst) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( source_info { syntax_version: "version" location: "location" line_offsets: 1 line_offsets: 2 positions { key: 1 value: 2 } positions { key: 3 value: 4 } macro_calls { key: 1 value { ident_expr { name: "name" } } } } expr { ident_expr { name: "expr" } } )pb", &parsed_expr)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(parsed_expr)); } TEST(AstConvertersTest, AstToParsedExprBasic) { Expr expr; expr.set_id(1); expr.mutable_ident_expr().set_name("expr"); SourceInfo source_info; source_info.set_syntax_version("version"); source_info.set_location("location"); source_info.mutable_line_offsets().push_back(1); source_info.mutable_line_offsets().push_back(2); source_info.mutable_positions().insert({1, 2}); source_info.mutable_positions().insert({3, 4}); Expr macro; macro.mutable_ident_expr().set_name("name"); source_info.mutable_macro_calls().insert({1, std::move(macro)}); Ast ast(std::move(expr), std::move(source_info)); ParsedExpr parsed_expr; ASSERT_THAT(AstToParsedExpr(ast, &parsed_expr), IsOk()); EXPECT_THAT(parsed_expr, EqualsProto(R"pb( source_info { syntax_version: "version" location: "location" line_offsets: 1 line_offsets: 2 positions { key: 1 value: 2 } positions { key: 3 value: 4 } macro_calls { key: 1 value { ident_expr { name: "name" } } } } expr { id: 1 ident_expr { name: "expr" } } )pb")); } TEST(AstConvertersTest, ExprToAst) { cel::expr::Expr expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( ident_expr { name: "expr" } )pb", &expr)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(expr)); } TEST(AstConvertersTest, ExprAndSourceInfoToAst) { cel::expr::Expr expr; cel::expr::SourceInfo source_info; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( syntax_version: "version" location: "location" line_offsets: 1 line_offsets: 2 positions { key: 1 value: 2 } positions { key: 3 value: 4 } macro_calls { key: 1 value { ident_expr { name: "name" } } } )pb", &source_info)); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( ident_expr { name: "expr" } )pb", &expr)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(expr, &source_info)); } TEST(AstConvertersTest, EmptyNodeRoundTrip) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr { id: 1 select_expr { operand { id: 2 # no kind set. } field: "field" } } source_info {} )pb", &parsed_expr)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(parsed_expr)); ParsedExpr copy; ASSERT_THAT(AstToParsedExpr(*ast, ©), IsOk()); EXPECT_THAT(copy, EqualsProto(parsed_expr)); } TEST(AstConvertersTest, DurationConstantRoundTrip) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr { id: 1 const_expr { # deprecated, but support existing ASTs. duration_value { seconds: 10 } } } source_info {} )pb", &parsed_expr)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(parsed_expr)); ParsedExpr copy; ASSERT_THAT(AstToParsedExpr(*ast, ©), IsOk()); EXPECT_THAT(copy, EqualsProto(parsed_expr)); } TEST(AstConvertersTest, TimestampConstantRoundTrip) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr { id: 1 const_expr { # deprecated, but support existing ASTs. timestamp_value { seconds: 10 } } } source_info {} )pb", &parsed_expr)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(parsed_expr)); ParsedExpr copy; ASSERT_THAT(AstToParsedExpr(*ast, ©), IsOk()); EXPECT_THAT(copy, EqualsProto(parsed_expr)); } struct ConversionRoundTripCase { absl::string_view expr; }; class ConversionRoundTripTest : public testing::TestWithParam { public: ConversionRoundTripTest() { auto builder = cel::NewCompilerBuilder(internal::GetTestingDescriptorPool()).value(); builder->AddLibrary(cel::StandardCompilerLibrary()).IgnoreError(); builder->AddLibrary(OptionalCompilerLibrary()).IgnoreError(); builder->AddLibrary(extensions::ComprehensionsV2CompilerLibrary()) .IgnoreError(); builder->GetCheckerBuilder().set_container("cel.expr.conformance.proto3"); builder->GetCheckerBuilder() .AddVariable(MakeVariableDecl("ident", IntType())) .IgnoreError(); builder->GetCheckerBuilder() .AddVariable(MakeVariableDecl("map_ident", JsonMapType())) .IgnoreError(); compiler_ = builder->Build().value(); } absl::StatusOr ParseToProto(absl::string_view expr) { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expr)); CEL_ASSIGN_OR_RETURN(auto result, compiler_->GetParser().Parse(*source)); ParsedExpr parsed_expr; CEL_RETURN_IF_ERROR(AstToParsedExpr(*result, &parsed_expr)); return parsed_expr; } absl::StatusOr CompileToProto(absl::string_view expr) { CEL_ASSIGN_OR_RETURN(auto result, compiler_->Compile(expr)); if (!result.IsValid()) { return absl::InvalidArgumentError(absl::StrCat( "Compilation failed: '", expr, "': ", result.FormatError())); } CEL_ASSIGN_OR_RETURN(auto ast, result.ReleaseAst()); CheckedExpr checked_expr; CEL_RETURN_IF_ERROR(AstToCheckedExpr(*ast, &checked_expr)); return checked_expr; } protected: std::unique_ptr compiler_; }; TEST_P(ConversionRoundTripTest, ParsedExprCopyable) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseToProto(GetParam().expr)); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CreateAstFromParsedExpr(parsed_expr)); CheckedExpr expr_pb; EXPECT_THAT(AstToCheckedExpr(*ast, &expr_pb), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("AST is not type-checked"))); ParsedExpr proto_out; ASSERT_THAT(AstToParsedExpr(*ast, &proto_out), IsOk()); EXPECT_THAT(proto_out, EqualsProto(parsed_expr)); } TEST_P(ConversionRoundTripTest, ExprCopyable) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseToProto(GetParam().expr)); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CreateAstFromParsedExpr(parsed_expr)); Expr copy = ast->root_expr(); ast->mutable_root_expr() = std::move(copy); ParsedExpr parsed_pb_out; CheckedExpr checked_pb_out; EXPECT_THAT(AstToCheckedExpr(*ast, &checked_pb_out), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("AST is not type-checked"))); ASSERT_THAT(AstToParsedExpr(*ast, &parsed_pb_out), IsOk()); EXPECT_THAT(parsed_pb_out, EqualsProto(parsed_expr)); } TEST_P(ConversionRoundTripTest, CheckedExprRoundTrip) { ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr, CompileToProto(GetParam().expr)); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CreateAstFromCheckedExpr(checked_expr)); CheckedExpr checked_pb_out; ASSERT_THAT(AstToCheckedExpr(*ast, &checked_pb_out), IsOk()); EXPECT_THAT(checked_pb_out, EqualsProto(checked_expr)); } TEST_P(ConversionRoundTripTest, CheckedExprCopyRoundTrip) { ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr, CompileToProto(GetParam().expr)); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CreateAstFromCheckedExpr(checked_expr)); Ast copy = *ast; CheckedExpr checked_pb_out; ASSERT_THAT(AstToCheckedExpr(copy, &checked_pb_out), IsOk()); EXPECT_THAT(checked_pb_out, EqualsProto(checked_expr)); } INSTANTIATE_TEST_SUITE_P( ExpressionCases, ConversionRoundTripTest, testing::ValuesIn( {{R"cel(null == null)cel"}, {R"cel(1 == 2)cel"}, {R"cel(1u == 2u)cel"}, {R"cel(1.1 == 2.1)cel"}, {R"cel(b"1" == b"2")cel"}, {R"cel("42" == "42")cel"}, {R"cel("s".startsWith("s") == true)cel"}, {R"cel([1, 2, 3] == [1, 2, 3])cel"}, {R"cel([1, 2, 3].all(i, e, i == e - 1) == true)cel"}, {R"cel(TestAllTypes{single_int64: 42}.single_int64 == 42)cel"}, {R"cel([1, 2, 3].map(x, x + 2).size() == 3)cel"}, {R"cel({"a": 1, "b": 2}["a"] == 1)cel"}, {R"cel(ident == 42)cel"}, {R"cel(map_ident.field == 42)cel"}, {R"cel({?"abc": {}[?1]}.?abc.orValue(42) == 42)cel"}, {R"cel([1, 2, ?optional.none()].size() == 2)cel"}})); TEST(ExtensionConversionRoundTripTest, RoundTrip) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr { id: 1 ident_expr { name: "unused" } } source_info { extensions { id: "extension" version { major: 1 minor: 2 } affected_components: COMPONENT_UNSPECIFIED affected_components: COMPONENT_PARSER affected_components: COMPONENT_TYPE_CHECKER affected_components: COMPONENT_RUNTIME } } )pb", &parsed_expr)); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CreateAstFromParsedExpr(parsed_expr)); CheckedExpr expr_pb; EXPECT_THAT(AstToCheckedExpr(*ast, &expr_pb), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("AST is not type-checked"))); ParsedExpr copy; ASSERT_THAT(AstToParsedExpr(*ast, ©), IsOk()); EXPECT_THAT(copy, EqualsProto(parsed_expr)); } } // namespace } // namespace cel ================================================ FILE: common/ast_rewrite.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast_rewrite.h" #include #include #include "absl/log/absl_log.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "common/ast_visitor.h" #include "common/constant.h" #include "common/expr.h" namespace cel { namespace { struct ArgRecord { // Not null. Expr* expr; // For records that are direct arguments to call, we need to call // the CallArg visitor immediately after the argument is evaluated. const Expr* calling_expr; int call_arg; }; struct ComprehensionRecord { // Not null. Expr* expr; const ComprehensionExpr* comprehension; const Expr* comprehension_expr; ComprehensionArg comprehension_arg; bool use_comprehension_callbacks; }; struct ExprRecord { // Not null. Expr* expr; }; using StackRecordKind = absl::variant; struct StackRecord { public: static constexpr int kTarget = -2; explicit StackRecord(Expr* e) { ExprRecord record; record.expr = e; record_variant = record; } StackRecord(Expr* e, ComprehensionExpr* comprehension, Expr* comprehension_expr, ComprehensionArg comprehension_arg, bool use_comprehension_callbacks) { if (use_comprehension_callbacks) { ComprehensionRecord record; record.expr = e; record.comprehension = comprehension; record.comprehension_expr = comprehension_expr; record.comprehension_arg = comprehension_arg; record.use_comprehension_callbacks = use_comprehension_callbacks; record_variant = record; return; } ArgRecord record; record.expr = e; record.calling_expr = comprehension_expr; record.call_arg = comprehension_arg; record_variant = record; } StackRecord(Expr* e, const Expr* call, int argnum) { ArgRecord record; record.expr = e; record.calling_expr = call; record.call_arg = argnum; record_variant = record; } Expr* expr() const { return absl::get(record_variant).expr; } bool IsExprRecord() const { return absl::holds_alternative(record_variant); } StackRecordKind record_variant; bool visited = false; }; struct PreVisitor { void operator()(const ExprRecord& record) { struct { AstVisitor* visitor; const Expr* expr; void operator()(const Constant&) { // No pre-visit action. } void operator()(const IdentExpr&) { // No pre-visit action. } void operator()(const SelectExpr& select) { visitor->PreVisitSelect(*expr, select); } void operator()(const CallExpr& call) { visitor->PreVisitCall(*expr, call); } void operator()(const ListExpr&) { // No pre-visit action. } void operator()(const StructExpr&) { // No pre-visit action. } void operator()(const MapExpr&) { // No pre-visit action. } void operator()(const ComprehensionExpr& comprehension) { visitor->PreVisitComprehension(*expr, comprehension); } void operator()(const UnspecifiedExpr&) { // No pre-visit action. } } handler{visitor, record.expr}; visitor->PreVisitExpr(*record.expr); absl::visit(handler, record.expr->kind()); } // Do nothing for Arg variant. void operator()(const ArgRecord&) {} void operator()(const ComprehensionRecord& record) { visitor->PreVisitComprehensionSubexpression(*record.comprehension_expr, *record.comprehension, record.comprehension_arg); } AstVisitor* visitor; }; void PreVisit(const StackRecord& record, AstVisitor* visitor) { absl::visit(PreVisitor{visitor}, record.record_variant); } struct PostVisitor { void operator()(const ExprRecord& record) { struct { AstVisitor* visitor; const Expr* expr; void operator()(const Constant& constant) { visitor->PostVisitConst(*expr, constant); } void operator()(const IdentExpr& ident) { visitor->PostVisitIdent(*expr, ident); } void operator()(const SelectExpr& select) { visitor->PostVisitSelect(*expr, select); } void operator()(const CallExpr& call) { visitor->PostVisitCall(*expr, call); } void operator()(const ListExpr& create_list) { visitor->PostVisitList(*expr, create_list); } void operator()(const StructExpr& create_struct) { visitor->PostVisitStruct(*expr, create_struct); } void operator()(const MapExpr& map_expr) { visitor->PostVisitMap(*expr, map_expr); } void operator()(const ComprehensionExpr& comprehension) { visitor->PostVisitComprehension(*expr, comprehension); } void operator()(const UnspecifiedExpr&) { ABSL_LOG(ERROR) << "Unsupported Expr kind"; } } handler{visitor, record.expr}; absl::visit(handler, record.expr->kind()); visitor->PostVisitExpr(*record.expr); } void operator()(const ArgRecord& record) { if (record.call_arg == StackRecord::kTarget) { visitor->PostVisitTarget(*record.calling_expr); } else { visitor->PostVisitArg(*record.calling_expr, record.call_arg); } } void operator()(const ComprehensionRecord& record) { visitor->PostVisitComprehensionSubexpression(*record.comprehension_expr, *record.comprehension, record.comprehension_arg); } AstVisitor* visitor; }; void PostVisit(const StackRecord& record, AstVisitor* visitor) { absl::visit(PostVisitor{visitor}, record.record_variant); } void PushSelectDeps(SelectExpr* select_expr, std::stack* stack) { if (select_expr->has_operand()) { stack->push(StackRecord(&select_expr->mutable_operand())); } } void PushCallDeps(CallExpr* call_expr, Expr* expr, std::stack* stack) { const int arg_size = call_expr->args().size(); // Our contract is that we visit arguments in order. To do that, we need // to push them onto the stack in reverse order. for (int i = arg_size - 1; i >= 0; --i) { stack->push(StackRecord(&call_expr->mutable_args()[i], expr, i)); } // Are we receiver-style? if (call_expr->has_target()) { stack->push( StackRecord(&call_expr->mutable_target(), expr, StackRecord::kTarget)); } } void PushListDeps(ListExpr* list_expr, std::stack* stack) { auto& elements = list_expr->mutable_elements(); for (auto it = elements.rbegin(); it != elements.rend(); ++it) { auto& element = *it; stack->push(StackRecord(&element.mutable_expr())); } } void PushStructDeps(StructExpr* struct_expr, std::stack* stack) { auto& entries = struct_expr->mutable_fields(); for (auto it = entries.rbegin(); it != entries.rend(); ++it) { auto& entry = *it; // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_value()) { stack->push(StackRecord(&entry.mutable_value())); } } } void PushMapDeps(MapExpr* struct_expr, std::stack* stack) { auto& entries = struct_expr->mutable_entries(); for (auto it = entries.rbegin(); it != entries.rend(); ++it) { auto& entry = *it; // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_value()) { stack->push(StackRecord(&entry.mutable_value())); } // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_key()) { stack->push(StackRecord(&entry.mutable_key())); } } } void PushComprehensionDeps(ComprehensionExpr* c, Expr* expr, std::stack* stack, bool use_comprehension_callbacks) { StackRecord iter_range(&c->mutable_iter_range(), c, expr, ITER_RANGE, use_comprehension_callbacks); StackRecord accu_init(&c->mutable_accu_init(), c, expr, ACCU_INIT, use_comprehension_callbacks); StackRecord loop_condition(&c->mutable_loop_condition(), c, expr, LOOP_CONDITION, use_comprehension_callbacks); StackRecord loop_step(&c->mutable_loop_step(), c, expr, LOOP_STEP, use_comprehension_callbacks); StackRecord result(&c->mutable_result(), c, expr, RESULT, use_comprehension_callbacks); // Push them in reverse order. stack->push(result); stack->push(loop_step); stack->push(loop_condition); stack->push(accu_init); stack->push(iter_range); } struct PushDepsVisitor { void operator()(const ExprRecord& record) { struct { std::stack& stack; const RewriteTraversalOptions& options; const ExprRecord& record; void operator()(const Constant&) {} void operator()(const IdentExpr&) {} void operator()(const SelectExpr&) { PushSelectDeps(&record.expr->mutable_select_expr(), &stack); } void operator()(const CallExpr&) { PushCallDeps(&record.expr->mutable_call_expr(), record.expr, &stack); } void operator()(const ListExpr&) { PushListDeps(&record.expr->mutable_list_expr(), &stack); } void operator()(const StructExpr&) { PushStructDeps(&record.expr->mutable_struct_expr(), &stack); } void operator()(const MapExpr&) { PushMapDeps(&record.expr->mutable_map_expr(), &stack); } void operator()(const ComprehensionExpr&) { PushComprehensionDeps(&record.expr->mutable_comprehension_expr(), record.expr, &stack, options.use_comprehension_callbacks); } void operator()(const UnspecifiedExpr&) {} } handler{stack, options, record}; absl::visit(handler, record.expr->kind()); } void operator()(const ArgRecord& record) { stack.push(StackRecord(record.expr)); } void operator()(const ComprehensionRecord& record) { stack.push(StackRecord(record.expr)); } std::stack& stack; const RewriteTraversalOptions& options; }; void PushDependencies(const StackRecord& record, std::stack& stack, const RewriteTraversalOptions& options) { absl::visit(PushDepsVisitor{stack, options}, record.record_variant); } } // namespace bool AstRewrite(Expr& expr, AstRewriter& visitor, RewriteTraversalOptions options) { std::stack stack; std::vector traversal_path; stack.push(StackRecord(&expr)); bool rewritten = false; while (!stack.empty()) { StackRecord& record = stack.top(); if (!record.visited) { if (record.IsExprRecord()) { traversal_path.push_back(record.expr()); visitor.TraversalStackUpdate(absl::MakeSpan(traversal_path)); if (visitor.PreVisitRewrite(*record.expr())) { rewritten = true; } } PreVisit(record, &visitor); PushDependencies(record, stack, options); record.visited = true; } else { PostVisit(record, &visitor); if (record.IsExprRecord()) { if (visitor.PostVisitRewrite(*record.expr())) { rewritten = true; } traversal_path.pop_back(); visitor.TraversalStackUpdate(absl::MakeSpan(traversal_path)); } stack.pop(); } } return rewritten; } } // namespace cel ================================================ FILE: common/ast_rewrite.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_REWRITE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_REWRITE_H_ #include "absl/base/nullability.h" #include "absl/types/span.h" #include "common/ast_visitor.h" #include "common/constant.h" #include "common/expr.h" namespace cel { // Traversal options for AstRewrite. struct RewriteTraversalOptions { // If enabled, use comprehension specific callbacks instead of the general // arguments callbacks. bool use_comprehension_callbacks; RewriteTraversalOptions() : use_comprehension_callbacks(false) {} }; // Interface for AST rewriters. // Extends AstVisitor interface with update methods. // see AstRewrite for more details on usage. class AstRewriter : public AstVisitor { public: ~AstRewriter() override {} // Rewrite a sub expression before visiting. // Occurs before visiting Expr. If expr is modified, it the new value will be // visited. virtual bool PreVisitRewrite(Expr& expr) = 0; // Rewrite a sub expression after visiting. // Occurs after visiting expr and it's children. If expr is modified, the old // sub expression is visited. virtual bool PostVisitRewrite(Expr& expr) = 0; // Notify the visitor of updates to the traversal stack. virtual void TraversalStackUpdate( absl::Span path) = 0; }; // Trivial implementation for AST rewriters. // Virtual methods are overridden with no-op callbacks. class AstRewriterBase : public AstRewriter { public: ~AstRewriterBase() override {} void PreVisitExpr(const Expr&) override {} void PostVisitExpr(const Expr&) override {} void PostVisitConst(const Expr&, const Constant&) override {} void PostVisitIdent(const Expr&, const IdentExpr&) override {} void PreVisitSelect(const Expr&, const SelectExpr&) override {} void PostVisitSelect(const Expr&, const SelectExpr&) override {} void PreVisitCall(const Expr&, const CallExpr&) override {} void PostVisitCall(const Expr&, const CallExpr&) override {} void PreVisitComprehension(const Expr&, const ComprehensionExpr&) override {} void PostVisitComprehension(const Expr&, const ComprehensionExpr&) override {} void PostVisitArg(const Expr&, int) override {} void PostVisitTarget(const Expr&) override {} void PostVisitList(const Expr&, const ListExpr&) override {} void PostVisitStruct(const Expr&, const StructExpr&) override {} void PostVisitMap(const Expr&, const MapExpr&) override {} bool PreVisitRewrite(Expr& expr) override { return false; } bool PostVisitRewrite(Expr& expr) override { return false; } void TraversalStackUpdate( absl::Span path) override {} }; // Traverses the AST representation in an expr proto. Returns true if any // rewrites occur. // // Rewrites may happen before and/or after visiting an expr subtree. If a // change happens during the pre-visit rewrite, the updated subtree will be // visited. If a change happens during the post-visit rewrite, the old subtree // will be visited. // // expr: root node of the tree. // source_info: optional additional parse information about the expression // visitor: the callback object that receives the visitation notifications // options: options for traversal. see RewriteTraversalOptions. Defaults are // used if not sepecified. // // Traversal order follows the pattern: // PreVisitRewrite // PreVisitExpr // ..PreVisit{ExprKind} // ....PreVisit{ArgumentIndex} // .......PreVisitExpr (subtree) // .......PostVisitExpr (subtree) // ....PostVisit{ArgumentIndex} // ..PostVisit{ExprKind} // PostVisitExpr // PostVisitRewrite // // Example callback order for fn(1, var): // PreVisitExpr // ..PreVisitCall(fn) // ......PreVisitExpr // ........PostVisitConst(1) // ......PostVisitExpr // ....PostVisitArg(fn, 0) // ......PreVisitExpr // ........PostVisitIdent(var) // ......PostVisitExpr // ....PostVisitArg(fn, 1) // ..PostVisitCall(fn) // PostVisitExpr bool AstRewrite(Expr& expr, AstRewriter& visitor, RewriteTraversalOptions options = RewriteTraversalOptions()); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_REWRITE_H_ ================================================ FILE: common/ast_rewrite_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast_rewrite.h" #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status_matchers.h" #include "common/ast.h" #include "common/ast/expr_proto.h" #include "common/ast_visitor.h" #include "common/expr.h" #include "extensions/protobuf/ast_converters.h" #include "internal/testing.h" #include "parser/parser.h" #include "google/protobuf/text_format.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::cel::ast_internal::ExprFromProto; using ::cel::extensions::CreateAstFromParsedExpr; using ::testing::_; using ::testing::ElementsAre; using ::testing::InSequence; using ::testing::Ref; class MockAstRewriter : public AstRewriter { public: // Expr handler. MOCK_METHOD(void, PreVisitExpr, (const Expr& expr), (override)); // Expr handler. MOCK_METHOD(void, PostVisitExpr, (const Expr& expr), (override)); MOCK_METHOD(void, PostVisitConst, (const Expr& expr, const Constant& const_expr), (override)); // Ident node handler. MOCK_METHOD(void, PostVisitIdent, (const Expr& expr, const IdentExpr& ident_expr), (override)); // Select node handler group MOCK_METHOD(void, PreVisitSelect, (const Expr& expr, const SelectExpr& select_expr), (override)); MOCK_METHOD(void, PostVisitSelect, (const Expr& expr, const SelectExpr& select_expr), (override)); // Call node handler group MOCK_METHOD(void, PreVisitCall, (const Expr& expr, const CallExpr& call_expr), (override)); MOCK_METHOD(void, PostVisitCall, (const Expr& expr, const CallExpr& call_expr), (override)); // Comprehension node handler group MOCK_METHOD(void, PreVisitComprehension, (const Expr& expr, const ComprehensionExpr& comprehension_expr), (override)); MOCK_METHOD(void, PostVisitComprehension, (const Expr& expr, const ComprehensionExpr& comprehension_expr), (override)); // Comprehension node handler group MOCK_METHOD(void, PreVisitComprehensionSubexpression, (const Expr& expr, const ComprehensionExpr& comprehension_expr, ComprehensionArg comprehension_arg), (override)); MOCK_METHOD(void, PostVisitComprehensionSubexpression, (const Expr& expr, const ComprehensionExpr& comprehension_expr, ComprehensionArg comprehension_arg), (override)); // We provide finer granularity for Call and Comprehension node callbacks // to allow special handling for short-circuiting. MOCK_METHOD(void, PostVisitTarget, (const Expr& expr), (override)); MOCK_METHOD(void, PostVisitArg, (const Expr& expr, int arg_num), (override)); // List node handler group MOCK_METHOD(void, PostVisitList, (const Expr& expr, const ListExpr& list_expr), (override)); // Struct node handler group MOCK_METHOD(void, PostVisitStruct, (const Expr& expr, const StructExpr& struct_expr), (override)); // Map node handler group MOCK_METHOD(void, PostVisitMap, (const Expr& expr, const MapExpr& map_expr), (override)); MOCK_METHOD(bool, PreVisitRewrite, (Expr & expr), (override)); MOCK_METHOD(bool, PostVisitRewrite, (Expr & expr), (override)); MOCK_METHOD(void, TraversalStackUpdate, (absl::Span path), (override)); }; TEST(AstCrawlerTest, CheckCrawlConstant) { MockAstRewriter handler; Expr expr; auto& const_expr = expr.mutable_const_expr(); EXPECT_CALL(handler, PostVisitConst(Ref(expr), Ref(const_expr))).Times(1); AstRewrite(expr, handler); } TEST(AstCrawlerTest, CheckCrawlIdent) { MockAstRewriter handler; Expr expr; auto& ident_expr = expr.mutable_ident_expr(); EXPECT_CALL(handler, PostVisitIdent(Ref(expr), Ref(ident_expr))).Times(1); AstRewrite(expr, handler); } // Test handling of Select node when operand is not set. TEST(AstCrawlerTest, CheckCrawlSelectNotCrashingPostVisitAbsentOperand) { MockAstRewriter handler; Expr expr; auto& select_expr = expr.mutable_select_expr(); // Lowest level entry will be called first EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); AstRewrite(expr, handler); } // Test handling of Select node TEST(AstCrawlerTest, CheckCrawlSelect) { MockAstRewriter handler; Expr expr; auto& select_expr = expr.mutable_select_expr(); auto& operand = select_expr.mutable_operand(); auto& ident_expr = operand.mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); AstRewrite(expr, handler); } // Test handling of Call node without receiver TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { MockAstRewriter handler; // (, ) Expr expr; auto& call_expr = expr.mutable_call_expr(); call_expr.mutable_args().reserve(2); Expr& arg0 = call_expr.mutable_args().emplace_back(); auto& const_expr = arg0.mutable_const_expr(); Expr& arg1 = call_expr.mutable_args().emplace_back(); auto& ident_expr = arg1.mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(Ref(expr), Ref(call_expr))).Times(1); EXPECT_CALL(handler, PostVisitTarget(_)).Times(0); // Arg0 EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(arg0))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), 0)).Times(1); // Arg1 EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(arg1))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), 1)).Times(1); // Back to call EXPECT_CALL(handler, PostVisitCall(Ref(expr), Ref(call_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(expr))).Times(1); AstRewrite(expr, handler); } // Test handling of Call node with receiver TEST(AstCrawlerTest, CheckCrawlCallReceiver) { MockAstRewriter handler; // .(, ) Expr expr; auto& call_expr = expr.mutable_call_expr(); Expr& target = call_expr.mutable_target(); auto& target_ident = target.mutable_ident_expr(); call_expr.mutable_args().reserve(2); Expr& arg0 = call_expr.mutable_args().emplace_back(); auto& const_expr = arg0.mutable_const_expr(); Expr& arg1 = call_expr.mutable_args().emplace_back(); auto& ident_expr = arg1.mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(Ref(expr), Ref(call_expr))).Times(1); // Target EXPECT_CALL(handler, PostVisitIdent(Ref(target), Ref(target_ident))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(target))).Times(1); EXPECT_CALL(handler, PostVisitTarget(Ref(expr))).Times(1); // Arg0 EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(arg0))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), 0)).Times(1); // Arg1 EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(arg1))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), 1)).Times(1); // Back to call EXPECT_CALL(handler, PostVisitCall(Ref(expr), Ref(call_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(expr))).Times(1); AstRewrite(expr, handler); } // Test handling of Comprehension node TEST(AstCrawlerTest, CheckCrawlComprehension) { MockAstRewriter handler; Expr expr; auto& c = expr.mutable_comprehension_expr(); auto& iter_range = c.mutable_iter_range(); auto& iter_range_expr = iter_range.mutable_const_expr(); auto& accu_init = c.mutable_accu_init(); auto& accu_init_expr = accu_init.mutable_ident_expr(); auto& loop_condition = c.mutable_loop_condition(); auto& loop_condition_expr = loop_condition.mutable_const_expr(); auto& loop_step = c.mutable_loop_step(); auto& loop_step_expr = loop_step.mutable_ident_expr(); auto& result = c.mutable_result(); auto& result_expr = result.mutable_const_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(Ref(expr), Ref(c))).Times(1); EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), ITER_RANGE)) .Times(1); EXPECT_CALL(handler, PostVisitConst(Ref(iter_range), Ref(iter_range_expr))) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), ITER_RANGE)) .Times(1); // ACCU_INIT EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), ACCU_INIT)) .Times(1); EXPECT_CALL(handler, PostVisitIdent(Ref(accu_init), Ref(accu_init_expr))) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), ACCU_INIT)) .Times(1); // LOOP CONDITION EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_CONDITION)) .Times(1); EXPECT_CALL(handler, PostVisitConst(Ref(loop_condition), Ref(loop_condition_expr))) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_CONDITION)) .Times(1); // LOOP STEP EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_STEP)) .Times(1); EXPECT_CALL(handler, PostVisitIdent(Ref(loop_step), Ref(loop_step_expr))) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_STEP)) .Times(1); // RESULT EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), RESULT)) .Times(1); EXPECT_CALL(handler, PostVisitConst(Ref(result), Ref(result_expr))).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), RESULT)) .Times(1); EXPECT_CALL(handler, PostVisitComprehension(Ref(expr), Ref(c))).Times(1); RewriteTraversalOptions opts; opts.use_comprehension_callbacks = true; AstRewrite(expr, handler, opts); } // Test handling of Comprehension node TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { MockAstRewriter handler; Expr expr; auto& c = expr.mutable_comprehension_expr(); auto& iter_range = c.mutable_iter_range(); auto& iter_range_expr = iter_range.mutable_const_expr(); auto& accu_init = c.mutable_accu_init(); auto& accu_init_expr = accu_init.mutable_ident_expr(); auto& loop_condition = c.mutable_loop_condition(); auto& loop_condition_expr = loop_condition.mutable_const_expr(); auto& loop_step = c.mutable_loop_step(); auto& loop_step_expr = loop_step.mutable_ident_expr(); auto& result = c.mutable_result(); auto& result_expr = result.mutable_const_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(Ref(expr), Ref(c))).Times(1); EXPECT_CALL(handler, PostVisitConst(Ref(iter_range), Ref(iter_range_expr))) .Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), ITER_RANGE)).Times(1); // ACCU_INIT EXPECT_CALL(handler, PostVisitIdent(Ref(accu_init), Ref(accu_init_expr))) .Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), ACCU_INIT)).Times(1); // LOOP CONDITION EXPECT_CALL(handler, PostVisitConst(Ref(loop_condition), Ref(loop_condition_expr))) .Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), LOOP_CONDITION)).Times(1); // LOOP STEP EXPECT_CALL(handler, PostVisitIdent(Ref(loop_step), Ref(loop_step_expr))) .Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), LOOP_STEP)).Times(1); // RESULT EXPECT_CALL(handler, PostVisitConst(Ref(result), Ref(result_expr))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), RESULT)).Times(1); EXPECT_CALL(handler, PostVisitComprehension(Ref(expr), Ref(c))).Times(1); AstRewrite(expr, handler); } // Test handling of List node. TEST(AstCrawlerTest, CheckList) { MockAstRewriter handler; Expr expr; auto& list_expr = expr.mutable_list_expr(); list_expr.mutable_elements().reserve(2); auto& arg0 = list_expr.mutable_elements().emplace_back().mutable_expr(); auto& const_expr = arg0.mutable_const_expr(); auto& arg1 = list_expr.mutable_elements().emplace_back().mutable_expr(); auto& ident_expr = arg1.mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitList(Ref(expr), Ref(list_expr))).Times(1); AstRewrite(expr, handler); } // Test handling of Struct node. TEST(AstCrawlerTest, CheckStruct) { MockAstRewriter handler; Expr expr; auto& struct_expr = expr.mutable_struct_expr(); auto& entry0 = struct_expr.mutable_fields().emplace_back(); auto& value = entry0.mutable_value().mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PostVisitIdent(Ref(entry0.value()), Ref(value))) .Times(1); EXPECT_CALL(handler, PostVisitStruct(Ref(expr), Ref(struct_expr))).Times(1); AstRewrite(expr, handler); } // Test handling of Map node. TEST(AstCrawlerTest, CheckMap) { MockAstRewriter handler; Expr expr; auto& map_expr = expr.mutable_map_expr(); auto& entry0 = map_expr.mutable_entries().emplace_back(); auto& key = entry0.mutable_key().mutable_const_expr(); auto& value = entry0.mutable_value().mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PostVisitConst(Ref(entry0.key()), Ref(key))).Times(1); EXPECT_CALL(handler, PostVisitIdent(Ref(entry0.value()), Ref(value))) .Times(1); EXPECT_CALL(handler, PostVisitMap(Ref(expr), Ref(map_expr))).Times(1); AstRewrite(expr, handler); } // Test generic Expr handlers. TEST(AstCrawlerTest, CheckExprHandlers) { MockAstRewriter handler; Expr expr; auto& map_expr = expr.mutable_map_expr(); auto& entry0 = map_expr.mutable_entries().emplace_back(); entry0.mutable_key().mutable_const_expr(); entry0.mutable_value().mutable_ident_expr(); EXPECT_CALL(handler, PreVisitExpr(_)).Times(3); EXPECT_CALL(handler, PostVisitExpr(_)).Times(3); AstRewrite(expr, handler); } // Test generic Expr handlers. TEST(AstCrawlerTest, CheckExprRewriteHandlers) { MockAstRewriter handler; Expr select_expr; select_expr.mutable_select_expr().set_field("var"); auto& inner_select_expr = select_expr.mutable_select_expr().mutable_operand(); inner_select_expr.mutable_select_expr().set_field("mid"); auto& ident = inner_select_expr.mutable_select_expr().mutable_operand(); ident.mutable_ident_expr().set_name("top"); { InSequence sequence; EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre(&select_expr))); EXPECT_CALL(handler, PreVisitRewrite(Ref(select_expr))); EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( &select_expr, &inner_select_expr))); EXPECT_CALL(handler, PreVisitRewrite(Ref(inner_select_expr))); EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( &select_expr, &inner_select_expr, &ident))); EXPECT_CALL(handler, PreVisitRewrite(Ref(ident))); EXPECT_CALL(handler, PostVisitRewrite(Ref(ident))); EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( &select_expr, &inner_select_expr))); EXPECT_CALL(handler, PostVisitRewrite(Ref(inner_select_expr))); EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre(&select_expr))); EXPECT_CALL(handler, PostVisitRewrite(Ref(select_expr))); EXPECT_CALL(handler, TraversalStackUpdate(testing::IsEmpty())); } EXPECT_FALSE(AstRewrite(select_expr, handler)); } // Simple rewrite that replaces a select path with a dot-qualified identifier. class RewriterExample : public AstRewriterBase { public: RewriterExample() {} bool PostVisitRewrite(Expr& expr) override { if (target_.has_value() && expr.id() == *target_) { expr.mutable_ident_expr().set_name("com.google.Identifier"); return true; } return false; } void PostVisitIdent(const Expr& expr, const IdentExpr& ident) override { if (path_.size() >= 3) { if (ident.name() == "com") { const Expr* p1 = path_.at(path_.size() - 2); const Expr* p2 = path_.at(path_.size() - 3); if (p1->has_select_expr() && p1->select_expr().field() == "google" && p2->has_select_expr() && p2->select_expr().field() == "Identifier") { target_ = p2->id(); } } } } void TraversalStackUpdate(absl::Span path) override { path_ = path; } private: absl::Span path_; absl::optional target_; }; TEST(AstRewrite, SelectRewriteExample) { ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CreateAstFromParsedExpr( google::api::expr::parser::Parse("com.google.Identifier").value())); RewriterExample example; ASSERT_TRUE(AstRewrite(ast->mutable_root_expr(), example)); cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( R"pb( id: 3 ident_expr { name: "com.google.Identifier" } )pb", &expected_expr); cel::Expr expected_native; ASSERT_THAT(ExprFromProto(expected_expr, expected_native), IsOk()); EXPECT_EQ(ast->root_expr(), expected_native); } // Rewrites x -> y -> z to demonstrate traversal when a node is rewritten on // both passes. class PreRewriterExample : public AstRewriterBase { public: PreRewriterExample() {} bool PreVisitRewrite(Expr& expr) override { if (expr.ident_expr().name() == "x") { expr.mutable_ident_expr().set_name("y"); return true; } return false; } bool PostVisitRewrite(Expr& expr) override { if (expr.ident_expr().name() == "y") { expr.mutable_ident_expr().set_name("z"); return true; } return false; } void PostVisitIdent(const Expr& expr, const IdentExpr& ident) override { visited_idents_.push_back(ident.name()); } const std::vector& visited_idents() const { return visited_idents_; } private: std::vector visited_idents_; }; TEST(AstRewrite, PreAndPostVisitExpample) { ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CreateAstFromParsedExpr(google::api::expr::parser::Parse("x").value())); PreRewriterExample visitor; ASSERT_TRUE(AstRewrite(ast->mutable_root_expr(), visitor)); cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( R"pb( id: 1 ident_expr { name: "z" } )pb", &expected_expr); cel::Expr expected_native; ASSERT_THAT(ExprFromProto(expected_expr, expected_native), IsOk()); EXPECT_EQ(ast->root_expr(), expected_native); EXPECT_THAT(visitor.visited_idents(), ElementsAre("y")); } } // namespace } // namespace cel ================================================ FILE: common/ast_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast.h" #include #include "absl/container/flat_hash_map.h" #include "common/expr.h" #include "common/source.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::Pointee; using ::testing::Truly; TEST(AstImpl, RawExprCtor) { // arrange // make ast for 2 + 1 == 3 Expr expr; auto& call = expr.mutable_call_expr(); expr.set_id(5); call.set_function("_==_"); auto& eq_lhs = call.mutable_args().emplace_back(); eq_lhs.mutable_call_expr().set_function("_+_"); eq_lhs.set_id(3); auto& sum_lhs = eq_lhs.mutable_call_expr().mutable_args().emplace_back(); sum_lhs.mutable_const_expr().set_int_value(2); sum_lhs.set_id(1); auto& sum_rhs = eq_lhs.mutable_call_expr().mutable_args().emplace_back(); sum_rhs.mutable_const_expr().set_int_value(1); sum_rhs.set_id(2); auto& eq_rhs = call.mutable_args().emplace_back(); eq_rhs.mutable_const_expr().set_int_value(3); eq_rhs.set_id(4); SourceInfo source_info; source_info.mutable_positions()[5] = 6; // act Ast ast(std::move(expr), std::move(source_info)); // assert ASSERT_FALSE(ast.is_checked()); EXPECT_EQ(ast.GetTypeOrDyn(1), TypeSpec(DynTypeSpec())); EXPECT_EQ(ast.GetReturnType(), TypeSpec(DynTypeSpec())); EXPECT_EQ(ast.GetReference(1), nullptr); EXPECT_TRUE(ast.root_expr().has_call_expr()); EXPECT_EQ(ast.root_expr().call_expr().function(), "_==_"); EXPECT_EQ(ast.root_expr().id(), 5); // Parser IDs leaf to root. EXPECT_EQ(ast.source_info().positions().at(5), 6); // start pos of == } TEST(AstImpl, CheckedExprCtor) { Expr expr; expr.mutable_ident_expr().set_name("int_value"); expr.set_id(1); Reference ref; ref.set_name("com.int_value"); Ast::ReferenceMap reference_map; reference_map[1] = Reference(ref); Ast::TypeMap type_map; type_map[1] = TypeSpec(PrimitiveType::kInt64); SourceInfo source_info; source_info.set_syntax_version("1.0"); Ast ast(std::move(expr), std::move(source_info), std::move(reference_map), std::move(type_map), "1.0"); ASSERT_TRUE(ast.is_checked()); EXPECT_EQ(ast.GetTypeOrDyn(1), TypeSpec(PrimitiveType::kInt64)); EXPECT_THAT(ast.GetReference(1), Pointee(Truly([&ref](const Reference& arg) { return arg.name() == ref.name(); }))); EXPECT_EQ(ast.GetReturnType(), TypeSpec(PrimitiveType::kInt64)); EXPECT_TRUE(ast.root_expr().has_ident_expr()); EXPECT_EQ(ast.root_expr().ident_expr().name(), "int_value"); EXPECT_EQ(ast.root_expr().id(), 1); EXPECT_EQ(ast.source_info().syntax_version(), "1.0"); EXPECT_EQ(ast.expr_version(), "1.0"); } TEST(AstImpl, CheckedExprDeepCopy) { Expr root; root.set_id(3); root.mutable_call_expr().set_function("_==_"); root.mutable_call_expr().mutable_args().resize(2); auto& lhs = root.mutable_call_expr().mutable_args()[0]; auto& rhs = root.mutable_call_expr().mutable_args()[1]; Ast::TypeMap type_map; Ast::ReferenceMap reference_map; SourceInfo source_info; type_map[3] = TypeSpec(PrimitiveType::kBool); lhs.mutable_ident_expr().set_name("int_value"); lhs.set_id(1); Reference ref; ref.set_name("com.int_value"); reference_map[1] = std::move(ref); type_map[1] = TypeSpec(PrimitiveType::kInt64); rhs.mutable_const_expr().set_int_value(2); rhs.set_id(2); type_map[2] = TypeSpec(PrimitiveType::kInt64); source_info.set_syntax_version("1.0"); Ast ast(std::move(root), std::move(source_info), std::move(reference_map), std::move(type_map), "1.0"); ASSERT_TRUE(ast.IsChecked()); EXPECT_EQ(ast.GetTypeOrDyn(1), TypeSpec(PrimitiveType::kInt64)); EXPECT_THAT(ast.GetReference(1), Pointee(Truly([](const Reference& arg) { return arg.name() == "com.int_value"; }))); EXPECT_EQ(ast.GetReturnType(), TypeSpec(PrimitiveType::kBool)); EXPECT_TRUE(ast.root_expr().has_call_expr()); EXPECT_EQ(ast.root_expr().call_expr().function(), "_==_"); EXPECT_EQ(ast.root_expr().id(), 3); EXPECT_EQ(ast.source_info().syntax_version(), "1.0"); } TEST(AstImpl, ComputeSourceLocation) { SourceInfo source_info; source_info.set_line_offsets({10, 20, 30}); source_info.mutable_positions()[1] = 0; // Start of first line source_info.mutable_positions()[2] = 5; // Middle of first line source_info.mutable_positions()[3] = 10; // ... source_info.mutable_positions()[4] = 15; source_info.mutable_positions()[5] = 20; source_info.mutable_positions()[6] = 25; Ast ast(Expr{}, std::move(source_info)); EXPECT_EQ(ast.ComputeSourceLocation(1), (SourceLocation{1, 0})); EXPECT_EQ(ast.ComputeSourceLocation(2), (SourceLocation{1, 5})); EXPECT_EQ(ast.ComputeSourceLocation(3), (SourceLocation{2, 0})); EXPECT_EQ(ast.ComputeSourceLocation(4), (SourceLocation{2, 5})); EXPECT_EQ(ast.ComputeSourceLocation(5), (SourceLocation{3, 0})); EXPECT_EQ(ast.ComputeSourceLocation(6), (SourceLocation{3, 5})); } TEST(AstImpl, ComputeSourceLocationFailures) { SourceInfo source_info; source_info.set_line_offsets({10, 20}); source_info.mutable_positions()[1] = -1; // Negative position source_info.mutable_positions()[2] = 25; // Beyond last line offset // ID 3 is missing Ast ast; ast.mutable_source_info() = std::move(source_info); EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); EXPECT_EQ(ast.ComputeSourceLocation(2), SourceLocation{}); EXPECT_EQ(ast.ComputeSourceLocation(3), SourceLocation{}); } TEST(AstImpl, ComputeSourceLocationInvalidLineOffsets) { { // Empty line offsets Ast ast; EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); } { // Non-monotonic SourceInfo source_info; source_info.set_line_offsets({10, 5}); source_info.mutable_positions()[1] = 12; Ast ast(Expr{}, std::move(source_info)); EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); } } } // namespace } // namespace cel ================================================ FILE: common/ast_traverse.cc ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast_traverse.h" #include #include #include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "common/ast_visitor.h" #include "common/constant.h" #include "common/expr.h" namespace cel { namespace { struct ArgRecord { // Not null. const Expr* expr; // For records that are direct arguments to call, we need to call // the CallArg visitor immediately after the argument is evaluated. const Expr* calling_expr; int call_arg; }; struct ComprehensionRecord { // Not null. const Expr* expr; const ComprehensionExpr* comprehension; const Expr* comprehension_expr; ComprehensionArg comprehension_arg; bool use_comprehension_callbacks; }; struct ExprRecord { // Not null. const Expr* expr; }; using StackRecordKind = absl::variant; struct StackRecord { public: static constexpr int kTarget = -2; explicit StackRecord(const Expr* e) { ExprRecord record; record.expr = e; record_variant = record; } StackRecord(const Expr* e, const ComprehensionExpr* comprehension, const Expr* comprehension_expr, ComprehensionArg comprehension_arg, bool use_comprehension_callbacks) { if (use_comprehension_callbacks) { ComprehensionRecord record; record.expr = e; record.comprehension = comprehension; record.comprehension_expr = comprehension_expr; record.comprehension_arg = comprehension_arg; record.use_comprehension_callbacks = use_comprehension_callbacks; record_variant = record; return; } ArgRecord record; record.expr = e; record.calling_expr = comprehension_expr; record.call_arg = comprehension_arg; record_variant = record; } StackRecord(const Expr* e, const Expr* call, int argnum) { ArgRecord record; record.expr = e; record.calling_expr = call; record.call_arg = argnum; record_variant = record; } StackRecordKind record_variant; bool visited = false; }; struct PreVisitor { void operator()(const ExprRecord& record) { const Expr* expr = record.expr; visitor->PreVisitExpr(*expr); if (expr->has_select_expr()) { visitor->PreVisitSelect(*expr, expr->select_expr()); } else if (expr->has_call_expr()) { visitor->PreVisitCall(*expr, expr->call_expr()); } else if (expr->has_comprehension_expr()) { visitor->PreVisitComprehension(*expr, expr->comprehension_expr()); } else { // No pre-visit action. } } // Do nothing for Arg variant. void operator()(const ArgRecord&) {} void operator()(const ComprehensionRecord& record) { visitor->PreVisitComprehensionSubexpression(*record.comprehension_expr, *record.comprehension, record.comprehension_arg); } AstVisitor* visitor; }; void PreVisit(const StackRecord& record, AstVisitor* visitor) { absl::visit(PreVisitor{visitor}, record.record_variant); } struct PostVisitor { void operator()(const ExprRecord& record) { const Expr* expr = record.expr; struct { AstVisitor* visitor; const Expr* expr; void operator()(const Constant& constant) { visitor->PostVisitConst(*expr, expr->const_expr()); } void operator()(const IdentExpr& ident) { visitor->PostVisitIdent(*expr, expr->ident_expr()); } void operator()(const SelectExpr& select) { visitor->PostVisitSelect(*expr, expr->select_expr()); } void operator()(const CallExpr& call) { visitor->PostVisitCall(*expr, expr->call_expr()); } void operator()(const ListExpr& create_list) { visitor->PostVisitList(*expr, expr->list_expr()); } void operator()(const StructExpr& create_struct) { visitor->PostVisitStruct(*expr, expr->struct_expr()); } void operator()(const MapExpr& map_expr) { visitor->PostVisitMap(*expr, expr->map_expr()); } void operator()(const ComprehensionExpr& comprehension) { visitor->PostVisitComprehension(*expr, expr->comprehension_expr()); } void operator()(const UnspecifiedExpr&) { ABSL_LOG(ERROR) << "Unsupported Expr kind"; } } handler{visitor, record.expr}; absl::visit(handler, record.expr->kind()); visitor->PostVisitExpr(*expr); } void operator()(const ArgRecord& record) { if (record.call_arg == StackRecord::kTarget) { visitor->PostVisitTarget(*record.calling_expr); } else { visitor->PostVisitArg(*record.calling_expr, record.call_arg); } } void operator()(const ComprehensionRecord& record) { visitor->PostVisitComprehensionSubexpression(*record.comprehension_expr, *record.comprehension, record.comprehension_arg); } AstVisitor* visitor; }; void PostVisit(const StackRecord& record, AstVisitor* visitor) { absl::visit(PostVisitor{visitor}, record.record_variant); } void PushSelectDeps(const SelectExpr* select_expr, std::stack* stack) { if (select_expr->has_operand()) { stack->push(StackRecord(&select_expr->operand())); } } void PushCallDeps(const CallExpr* call_expr, const Expr* expr, std::stack* stack) { const int arg_size = call_expr->args().size(); // Our contract is that we visit arguments in order. To do that, we need // to push them onto the stack in reverse order. for (int i = arg_size - 1; i >= 0; --i) { stack->push(StackRecord(&call_expr->args()[i], expr, i)); } // Are we receiver-style? if (call_expr->has_target()) { stack->push(StackRecord(&call_expr->target(), expr, StackRecord::kTarget)); } } void PushListDeps(const ListExpr* list_expr, std::stack* stack) { const auto& elements = list_expr->elements(); for (auto it = elements.rbegin(); it != elements.rend(); ++it) { const auto& element = *it; stack->push(StackRecord(&element.expr())); } } void PushStructDeps(const StructExpr* struct_expr, std::stack* stack) { const auto& entries = struct_expr->fields(); for (auto it = entries.rbegin(); it != entries.rend(); ++it) { const auto& entry = *it; // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_value()) { stack->push(StackRecord(&entry.value())); } } } void PushMapDeps(const MapExpr* map_expr, std::stack* stack) { const auto& entries = map_expr->entries(); for (auto it = entries.rbegin(); it != entries.rend(); ++it) { const auto& entry = *it; // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_value()) { stack->push(StackRecord(&entry.value())); } // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_key()) { stack->push(StackRecord(&entry.key())); } } } void PushComprehensionDeps(const ComprehensionExpr* c, const Expr* expr, std::stack* stack, bool use_comprehension_callbacks) { StackRecord iter_range(&c->iter_range(), c, expr, ITER_RANGE, use_comprehension_callbacks); StackRecord accu_init(&c->accu_init(), c, expr, ACCU_INIT, use_comprehension_callbacks); StackRecord loop_condition(&c->loop_condition(), c, expr, LOOP_CONDITION, use_comprehension_callbacks); StackRecord loop_step(&c->loop_step(), c, expr, LOOP_STEP, use_comprehension_callbacks); StackRecord result(&c->result(), c, expr, RESULT, use_comprehension_callbacks); // Push them in reverse order. stack->push(result); stack->push(loop_step); stack->push(loop_condition); stack->push(accu_init); stack->push(iter_range); } struct PushDepsVisitor { void operator()(const ExprRecord& record) { struct { std::stack& stack; const TraversalOptions& options; const ExprRecord& record; void operator()(const Constant& constant) {} void operator()(const IdentExpr& ident) {} void operator()(const SelectExpr& select) { PushSelectDeps(&record.expr->select_expr(), &stack); } void operator()(const CallExpr& call) { PushCallDeps(&record.expr->call_expr(), record.expr, &stack); } void operator()(const ListExpr& create_list) { PushListDeps(&record.expr->list_expr(), &stack); } void operator()(const StructExpr& create_struct) { PushStructDeps(&record.expr->struct_expr(), &stack); } void operator()(const MapExpr& map_expr) { PushMapDeps(&record.expr->map_expr(), &stack); } void operator()(const ComprehensionExpr& comprehension) { PushComprehensionDeps(&record.expr->comprehension_expr(), record.expr, &stack, options.use_comprehension_callbacks); } void operator()(const UnspecifiedExpr&) {} } handler{stack, options, record}; absl::visit(handler, record.expr->kind()); } void operator()(const ArgRecord& record) { stack.push(StackRecord(record.expr)); } void operator()(const ComprehensionRecord& record) { stack.push(StackRecord(record.expr)); } std::stack& stack; const TraversalOptions& options; }; void PushDependencies(const StackRecord& record, std::stack& stack, const TraversalOptions& options) { absl::visit(PushDepsVisitor{stack, options}, record.record_variant); } } // namespace namespace common_internal { struct AstTraversalState { std::stack stack; }; } // namespace common_internal AstTraversal AstTraversal::Create(const cel::Expr& ast, const TraversalOptions& options) { AstTraversal instance(options); instance.state_ = std::make_unique(); instance.state_->stack.push(StackRecord(&ast)); return instance; } AstTraversal::AstTraversal(TraversalOptions options) : options_(options) {} AstTraversal::~AstTraversal() = default; bool AstTraversal::Step(AstVisitor& visitor) { if (IsDone()) { return false; } auto& stack = state_->stack; StackRecord& record = stack.top(); if (!record.visited) { PreVisit(record, &visitor); PushDependencies(record, stack, options_); record.visited = true; } else { PostVisit(record, &visitor); stack.pop(); } return !stack.empty(); } bool AstTraversal::IsDone() { return state_ == nullptr || state_->stack.empty(); } void AstTraverse(const Expr& expr, AstVisitor& visitor, TraversalOptions options) { std::stack stack; stack.push(StackRecord(&expr)); while (!stack.empty()) { StackRecord& record = stack.top(); if (!record.visited) { PreVisit(record, &visitor); PushDependencies(record, stack, options); record.visited = true; } else { PostVisit(record, &visitor); stack.pop(); } } } } // namespace cel ================================================ FILE: common/ast_traverse.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_TRAVERSE_NATIVE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_TRAVERSE_NATIVE_H_ #include #include "absl/base/attributes.h" #include "common/ast_visitor.h" #include "common/expr.h" namespace cel { namespace common_internal { struct AstTraversalState; } struct TraversalOptions { // Enable use of the comprehension specific callbacks. bool use_comprehension_callbacks = false; }; // Helper class for managing the traversal of the AST. // Allows caller to step through the traversal. // // Usage: // // AstTraversal traversal = AstTraversal::Create(expr); // // MyVisitor visitor(); // while(!traversal.IsDone()) { // traversal.Step(visitor); // } // // This class is thread-hostile and should only be used in synchronous code. class AstTraversal { public: static AstTraversal Create(const cel::Expr& ast ABSL_ATTRIBUTE_LIFETIME_BOUND, const TraversalOptions& options = {}); ~AstTraversal(); AstTraversal(const AstTraversal&) = delete; AstTraversal& operator=(const AstTraversal&) = delete; AstTraversal(AstTraversal&&) = default; AstTraversal& operator=(AstTraversal&&) = default; // Advances the traversal. Returns true if there is more work to do. This is a // no-op if the traversal is done and IsDone() is true. bool Step(AstVisitor& visitor); // Returns true if there is no work left to do. bool IsDone(); private: explicit AstTraversal(TraversalOptions options); TraversalOptions options_; std::unique_ptr state_; }; // Traverses the AST representation in an expr proto. // // expr: root node of the tree. // source_info: optional additional parse information about the expression // visitor: the callback object that receives the visitation notifications // // Traversal order follows the pattern: // PreVisitExpr // ..PreVisit{ExprKind} // ....PreVisit{ArgumentIndex} // .......PreVisitExpr (subtree) // .......PostVisitExpr (subtree) // ....PostVisit{ArgumentIndex} // ..PostVisit{ExprKind} // PostVisitExpr // // Example callback order for fn(1, var): // PreVisitExpr // ..PreVisitCall(fn) // ......PreVisitExpr // ........PostVisitConst(1) // ......PostVisitExpr // ....PostVisitArg(fn, 0) // ......PreVisitExpr // ........PostVisitIdent(var) // ......PostVisitExpr // ....PostVisitArg(fn, 1) // ..PostVisitCall(fn) // PostVisitExpr void AstTraverse(const Expr& expr, AstVisitor& visitor, TraversalOptions options = TraversalOptions()); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_TRAVERSE_NATIVE_H_ ================================================ FILE: common/ast_traverse_test.cc ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/ast_traverse.h" #include "common/ast_visitor.h" #include "common/constant.h" #include "common/expr.h" #include "internal/testing.h" namespace cel::ast_internal { namespace { using ::testing::_; using ::testing::Ref; class MockAstVisitor : public AstVisitor { public: // Expr handler. MOCK_METHOD(void, PreVisitExpr, (const Expr& expr), (override)); // Expr handler. MOCK_METHOD(void, PostVisitExpr, (const Expr& expr), (override)); MOCK_METHOD(void, PostVisitConst, (const Expr& expr, const Constant& const_expr), (override)); // Ident node handler. MOCK_METHOD(void, PostVisitIdent, (const Expr& expr, const IdentExpr& ident_expr), (override)); // Select node handler group MOCK_METHOD(void, PreVisitSelect, (const Expr& expr, const SelectExpr& select_expr), (override)); MOCK_METHOD(void, PostVisitSelect, (const Expr& expr, const SelectExpr& select_expr), (override)); // Call node handler group MOCK_METHOD(void, PreVisitCall, (const Expr& expr, const CallExpr& call_expr), (override)); MOCK_METHOD(void, PostVisitCall, (const Expr& expr, const CallExpr& call_expr), (override)); // Comprehension node handler group MOCK_METHOD(void, PreVisitComprehension, (const Expr& expr, const ComprehensionExpr& comprehension_expr), (override)); MOCK_METHOD(void, PostVisitComprehension, (const Expr& expr, const ComprehensionExpr& comprehension_expr), (override)); // Comprehension node handler group MOCK_METHOD(void, PreVisitComprehensionSubexpression, (const Expr& expr, const ComprehensionExpr& comprehension_expr, ComprehensionArg comprehension_arg), (override)); MOCK_METHOD(void, PostVisitComprehensionSubexpression, (const Expr& expr, const ComprehensionExpr& comprehension_expr, ComprehensionArg comprehension_arg), (override)); // We provide finer granularity for Call and Comprehension node callbacks // to allow special handling for short-circuiting. MOCK_METHOD(void, PostVisitTarget, (const Expr& expr), (override)); MOCK_METHOD(void, PostVisitArg, (const Expr& expr, int arg_num), (override)); // List node handler group MOCK_METHOD(void, PostVisitList, (const Expr& expr, const ListExpr& list_expr), (override)); // Struct node handler group MOCK_METHOD(void, PostVisitStruct, (const Expr& expr, const StructExpr& struct_expr), (override)); // Map node handler group MOCK_METHOD(void, PostVisitMap, (const Expr& expr, const MapExpr& map_expr), (override)); }; TEST(AstCrawlerTest, CheckCrawlConstant) { MockAstVisitor handler; Expr expr; auto& const_expr = expr.mutable_const_expr(); EXPECT_CALL(handler, PostVisitConst(Ref(expr), Ref(const_expr))).Times(1); AstTraverse(expr, handler); } TEST(AstCrawlerTest, CheckCrawlIdent) { MockAstVisitor handler; Expr expr; auto& ident_expr = expr.mutable_ident_expr(); EXPECT_CALL(handler, PostVisitIdent(Ref(expr), Ref(ident_expr))).Times(1); AstTraverse(expr, handler); } // Test handling of Select node when operand is not set. TEST(AstCrawlerTest, CheckCrawlSelectNotCrashingPostVisitAbsentOperand) { MockAstVisitor handler; Expr expr; auto& select_expr = expr.mutable_select_expr(); // Lowest level entry will be called first EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); AstTraverse(expr, handler); } // Test handling of Select node TEST(AstCrawlerTest, CheckCrawlSelect) { MockAstVisitor handler; Expr expr; auto& select_expr = expr.mutable_select_expr(); auto& operand = select_expr.mutable_operand(); auto& ident_expr = operand.mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); AstTraverse(expr, handler); } // Test handling of Call node without receiver TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { MockAstVisitor handler; // (, ) Expr expr; auto& call_expr = expr.mutable_call_expr(); call_expr.mutable_args().reserve(2); auto& arg0 = call_expr.mutable_args().emplace_back(); auto& const_expr = arg0.mutable_const_expr(); auto& arg1 = call_expr.mutable_args().emplace_back(); auto& ident_expr = arg1.mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(Ref(expr), Ref(call_expr))).Times(1); EXPECT_CALL(handler, PostVisitTarget(_)).Times(0); // Arg0 EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(arg0))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), 0)).Times(1); // Arg1 EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(arg1))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), 1)).Times(1); // Back to call EXPECT_CALL(handler, PostVisitCall(Ref(expr), Ref(call_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(expr))).Times(1); AstTraverse(expr, handler); } // Test handling of Call node with receiver TEST(AstCrawlerTest, CheckCrawlCallReceiver) { MockAstVisitor handler; // .(, ) Expr expr; auto& call_expr = expr.mutable_call_expr(); auto& target = call_expr.mutable_target(); auto& target_ident = target.mutable_ident_expr(); call_expr.mutable_args().reserve(2); auto& arg0 = call_expr.mutable_args().emplace_back(); auto& const_expr = arg0.mutable_const_expr(); auto& arg1 = call_expr.mutable_args().emplace_back(); auto& ident_expr = arg1.mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(Ref(expr), Ref(call_expr))).Times(1); // Target EXPECT_CALL(handler, PostVisitIdent(Ref(target), Ref(target_ident))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(target))).Times(1); EXPECT_CALL(handler, PostVisitTarget(Ref(expr))).Times(1); // Arg0 EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(arg0))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), 0)).Times(1); // Arg1 EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(arg1))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), 1)).Times(1); // Back to call EXPECT_CALL(handler, PostVisitCall(Ref(expr), Ref(call_expr))).Times(1); EXPECT_CALL(handler, PostVisitExpr(Ref(expr))).Times(1); AstTraverse(expr, handler); } // Test handling of Comprehension node TEST(AstCrawlerTest, CheckCrawlComprehension) { MockAstVisitor handler; Expr expr; auto& c = expr.mutable_comprehension_expr(); auto& iter_range = c.mutable_iter_range(); auto& iter_range_expr = iter_range.mutable_const_expr(); auto& accu_init = c.mutable_accu_init(); auto& accu_init_expr = accu_init.mutable_ident_expr(); auto& loop_condition = c.mutable_loop_condition(); auto& loop_condition_expr = loop_condition.mutable_const_expr(); auto& loop_step = c.mutable_loop_step(); auto& loop_step_expr = loop_step.mutable_ident_expr(); auto& result = c.mutable_result(); auto& result_expr = result.mutable_const_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(Ref(expr), Ref(c))).Times(1); EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), ITER_RANGE)) .Times(1); EXPECT_CALL(handler, PostVisitConst(Ref(iter_range), Ref(iter_range_expr))) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), ITER_RANGE)) .Times(1); // ACCU_INIT EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), ACCU_INIT)) .Times(1); EXPECT_CALL(handler, PostVisitIdent(Ref(accu_init), Ref(accu_init_expr))) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), ACCU_INIT)) .Times(1); // LOOP CONDITION EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_CONDITION)) .Times(1); EXPECT_CALL(handler, PostVisitConst(Ref(loop_condition), Ref(loop_condition_expr))) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_CONDITION)) .Times(1); // LOOP STEP EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_STEP)) .Times(1); EXPECT_CALL(handler, PostVisitIdent(Ref(loop_step), Ref(loop_step_expr))) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), LOOP_STEP)) .Times(1); // RESULT EXPECT_CALL(handler, PreVisitComprehensionSubexpression(Ref(expr), Ref(c), RESULT)) .Times(1); EXPECT_CALL(handler, PostVisitConst(Ref(result), Ref(result_expr))).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(Ref(expr), Ref(c), RESULT)) .Times(1); EXPECT_CALL(handler, PostVisitComprehension(Ref(expr), Ref(c))).Times(1); TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(expr, handler, opts); } // Test handling of Comprehension node TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { MockAstVisitor handler; Expr expr; auto& c = expr.mutable_comprehension_expr(); auto& iter_range = c.mutable_iter_range(); auto& iter_range_expr = iter_range.mutable_const_expr(); auto& accu_init = c.mutable_accu_init(); auto& accu_init_expr = accu_init.mutable_ident_expr(); auto& loop_condition = c.mutable_loop_condition(); auto& loop_condition_expr = loop_condition.mutable_const_expr(); auto& loop_step = c.mutable_loop_step(); auto& loop_step_expr = loop_step.mutable_ident_expr(); auto& result = c.mutable_result(); auto& result_expr = result.mutable_const_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(Ref(expr), Ref(c))).Times(1); EXPECT_CALL(handler, PostVisitConst(Ref(iter_range), Ref(iter_range_expr))) .Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), ITER_RANGE)).Times(1); // ACCU_INIT EXPECT_CALL(handler, PostVisitIdent(Ref(accu_init), Ref(accu_init_expr))) .Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), ACCU_INIT)).Times(1); // LOOP CONDITION EXPECT_CALL(handler, PostVisitConst(Ref(loop_condition), Ref(loop_condition_expr))) .Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), LOOP_CONDITION)).Times(1); // LOOP STEP EXPECT_CALL(handler, PostVisitIdent(Ref(loop_step), Ref(loop_step_expr))) .Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), LOOP_STEP)).Times(1); // RESULT EXPECT_CALL(handler, PostVisitConst(Ref(result), Ref(result_expr))).Times(1); EXPECT_CALL(handler, PostVisitArg(Ref(expr), RESULT)).Times(1); EXPECT_CALL(handler, PostVisitComprehension(Ref(expr), Ref(c))).Times(1); AstTraverse(expr, handler); } // Test handling of List node. TEST(AstCrawlerTest, CheckList) { MockAstVisitor handler; Expr expr; auto& list_expr = expr.mutable_list_expr(); list_expr.mutable_elements().reserve(2); auto& arg0 = list_expr.mutable_elements().emplace_back().mutable_expr(); auto& const_expr = arg0.mutable_const_expr(); auto& arg1 = list_expr.mutable_elements().emplace_back().mutable_expr(); auto& ident_expr = arg1.mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PostVisitConst(Ref(arg0), Ref(const_expr))).Times(1); EXPECT_CALL(handler, PostVisitIdent(Ref(arg1), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitList(Ref(expr), Ref(list_expr))).Times(1); AstTraverse(expr, handler); } // Test handling of Struct node. TEST(AstCrawlerTest, CheckStruct) { MockAstVisitor handler; Expr expr; auto& struct_expr = expr.mutable_struct_expr(); auto& entry0 = struct_expr.mutable_fields().emplace_back(); auto& value = entry0.mutable_value().mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PostVisitIdent(Ref(entry0.value()), Ref(value))) .Times(1); EXPECT_CALL(handler, PostVisitStruct(Ref(expr), Ref(struct_expr))).Times(1); AstTraverse(expr, handler); } // Test handling of Map node. TEST(AstCrawlerTest, CheckMap) { MockAstVisitor handler; Expr expr; auto& map_expr = expr.mutable_map_expr(); auto& entry0 = map_expr.mutable_entries().emplace_back(); auto& key = entry0.mutable_key().mutable_const_expr(); auto& value = entry0.mutable_value().mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PostVisitConst(Ref(entry0.key()), Ref(key))).Times(1); EXPECT_CALL(handler, PostVisitIdent(Ref(entry0.value()), Ref(value))) .Times(1); EXPECT_CALL(handler, PostVisitMap(Ref(expr), Ref(map_expr))).Times(1); AstTraverse(expr, handler); } // Test generic Expr handlers. TEST(AstCrawlerTest, CheckExprHandlers) { MockAstVisitor handler; Expr expr; auto& map_expr = expr.mutable_map_expr(); auto& entry0 = map_expr.mutable_entries().emplace_back(); entry0.mutable_key().mutable_const_expr(); entry0.mutable_value().mutable_ident_expr(); EXPECT_CALL(handler, PreVisitExpr(_)).Times(3); EXPECT_CALL(handler, PostVisitExpr(_)).Times(3); AstTraverse(expr, handler); } TEST(AstTraversal, Interrupt) { MockAstVisitor handler; Expr expr; auto& select_expr = expr.mutable_select_expr(); auto& operand = select_expr.mutable_operand(); auto& ident_expr = operand.mutable_ident_expr(); testing::InSequence seq; auto traversal = AstTraversal::Create(expr); EXPECT_CALL(handler, PreVisitExpr(_)).Times(2); EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(0); EXPECT_TRUE(traversal.Step(handler)); EXPECT_TRUE(traversal.Step(handler)); EXPECT_TRUE(traversal.Step(handler)); EXPECT_FALSE(traversal.IsDone()); } TEST(AstTraversal, NoInterrupt) { MockAstVisitor handler; Expr expr; auto& select_expr = expr.mutable_select_expr(); auto& operand = select_expr.mutable_operand(); auto& ident_expr = operand.mutable_ident_expr(); testing::InSequence seq; auto traversal = AstTraversal::Create(expr); EXPECT_CALL(handler, PostVisitIdent(Ref(operand), Ref(ident_expr))).Times(1); EXPECT_CALL(handler, PostVisitSelect(Ref(expr), Ref(select_expr))).Times(1); while (traversal.Step(handler)) continue; EXPECT_TRUE(traversal.IsDone()); } } // namespace } // namespace cel::ast_internal ================================================ FILE: common/ast_visitor.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_NATIVE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_NATIVE_H_ #include "common/constant.h" #include "common/expr.h" namespace cel { // ComprehensionArg specifies arg_num values passed to PostVisitArg // for subexpressions of Comprehension. enum ComprehensionArg { ITER_RANGE, ACCU_INIT, LOOP_CONDITION, LOOP_STEP, RESULT, }; // Callback handler class, used in conjunction with AstTraverse. // Methods of this class are invoked when AST nodes with corresponding // types are processed. // // For all types with children, the children will be visited in the natural // order from first to last. For structs, keys are visited before values. class AstVisitor { public: virtual ~AstVisitor() = default; // Expr node handler method. Called for all Expr nodes. // Is invoked before child Expr nodes being processed. virtual void PreVisitExpr(const Expr&) = 0; // Expr node handler method. Called for all Expr nodes. // Is invoked after child Expr nodes are processed. virtual void PostVisitExpr(const Expr&) = 0; // Const node handler. // Invoked after child nodes are processed. virtual void PostVisitConst(const Expr&, const Constant&) = 0; // Ident node handler. // Invoked after child nodes are processed. virtual void PostVisitIdent(const Expr&, const IdentExpr&) = 0; // Select node handler // Invoked before child nodes are processed. virtual void PreVisitSelect(const Expr&, const SelectExpr&) = 0; // Select node handler // Invoked after child nodes are processed. virtual void PostVisitSelect(const Expr&, const SelectExpr&) = 0; // Call node handler group // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. virtual void PreVisitCall(const Expr&, const CallExpr&) = 0; // Invoked after all child nodes are processed. virtual void PostVisitCall(const Expr&, const CallExpr&) = 0; // Invoked after target node is processed. // Expr is the call expression. virtual void PostVisitTarget(const Expr&) = 0; // Invoked before all child nodes are processed. virtual void PreVisitComprehension(const Expr&, const ComprehensionExpr&) = 0; // Invoked before comprehension child node is processed. virtual void PreVisitComprehensionSubexpression( const Expr&, const ComprehensionExpr& compr, ComprehensionArg comprehension_arg) {} // Invoked after comprehension child node is processed. virtual void PostVisitComprehensionSubexpression( const Expr&, const ComprehensionExpr& compr, ComprehensionArg comprehension_arg) {} // Invoked after all child nodes are processed. virtual void PostVisitComprehension(const Expr&, const ComprehensionExpr&) = 0; // Invoked after each argument node processed. // For Call arg_num is the index of the argument. // For Comprehension arg_num is specified by ComprehensionArg. // Expr is the call expression. virtual void PostVisitArg(const Expr&, int arg_num) = 0; // List node handler // Invoked after child nodes are processed. virtual void PostVisitList(const Expr&, const ListExpr&) = 0; // Struct node handler // Invoked after child nodes are processed. virtual void PostVisitStruct(const Expr&, const StructExpr&) = 0; // Map node handler // Invoked after child nodes are processed. virtual void PostVisitMap(const Expr&, const MapExpr&) = 0; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_NATIVE_H_ ================================================ FILE: common/ast_visitor_base.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_BASE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_BASE_H_ #include "common/ast_visitor.h" #include "common/constant.h" #include "common/expr.h" namespace cel { // Trivial base implementation of AstVisitor. class AstVisitorBase : public AstVisitor { public: AstVisitorBase() = default; // Non-copyable AstVisitorBase(const AstVisitorBase&) = delete; AstVisitorBase& operator=(AstVisitorBase const&) = delete; ~AstVisitorBase() override {} // Const node handler. // Invoked after child nodes are processed. void PostVisitConst(const Expr&, const Constant&) override {} // Ident node handler. // Invoked after child nodes are processed. void PostVisitIdent(const Expr&, const IdentExpr&) override {} void PreVisitSelect(const Expr&, const SelectExpr&) override {} // Select node handler // Invoked after child nodes are processed. void PostVisitSelect(const Expr&, const SelectExpr&) override {} // Call node handler group // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. void PreVisitCall(const Expr&, const CallExpr&) override {} // Invoked after all child nodes are processed. void PostVisitCall(const Expr&, const CallExpr&) override {} // Invoked before all child nodes are processed. void PreVisitComprehension(const Expr&, const ComprehensionExpr&) override {} // Invoked after all child nodes are processed. void PostVisitComprehension(const Expr&, const ComprehensionExpr&) override {} // Invoked after each argument node processed. // For Call arg_num is the index of the argument. // For Comprehension arg_num is specified by ComprehensionArg. // Expr is the call expression. void PostVisitArg(const Expr&, int) override {} // Invoked after target node processed. void PostVisitTarget(const Expr&) override {} // List node handler // Invoked after child nodes are processed. void PostVisitList(const Expr&, const ListExpr&) override {} // Struct node handler // Invoked after child nodes are processed. void PostVisitStruct(const Expr&, const StructExpr&) override {} // Map node handler // Invoked after child nodes are processed. void PostVisitMap(const Expr&, const MapExpr&) override {} }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_AST_VISITOR_BASE_H_ ================================================ FILE: common/casting.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_CASTING_H_ #define THIRD_PARTY_CEL_CPP_COMMON_CASTING_H_ #include "absl/base/attributes.h" #include "common/internal/casting.h" namespace cel { // `InstanceOf(const From&)` determines whether `From` holds or is `To`. // // `To` must be a plain non-union class type that is not qualified. // // We expose `InstanceOf` this way to avoid ADL. // // Example: // // if (InstanceOf(superclass)) { // Cast(superclass).SomeMethod(); // } template ABSL_DEPRECATED("Use Is member functions instead.") inline constexpr common_internal::InstanceOfImpl InstanceOf{}; // `Cast(From)` is a "checked cast". In debug builds an assertion is emitted // which verifies `From` is an instance-of `To`. In non-debug builds, invalid // casts are undefined behavior. // // We expose `Cast` this way to avoid ADL. // // Example: // // if (InstanceOf(superclass)) { // Cast(superclass).SomeMethod(); // } template ABSL_DEPRECATED( "Use explicit conversion functions instead through static_cast.") inline constexpr common_internal::CastImpl Cast{}; // `As(From)` is a "checking cast". The result is explicitly convertible to // `bool`, such that it can be used with `if` statements. The result can be // accessed with `operator*` or `operator->`. The return type should be treated // as an implementation detail, with no assumptions on the concrete type. You // should use `auto`. // // `As` is analogous to the paradigm `if (InstanceOf(a)) Cast(a)`. // // We expose `As` this way to avoid ADL. // // Example: // // if (auto subclass = As(superclass); subclass) { // subclass->SomeMethod(); // } template ABSL_DEPRECATED("Use As member functions instead.") inline constexpr common_internal::AsImpl As{}; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_INSTANCE_OF_H_ ================================================ FILE: common/constant.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/constant.h" #include #include #include #include "absl/base/no_destructor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "internal/strings.h" namespace cel { const BytesConstant& BytesConstant::default_instance() { static const absl::NoDestructor instance; return *instance; } const StringConstant& StringConstant::default_instance() { static const absl::NoDestructor instance; return *instance; } const Constant& Constant::default_instance() { static const absl::NoDestructor instance; return *instance; } std::string FormatNullConstant() { return "null"; } std::string FormatBoolConstant(bool value) { return value ? std::string("true") : std::string("false"); } std::string FormatIntConstant(int64_t value) { return absl::StrCat(value); } std::string FormatUintConstant(uint64_t value) { return absl::StrCat(value, "u"); } std::string FormatDoubleConstant(double value) { if (std::isfinite(value)) { if (std::floor(value) != value) { // The double is not representable as a whole number, so use // absl::StrCat which will add decimal places. return absl::StrCat(value); } // absl::StrCat historically would represent 0.0 as 0, and we want the // decimal places so ZetaSQL correctly assumes the type as double // instead of int64. std::string stringified = absl::StrCat(value); if (!absl::StrContains(stringified, '.')) { absl::StrAppend(&stringified, ".0"); } return stringified; } if (std::isnan(value)) { return "nan"; } if (std::signbit(value)) { return "-infinity"; } return "+infinity"; } std::string FormatBytesConstant(absl::string_view value) { return internal::FormatBytesLiteral(value); } std::string FormatStringConstant(absl::string_view value) { return internal::FormatStringLiteral(value); } std::string FormatDurationConstant(absl::Duration value) { return absl::StrCat("duration(\"", absl::FormatDuration(value), "\")"); } std::string FormatTimestampConstant(absl::Time value) { return absl::StrCat( "timestamp(\"", absl::FormatTime("%Y-%m-%d%ET%H:%M:%E*SZ", value, absl::UTCTimeZone()), "\")"); } } // namespace cel ================================================ FILE: common/constant.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_CONSTANT_H_ #define THIRD_PARTY_CEL_CPP_COMMON_CONSTANT_H_ #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/functional/overload.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/variant.h" namespace cel { class Expr; class Constant; class BytesConstant; class StringConstant; class VariableDecl; class BytesConstant final : public std::string { public: explicit BytesConstant(std::string string) : std::string(std::move(string)) {} explicit BytesConstant(absl::string_view string) : BytesConstant(std::string(string)) {} explicit BytesConstant(const char* string) : BytesConstant(absl::NullSafeStringView(string)) {} BytesConstant() = default; BytesConstant(const BytesConstant&) = default; BytesConstant(BytesConstant&&) = default; BytesConstant& operator=(const BytesConstant&) = default; BytesConstant& operator=(BytesConstant&&) = default; BytesConstant(const StringConstant&) = delete; BytesConstant(StringConstant&&) = delete; BytesConstant& operator=(const StringConstant&) = delete; BytesConstant& operator=(StringConstant&&) = delete; private: static const BytesConstant& default_instance(); friend class Constant; }; class StringConstant final : public std::string { public: explicit StringConstant(std::string string) : std::string(std::move(string)) {} explicit StringConstant(absl::string_view string) : StringConstant(std::string(string)) {} explicit StringConstant(const char* string) : StringConstant(absl::NullSafeStringView(string)) {} StringConstant() = default; StringConstant(const StringConstant&) = default; StringConstant(StringConstant&&) = default; StringConstant& operator=(const StringConstant&) = default; StringConstant& operator=(StringConstant&&) = default; StringConstant(const BytesConstant&) = delete; StringConstant(BytesConstant&&) = delete; StringConstant& operator=(const BytesConstant&) = delete; StringConstant& operator=(BytesConstant&&) = delete; private: static const StringConstant& default_instance(); friend class Constant; }; namespace common_internal { template struct ConstantKindIndexer { static constexpr size_t value = std::conditional_t, std::integral_constant, ConstantKindIndexer>::value; }; template struct ConstantKindIndexer { static constexpr size_t value = std::conditional_t< std::is_same_v, std::integral_constant, std::integral_constant>::value; }; template struct ConstantKindImpl { using VariantType = absl::variant; template static constexpr size_t IndexOf() { return ConstantKindIndexer<0, U, Ts...>::value; } }; using ConstantKind = ConstantKindImpl; static_assert(ConstantKind::IndexOf() == 0); static_assert(ConstantKind::IndexOf() == 1); static_assert(ConstantKind::IndexOf() == 2); static_assert(ConstantKind::IndexOf() == 3); static_assert(ConstantKind::IndexOf() == 4); static_assert(ConstantKind::IndexOf() == 5); static_assert(ConstantKind::IndexOf() == 6); static_assert(ConstantKind::IndexOf() == 7); static_assert(ConstantKind::IndexOf() == 8); static_assert(ConstantKind::IndexOf() == 9); static_assert(ConstantKind::IndexOf() == absl::variant_npos); } // namespace common_internal // Constant is a variant composed of all the literal types support by the Common // Expression Language. using ConstantKind = common_internal::ConstantKind::VariantType; enum class ConstantKindCase { kUnspecified, kNull, kBool, kInt, kUint, kDouble, kBytes, kString, kDuration, kTimestamp, }; template constexpr size_t ConstantKindIndexOf() { return common_internal::ConstantKind::IndexOf(); } // Returns the `null` literal. std::string FormatNullConstant(); inline std::string FormatNullConstant(std::nullptr_t) { return FormatNullConstant(); } // Formats `value` as a bool literal. std::string FormatBoolConstant(bool value); // Formats `value` as a int literal. std::string FormatIntConstant(int64_t value); // Formats `value` as a uint literal. std::string FormatUintConstant(uint64_t value); // Formats `value` as a double literal-like representation. Due to Common // Expression Language not having NaN or infinity literals, the result will not // always be syntactically valid. std::string FormatDoubleConstant(double value); // Formats `value` as a bytes literal. std::string FormatBytesConstant(absl::string_view value); // Formats `value` as a string literal. std::string FormatStringConstant(absl::string_view value); // Formats `value` as a duration constant. std::string FormatDurationConstant(absl::Duration value); // Formats `value` as a timestamp constant. std::string FormatTimestampConstant(absl::Time value); // Represents a primitive literal. // // This is similar as the primitives supported in the well-known type // `google.protobuf.Value`, but richer so it can represent CEL's full range of // primitives. // // Lists and structs are not included as constants as these aggregate types may // contain [Expr][] elements which require evaluation and are thus not constant. // // Examples of constants include: `"hello"`, `b'bytes'`, `1u`, `4.2`, `-2`, // `true`, `null`. class Constant final { public: Constant() = default; Constant(const Constant&) = default; Constant(Constant&&) = default; Constant& operator=(const Constant&) = default; Constant& operator=(Constant&&) = default; explicit Constant(ConstantKind kind) : kind_(std::move(kind)) {} ABSL_MUST_USE_RESULT const ConstantKind& kind() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return kind_; } ABSL_DEPRECATED("Use kind()") ABSL_MUST_USE_RESULT const ConstantKind& constant_kind() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return kind(); } ABSL_MUST_USE_RESULT bool has_value() const { return !absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT bool has_null_value() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT std::nullptr_t null_value() const { return nullptr; } void set_null_value() { mutable_kind().emplace(); } void set_null_value(std::nullptr_t) { set_null_value(); } ABSL_MUST_USE_RESULT bool has_bool_value() const { return absl::holds_alternative(kind()); } void set_bool_value(bool value) { mutable_kind().emplace(value); } ABSL_MUST_USE_RESULT bool bool_value() const { return get_value(); } ABSL_MUST_USE_RESULT bool has_int_value() const { return absl::holds_alternative(kind()); } void set_int_value(int64_t value) { mutable_kind().emplace(value); } ABSL_MUST_USE_RESULT int64_t int_value() const { return get_value(); } ABSL_MUST_USE_RESULT bool has_uint_value() const { return absl::holds_alternative(kind()); } void set_uint_value(uint64_t value) { mutable_kind().emplace(value); } ABSL_MUST_USE_RESULT uint64_t uint_value() const { return get_value(); } ABSL_DEPRECATED("Use has_int_value") ABSL_MUST_USE_RESULT bool has_int64_value() const { return has_int_value(); } ABSL_DEPRECATED("Use set_int_value()") void set_int64_value(int64_t value) { set_int_value(value); } ABSL_DEPRECATED("Use int_value()") ABSL_MUST_USE_RESULT int64_t int64_value() const { return int_value(); } ABSL_DEPRECATED("Use has_uint_value()") ABSL_MUST_USE_RESULT bool has_uint64_value() const { return has_uint_value(); } ABSL_DEPRECATED("Use set_uint_value()") void set_uint64_value(uint64_t value) { set_uint_value(value); } ABSL_DEPRECATED("Use uint_value()") ABSL_MUST_USE_RESULT uint64_t uint64_value() const { return uint_value(); } ABSL_MUST_USE_RESULT bool has_double_value() const { return absl::holds_alternative(kind()); } void set_double_value(double value) { mutable_kind().emplace(value); } ABSL_MUST_USE_RESULT double double_value() const { return get_value(); } ABSL_MUST_USE_RESULT bool has_bytes_value() const { return absl::holds_alternative(kind()); } void set_bytes_value(BytesConstant value) { mutable_kind().emplace(std::move(value)); } void set_bytes_value(std::string value) { set_bytes_value(BytesConstant{std::move(value)}); } void set_bytes_value(absl::string_view value) { set_bytes_value(BytesConstant{value}); } void set_bytes_value(const char* value) { set_bytes_value(absl::NullSafeStringView(value)); } ABSL_MUST_USE_RESULT const std::string& bytes_value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (const auto* alt = absl::get_if(&kind()); alt) { return *alt; } return BytesConstant::default_instance(); } ABSL_MUST_USE_RESULT std::string release_bytes_value() { std::string string; if (auto* alt = absl::get_if(&mutable_kind()); alt) { string.swap(*alt); } mutable_kind().emplace(); return string; } ABSL_MUST_USE_RESULT bool has_string_value() const { return absl::holds_alternative(kind()); } void set_string_value(StringConstant value) { mutable_kind().emplace(std::move(value)); } void set_string_value(std::string value) { set_string_value(StringConstant{std::move(value)}); } void set_string_value(absl::string_view value) { set_string_value(StringConstant{value}); } void set_string_value(const char* value) { set_string_value(absl::NullSafeStringView(value)); } ABSL_MUST_USE_RESULT const std::string& string_value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (const auto* alt = absl::get_if(&kind()); alt) { return *alt; } return StringConstant::default_instance(); } ABSL_MUST_USE_RESULT std::string release_string_value() { std::string string; if (auto* alt = absl::get_if(&mutable_kind()); alt) { string.swap(*alt); } mutable_kind().emplace(); return string; } ABSL_DEPRECATED("duration is no longer considered a builtin type") ABSL_MUST_USE_RESULT bool has_duration_value() const { return absl::holds_alternative(kind()); } ABSL_DEPRECATED("duration is no longer considered a builtin type") void set_duration_value(absl::Duration value) { mutable_kind().emplace(value); } ABSL_DEPRECATED("duration is no longer considered a builtin type") ABSL_MUST_USE_RESULT absl::Duration duration_value() const { return get_value(); } ABSL_DEPRECATED("timestamp is no longer considered a builtin type") ABSL_MUST_USE_RESULT bool has_timestamp_value() const { return absl::holds_alternative(kind()); } ABSL_DEPRECATED("timestamp is no longer considered a builtin type") void set_timestamp_value(absl::Time value) { mutable_kind().emplace(value); } ABSL_DEPRECATED("timestamp is no longer considered a builtin type") ABSL_MUST_USE_RESULT absl::Time timestamp_value() const { return get_value(); } ABSL_DEPRECATED("Use has_timestamp_value()") ABSL_MUST_USE_RESULT bool has_time_value() const { return has_timestamp_value(); } ABSL_DEPRECATED("Use set_timestamp_value()") void set_time_value(absl::Time value) { set_timestamp_value(value); } ABSL_DEPRECATED("Use timestamp_value()") ABSL_MUST_USE_RESULT absl::Time time_value() const { return timestamp_value(); } ConstantKindCase kind_case() const { static_assert(absl::variant_size_v == 10); if (kind_.index() <= 10) { return static_cast(kind_.index()); } return ConstantKindCase::kUnspecified; } private: friend class Expr; friend class VariableDecl; static const Constant& default_instance(); ABSL_MUST_USE_RESULT ConstantKind& mutable_kind() ABSL_ATTRIBUTE_LIFETIME_BOUND { return kind_; } template T get_value() const { if (const auto* alt = absl::get_if(&kind()); alt) { return *alt; } return T{}; } ConstantKind kind_; }; inline bool operator==(const Constant& lhs, const Constant& rhs) { return lhs.kind() == rhs.kind(); } inline bool operator!=(const Constant& lhs, const Constant& rhs) { return lhs.kind() != rhs.kind(); } template void AbslStringify(Sink& sink, const Constant& constant) { absl::visit( absl::Overload( [&sink](absl::monostate) -> void { sink.Append(""); }, [&sink](std::nullptr_t value) -> void { sink.Append(FormatNullConstant(value)); }, [&sink](bool value) -> void { sink.Append(FormatBoolConstant(value)); }, [&sink](int64_t value) -> void { sink.Append(FormatIntConstant(value)); }, [&sink](uint64_t value) -> void { sink.Append(FormatUintConstant(value)); }, [&sink](double value) -> void { sink.Append(FormatDoubleConstant(value)); }, [&sink](const BytesConstant& value) -> void { sink.Append(FormatBytesConstant(value)); }, [&sink](const StringConstant& value) -> void { sink.Append(FormatStringConstant(value)); }, [&sink](absl::Duration value) -> void { sink.Append(FormatDurationConstant(value)); }, [&sink](absl::Time value) -> void { sink.Append(FormatTimestampConstant(value)); }), constant.kind()); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_CONSTANT_H_ ================================================ FILE: common/constant_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/constant.h" #include #include #include #include #include "absl/strings/has_absl_stringify.h" #include "absl/strings/str_format.h" #include "absl/time/time.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::IsEmpty; using ::testing::IsFalse; using ::testing::IsTrue; TEST(Constant, NullValue) { Constant const_expr; EXPECT_THAT(const_expr.has_null_value(), IsFalse()); const_expr.set_null_value(); EXPECT_THAT(const_expr.has_null_value(), IsTrue()); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kNull); } TEST(Constant, BoolValue) { Constant const_expr; EXPECT_THAT(const_expr.has_bool_value(), IsFalse()); EXPECT_EQ(const_expr.bool_value(), false); const_expr.set_bool_value(false); EXPECT_THAT(const_expr.has_bool_value(), IsTrue()); EXPECT_EQ(const_expr.bool_value(), false); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kBool); } TEST(Constant, IntValue) { Constant const_expr; EXPECT_THAT(const_expr.has_int_value(), IsFalse()); EXPECT_EQ(const_expr.int_value(), 0); const_expr.set_int_value(0); EXPECT_THAT(const_expr.has_int_value(), IsTrue()); EXPECT_EQ(const_expr.int_value(), 0); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kInt); } TEST(Constant, UintValue) { Constant const_expr; EXPECT_THAT(const_expr.has_uint_value(), IsFalse()); EXPECT_EQ(const_expr.uint_value(), 0); const_expr.set_uint_value(0); EXPECT_THAT(const_expr.has_uint_value(), IsTrue()); EXPECT_EQ(const_expr.uint_value(), 0); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kUint); } TEST(Constant, DoubleValue) { Constant const_expr; EXPECT_THAT(const_expr.has_double_value(), IsFalse()); EXPECT_EQ(const_expr.double_value(), 0); const_expr.set_double_value(0); EXPECT_THAT(const_expr.has_double_value(), IsTrue()); EXPECT_EQ(const_expr.double_value(), 0); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kDouble); } TEST(Constant, BytesValue) { Constant const_expr; EXPECT_THAT(const_expr.has_bytes_value(), IsFalse()); EXPECT_THAT(const_expr.bytes_value(), IsEmpty()); const_expr.set_bytes_value("foo"); EXPECT_THAT(const_expr.has_bytes_value(), IsTrue()); EXPECT_EQ(const_expr.bytes_value(), "foo"); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kBytes); } TEST(Constant, StringValue) { Constant const_expr; EXPECT_THAT(const_expr.has_string_value(), IsFalse()); EXPECT_THAT(const_expr.string_value(), IsEmpty()); const_expr.set_string_value("foo"); EXPECT_THAT(const_expr.has_string_value(), IsTrue()); EXPECT_EQ(const_expr.string_value(), "foo"); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kString); } TEST(Constant, DurationValue) { Constant const_expr; EXPECT_THAT(const_expr.has_duration_value(), IsFalse()); EXPECT_EQ(const_expr.duration_value(), absl::ZeroDuration()); const_expr.set_duration_value(absl::ZeroDuration()); EXPECT_THAT(const_expr.has_duration_value(), IsTrue()); EXPECT_EQ(const_expr.duration_value(), absl::ZeroDuration()); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kDuration); } TEST(Constant, TimestampValue) { Constant const_expr; EXPECT_THAT(const_expr.has_timestamp_value(), IsFalse()); EXPECT_EQ(const_expr.timestamp_value(), absl::UnixEpoch()); const_expr.set_timestamp_value(absl::UnixEpoch()); EXPECT_THAT(const_expr.has_timestamp_value(), IsTrue()); EXPECT_EQ(const_expr.timestamp_value(), absl::UnixEpoch()); EXPECT_EQ(const_expr.kind().index(), ConstantKindIndexOf()); EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kTimestamp); } TEST(Constant, DefaultConstructed) { Constant const_expr; EXPECT_EQ(const_expr.kind_case(), ConstantKindCase::kUnspecified); } TEST(Constant, Equality) { EXPECT_EQ(Constant{}, Constant{}); Constant lhs_const_expr; Constant rhs_const_expr; lhs_const_expr.set_null_value(); rhs_const_expr.set_null_value(); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); lhs_const_expr.set_bool_value(false); rhs_const_expr.set_null_value(); EXPECT_NE(lhs_const_expr, rhs_const_expr); EXPECT_NE(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); rhs_const_expr.set_bool_value(false); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); lhs_const_expr.set_int_value(0); rhs_const_expr.set_null_value(); EXPECT_NE(lhs_const_expr, rhs_const_expr); EXPECT_NE(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); rhs_const_expr.set_int_value(0); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); lhs_const_expr.set_uint_value(0); rhs_const_expr.set_null_value(); EXPECT_NE(lhs_const_expr, rhs_const_expr); EXPECT_NE(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); rhs_const_expr.set_uint_value(0); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); lhs_const_expr.set_double_value(0); rhs_const_expr.set_null_value(); EXPECT_NE(lhs_const_expr, rhs_const_expr); EXPECT_NE(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); rhs_const_expr.set_double_value(0); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); lhs_const_expr.set_bytes_value("foo"); rhs_const_expr.set_null_value(); EXPECT_NE(lhs_const_expr, rhs_const_expr); EXPECT_NE(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); rhs_const_expr.set_bytes_value("foo"); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); lhs_const_expr.set_string_value("foo"); rhs_const_expr.set_null_value(); EXPECT_NE(lhs_const_expr, rhs_const_expr); EXPECT_NE(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); rhs_const_expr.set_string_value("foo"); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); lhs_const_expr.set_duration_value(absl::ZeroDuration()); rhs_const_expr.set_null_value(); EXPECT_NE(lhs_const_expr, rhs_const_expr); EXPECT_NE(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); rhs_const_expr.set_duration_value(absl::ZeroDuration()); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); lhs_const_expr.set_timestamp_value(absl::UnixEpoch()); rhs_const_expr.set_null_value(); EXPECT_NE(lhs_const_expr, rhs_const_expr); EXPECT_NE(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); rhs_const_expr.set_timestamp_value(absl::UnixEpoch()); EXPECT_EQ(lhs_const_expr, rhs_const_expr); EXPECT_EQ(rhs_const_expr, lhs_const_expr); EXPECT_NE(lhs_const_expr, Constant{}); EXPECT_NE(Constant{}, rhs_const_expr); } std::string Stringify(const Constant& constant) { return absl::StrFormat("%v", constant); } TEST(Constant, HasAbslStringify) { EXPECT_TRUE(absl::HasAbslStringify::value); } TEST(Constant, AbslStringify) { Constant constant; EXPECT_EQ(Stringify(constant), ""); constant.set_null_value(); EXPECT_EQ(Stringify(constant), "null"); constant.set_bool_value(true); EXPECT_EQ(Stringify(constant), "true"); constant.set_int_value(1); EXPECT_EQ(Stringify(constant), "1"); constant.set_uint_value(1); EXPECT_EQ(Stringify(constant), "1u"); constant.set_double_value(1); EXPECT_EQ(Stringify(constant), "1.0"); constant.set_double_value(1.1); EXPECT_EQ(Stringify(constant), "1.1"); constant.set_double_value(NAN); EXPECT_EQ(Stringify(constant), "nan"); constant.set_double_value(INFINITY); EXPECT_EQ(Stringify(constant), "+infinity"); constant.set_double_value(-INFINITY); EXPECT_EQ(Stringify(constant), "-infinity"); constant.set_bytes_value("foo"); EXPECT_EQ(Stringify(constant), "b\"foo\""); constant.set_string_value("foo"); EXPECT_EQ(Stringify(constant), "\"foo\""); constant.set_duration_value(absl::Seconds(1)); EXPECT_EQ(Stringify(constant), "duration(\"1s\")"); constant.set_timestamp_value(absl::UnixEpoch() + absl::Seconds(1)); EXPECT_EQ(Stringify(constant), "timestamp(\"1970-01-01T00:00:01Z\")"); } } // namespace } // namespace cel ================================================ FILE: common/container.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/container.h" #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "internal/lexis.h" namespace cel { namespace { bool IsValidQualifiedName(absl::string_view name) { auto dot_pos = name.find('.'); while (dot_pos != absl::string_view::npos) { if (!internal::LexisIsIdentifier(name.substr(0, dot_pos))) { return false; } name = name.substr(dot_pos + 1); dot_pos = name.find('.'); } return internal::LexisIsIdentifier(name); } bool IsValidAlias(absl::string_view alias) { return internal::LexisIsIdentifier(alias); } bool IsAbbreviationImpl(absl::string_view alias, absl::string_view name) { auto pos = name.rfind('.'); return pos != std::string::npos && pos > 0 && pos < name.size() - 1 && alias == name.substr(pos + 1); } } // namespace bool ExpressionContainer::AliasListing::IsAbbreviation() const { return IsAbbreviationImpl(alias, name); } absl::StatusOr MakeExpressionContainer( absl::string_view name) { ExpressionContainer container; absl::Status status = container.SetContainer(name); if (!status.ok()) { return status; } return container; } absl::Status ExpressionContainer::SetContainer(absl::string_view name) { if (name.empty()) { container_ = ""; return absl::OkStatus(); } if (!IsValidQualifiedName(name)) { return absl::InvalidArgumentError( absl::StrCat("invalid qualified name: ", name)); } for (const auto& entry : aliases_) { const std::string& alias = entry.first; if (name == alias || (name.size() > alias.size() && absl::string_view(name).substr(0, alias.size()) == alias && name.at(alias.size()) == '.')) { return absl::InvalidArgumentError( absl::StrCat("container name collides with alias: ", alias)); } } container_ = std::string(name); return absl::OkStatus(); } absl::Status ExpressionContainer::AddAbbreviation(absl::string_view abrev) { if (!IsValidQualifiedName(abrev)) { return absl::InvalidArgumentError( absl::StrCat("invalid qualified name: ", abrev)); } auto pos = abrev.rfind('.'); if (pos == 0 || pos == absl::string_view::npos || pos == abrev.size() - 1) { return absl::InvalidArgumentError( absl::StrCat("invalid qualified name: ", abrev, ", wanted name of the form 'qualified.name'")); } absl::string_view alias = abrev.substr(pos + 1); return AddAlias(alias, abrev); } absl::Status ExpressionContainer::AddAlias(absl::string_view alias, absl::string_view name) { if (!IsValidAlias(alias)) { return absl::InvalidArgumentError(absl::StrCat( "alias must be non-empty and simple (not qualified): ", alias)); } if (!IsValidQualifiedName(name)) { return absl::InvalidArgumentError( absl::StrCat("invalid qualified name: ", name)); } if (auto it = aliases_.find(alias); it != aliases_.end()) { return absl::InvalidArgumentError(absl::StrCat( "alias collides with existing reference: ", alias, " -> ", it->second)); } if (container_ == alias || (container_.size() > alias.size() && absl::string_view(container_).substr(0, alias.size()) == alias && container_.at(alias.size()) == '.')) { return absl::InvalidArgumentError( absl::StrCat("alias collides with container name: ", alias)); } aliases_.insert({std::string(alias), std::string(name)}); return absl::OkStatus(); } absl::string_view ExpressionContainer::FindAlias( absl::string_view alias) const { auto it = aliases_.find(alias); if (it != aliases_.end()) { return it->second; } return ""; } std::vector ExpressionContainer::ListAbbreviations() const { std::vector res; for (const auto& entry : aliases_) { if (IsAbbreviationImpl(entry.first, entry.second)) { res.push_back(entry.second); } } return res; } std::vector ExpressionContainer::ListAliases() const { std::vector res; for (const auto& entry : aliases_) { res.push_back({entry.first, entry.second}); } return res; } } // namespace cel ================================================ FILE: common/container.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ #define THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" namespace cel { // ExpressionContainer represents the namespace configuration for a CEL // expression. // // The container defines the default resolution order for names referenced in // the expression. It generally maps to a protobuf package and follows // approximately the same resolution rules as protobuf or C++ namespaces. // // Aliases declare short names that can be referenced without resolving against // the scopes defined by the container. An alias cannot be a prefix of the // container name, (otherwise re-type-checking an expression could // change the meaning). Aliases are always unqualified identifiers. // // An abbreviation is a special case of alias that behaves like an import or // using declaration in other languages. (pkg.TypeName -> TypeName). // // For better traceability, prefer using abbreviations over aliases. class ExpressionContainer { public: struct AliasListing { std::string alias; std::string name; bool IsAbbreviation() const; }; ExpressionContainer() = default; ExpressionContainer(const ExpressionContainer&) = default; ExpressionContainer(ExpressionContainer&&) = default; ExpressionContainer& operator=(const ExpressionContainer&) = default; ExpressionContainer& operator=(ExpressionContainer&&) = default; // Returns the full name of the container. // // The default value is an empty string meaning no container. absl::string_view container() const { return container_; } // Sets the container name. // // Returns an error if the container name is malformed or conflicts with an // existing alias. absl::Status SetContainer(absl::string_view name); // Adds an abbreviation to the container. // // Returns an error if the abbreviation is malformed or conflicts with the // container or an existing alias. absl::Status AddAbbreviation(absl::string_view abrev); // Adds an alias to the container. // // Returns an error if the alias is malformed or conflicts with the container // or an existing alias. absl::Status AddAlias(absl::string_view alias, absl::string_view name); // Returns the full name of the alias or an empty string if not found. // // The returned string view may be invalidated by updates to the // ExpressionContainer. absl::string_view FindAlias(absl::string_view alias) const; // Utility method for listing the abbreviations in the container. // Order is not guaranteed. std::vector ListAbbreviations() const; // Utility method for listing the aliases in the container. // Includes abbreviations. // Order is not guaranteed. std::vector ListAliases() const; // Removes all aliases and abbreviations from the container. void clear() { container_.clear(); aliases_.clear(); } private: std::string container_; // alias -> full name. absl::flat_hash_map aliases_; }; // Factory function for creating an ExpressionContainer. absl::StatusOr MakeExpressionContainer( absl::string_view name); // Factory function for creating an ExpressionContainer with a list of // abbreviations. template absl::StatusOr MakeExpressionContainer( absl::string_view name, Args&&... abbrevs) { ExpressionContainer container; absl::Status status = container.SetContainer(name); if (!status.ok()) { return status; } absl::string_view abbrevs_view[] = {std::forward(abbrevs)...}; for (absl::string_view abrev : abbrevs_view) { status.Update(container.AddAbbreviation(abrev)); if (!status.ok()) { return status; } } return container; } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ ================================================ FILE: common/container_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/container.h" #include "absl/status/status.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::SizeIs; using ::testing::UnorderedElementsAre; TEST(ExpressionContainerTest, DefaultConstructed) { ExpressionContainer container; EXPECT_THAT(container.container(), IsEmpty()); EXPECT_THAT(container.FindAlias("foo"), IsEmpty()); } TEST(ExpressionContainerTest, MakeExpressionContainer) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("my.container")); EXPECT_THAT(container.container(), Eq("my.container")); EXPECT_THAT(MakeExpressionContainer("..invalid"), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ExpressionContainerTest, MakeExpressionContainerWithAbbrevs) { ASSERT_OK_AND_ASSIGN( ExpressionContainer container, MakeExpressionContainer("my.container", "pkg.Abbr", "qual.pkg.Abbr2")); EXPECT_THAT(container.container(), Eq("my.container")); EXPECT_THAT(container.FindAlias("Abbr"), Eq("pkg.Abbr")); EXPECT_THAT(container.FindAlias("Abbr2"), Eq("qual.pkg.Abbr2")); EXPECT_THAT(MakeExpressionContainer("my.container", "invalid"), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ExpressionContainerTest, SetContainer) { ExpressionContainer container; EXPECT_THAT(container.SetContainer("my.container.name"), IsOk()); EXPECT_THAT(container.container(), Eq("my.container.name")); EXPECT_THAT(container.SetContainer("..invalid"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.SetContainer("foo.1invalid"), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ExpressionContainerTest, AddAlias) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); EXPECT_THAT(container.FindAlias("foo"), Eq("bar.baz")); } TEST(ExpressionContainerTest, AddAbbreviation) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation("qual.pkg.TypeName"), IsOk()); EXPECT_THAT(container.FindAlias("TypeName"), Eq("qual.pkg.TypeName")); } TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation("qual.pkg.Abbr"), IsOk()); EXPECT_THAT(container.AddAlias("AliasSym", "some.long.name"), IsOk()); EXPECT_THAT(container.ListAbbreviations(), UnorderedElementsAre("qual.pkg.Abbr")); auto aliases = container.ListAliases(); EXPECT_THAT(aliases, SizeIs(2)); } TEST(ExpressionContainerTest, InvalidAbbreviation) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation(""), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAbbreviation("pkg"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAbbreviation(".pkg"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAbbreviation("pkg."), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ExpressionContainerTest, InvalidAlias) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("", "bar"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAlias("foo.bar", "baz"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAlias("foo", ".baz"), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ExpressionContainerTest, CollidesWithContainer) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("my", "bar"), StatusIs(absl::StatusCode::kInvalidArgument)); } } // namespace } // namespace cel ================================================ FILE: common/data.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_DATA_H_ #define THIRD_PARTY_CEL_CPP_COMMON_DATA_H_ #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "common/internal/metadata.h" #include "google/protobuf/arena.h" namespace cel { class Data; template struct Ownable; template struct Borrowable; namespace common_internal { class ReferenceCount; void SetDataReferenceCount(const Data* absl_nonnull data, const ReferenceCount* absl_nonnull refcount); const ReferenceCount* absl_nullable GetDataReferenceCount( const Data* absl_nonnull data); } // namespace common_internal // `Data` is one of the base classes of objects that can be managed by // `MemoryManager`, the other is `google::protobuf::MessageLite`. class Data { public: Data(const Data&) = default; Data(Data&&) = default; ~Data() = default; Data& operator=(const Data&) = default; Data& operator=(Data&&) = default; google::protobuf::Arena* absl_nullable GetArena() const { return (owner_ & kOwnerBits) == kOwnerArenaBit ? reinterpret_cast(owner_ & kOwnerPointerMask) : nullptr; } protected: // At this point, the reference count has not been created. So we create it // unowned and set the reference count after. In theory we could create the // reference count ahead of time and then update it with the data it has to // delete, but that is a bit counter intuitive. Doing it this way is also // similar to how std::enable_shared_from_this works. Data() = default; Data(std::nullptr_t) = delete; explicit Data(google::protobuf::Arena* absl_nullable arena) : owner_(reinterpret_cast(arena) | (arena != nullptr ? kOwnerArenaBit : kOwnerNone)) {} private: static constexpr uintptr_t kOwnerNone = common_internal::kMetadataOwnerNone; static constexpr uintptr_t kOwnerReferenceCountBit = common_internal::kMetadataOwnerReferenceCountBit; static constexpr uintptr_t kOwnerArenaBit = common_internal::kMetadataOwnerArenaBit; static constexpr uintptr_t kOwnerBits = common_internal::kMetadataOwnerBits; static constexpr uintptr_t kOwnerPointerMask = common_internal::kMetadataOwnerPointerMask; friend void common_internal::SetDataReferenceCount( const Data* absl_nonnull data, const common_internal::ReferenceCount* absl_nonnull refcount); friend const common_internal::ReferenceCount* absl_nullable common_internal::GetDataReferenceCount(const Data* absl_nonnull data); template friend struct Ownable; template friend struct Borrowable; mutable uintptr_t owner_ = kOwnerNone; }; namespace common_internal { inline void SetDataReferenceCount(const Data* absl_nonnull data, const ReferenceCount* absl_nonnull refcount) { ABSL_DCHECK_EQ(data->owner_, Data::kOwnerNone); data->owner_ = reinterpret_cast(refcount) | Data::kOwnerReferenceCountBit; } inline const ReferenceCount* absl_nullable GetDataReferenceCount( const Data* absl_nonnull data) { return (data->owner_ & Data::kOwnerBits) == Data::kOwnerReferenceCountBit ? reinterpret_cast(data->owner_ & Data::kOwnerPointerMask) : nullptr; } } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_DATA_H_ ================================================ FILE: common/data_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This header contains primitives for reference counting, roughly equivalent to // the primitives used to implement `std::shared_ptr`. These primitives should // not be used directly in most cases, instead `cel::ManagedMemory` should be // used instead. #include "common/data.h" #include "absl/base/nullability.h" #include "common/internal/reference_count.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::testing::IsNull; class DataTest final : public Data { public: DataTest() noexcept : Data() {} explicit DataTest(google::protobuf::Arena* absl_nullable arena) noexcept : Data(arena) {} }; class DataReferenceCount final : public common_internal::ReferenceCounted { public: explicit DataReferenceCount(const Data* data) : data_(data) {} private: void Finalize() noexcept override { delete data_; } const Data* data_; }; TEST(Data, Arena) { google::protobuf::Arena arena; DataTest data(&arena); EXPECT_EQ(data.GetArena(), &arena); EXPECT_THAT(common_internal::GetDataReferenceCount(&data), IsNull()); } TEST(Data, ReferenceCount) { auto* data = new DataTest(); EXPECT_THAT(data->GetArena(), IsNull()); auto* refcount = new DataReferenceCount(data); common_internal::SetDataReferenceCount(data, refcount); EXPECT_EQ(common_internal::GetDataReferenceCount(data), refcount); common_internal::StrongUnref(refcount); } } // namespace } // namespace cel ================================================ FILE: common/decl.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/decl.h" #include #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/internal/signature.h" #include "common/type.h" #include "common/type_kind.h" namespace cel { namespace common_internal { bool TypeIsAssignable(const Type& to, const Type& from) { if (to == from) { return true; } const auto to_kind = to.kind(); if (to_kind == TypeKind::kDyn) { return true; } switch (to_kind) { case TypeKind::kBoolWrapper: return TypeIsAssignable(NullType{}, from) || TypeIsAssignable(BoolType{}, from); case TypeKind::kIntWrapper: return TypeIsAssignable(NullType{}, from) || TypeIsAssignable(IntType{}, from); case TypeKind::kUintWrapper: return TypeIsAssignable(NullType{}, from) || TypeIsAssignable(UintType{}, from); case TypeKind::kDoubleWrapper: return TypeIsAssignable(NullType{}, from) || TypeIsAssignable(DoubleType{}, from); case TypeKind::kBytesWrapper: return TypeIsAssignable(NullType{}, from) || TypeIsAssignable(BytesType{}, from); case TypeKind::kStringWrapper: return TypeIsAssignable(NullType{}, from) || TypeIsAssignable(StringType{}, from); default: break; } const auto from_kind = from.kind(); if (to_kind != from_kind || to.name() != from.name()) { return false; } auto to_params = to.GetParameters(); auto from_params = from.GetParameters(); const auto params_size = to_params.size(); if (params_size != from_params.size()) { return false; } for (size_t i = 0; i < params_size; ++i) { if (!TypeIsAssignable(to_params[i], from_params[i])) { return false; } } return true; } } // namespace common_internal namespace { bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) { if (lhs.member() != rhs.member()) { return false; } const auto& lhs_args = lhs.args(); const auto& rhs_args = rhs.args(); const auto args_size = lhs_args.size(); if (args_size != rhs_args.size()) { return false; } bool args_overlap = true; for (size_t i = 0; i < args_size; ++i) { args_overlap = args_overlap && (common_internal::TypeIsAssignable(lhs_args[i], rhs_args[i]) || common_internal::TypeIsAssignable(rhs_args[i], lhs_args[i])); } return args_overlap; } template void AddOverloadInternal(std::string_view function_name, std::vector& insertion_order, OverloadDeclHashSet& overloads, Overload&& overload, absl::Status& status) { if (!status.ok()) { return; } if (overload.id().empty()) { OverloadDecl overload_decl = overload; absl::StatusOr overload_id = common_internal::MakeOverloadSignature( function_name, overload_decl.args(), overload_decl.member()); if (!overload_id.ok()) { status = overload_id.status(); return; } overload_decl.set_id(*overload_id); AddOverloadInternal(function_name, insertion_order, overloads, std::move(overload_decl), status); return; } if (auto it = overloads.find(overload.id()); it != overloads.end()) { status = absl::AlreadyExistsError( absl::StrCat("overload already exists: ", overload.id())); return; } for (const auto& existing : overloads) { if (SignaturesOverlap(overload, existing)) { status = absl::InvalidArgumentError( absl::StrCat("overload signature collision: ", existing.id(), " collides with ", overload.id())); return; } } const auto inserted = overloads.insert(std::forward(overload)); ABSL_DCHECK(inserted.second); insertion_order.push_back(*inserted.first); } void CollectTypeParams(absl::flat_hash_set& type_params, const Type& type) { const auto kind = type.kind(); switch (kind) { case TypeKind::kList: { const auto& list_type = type.GetList(); CollectTypeParams(type_params, list_type.element()); } break; case TypeKind::kMap: { const auto& map_type = type.GetMap(); CollectTypeParams(type_params, map_type.key()); CollectTypeParams(type_params, map_type.value()); } break; case TypeKind::kOpaque: { const auto& opaque_type = type.GetOpaque(); for (const auto& param : opaque_type.GetParameters()) { CollectTypeParams(type_params, param); } } break; case TypeKind::kFunction: { const auto& function_type = type.GetFunction(); CollectTypeParams(type_params, function_type.result()); for (const auto& arg : function_type.args()) { CollectTypeParams(type_params, arg); } } break; case TypeKind::kTypeParam: type_params.emplace(type.GetTypeParam().name()); break; default: break; } } } // namespace absl::flat_hash_set OverloadDecl::GetTypeParams() const { absl::flat_hash_set type_params; CollectTypeParams(type_params, result()); for (const auto& arg : args()) { CollectTypeParams(type_params, arg); } return type_params; } void FunctionDecl::AddOverloadImpl(const OverloadDecl& overload, absl::Status& status) { AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, overload, status); } void FunctionDecl::AddOverloadImpl(OverloadDecl&& overload, absl::Status& status) { AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, std::move(overload), status); } } // namespace cel ================================================ FILE: common/decl.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/constant.h" #include "common/type.h" #include "internal/status_macros.h" namespace cel { class VariableDecl; class OverloadDecl; class FunctionDecl; // `VariableDecl` represents a declaration of a variable, composed of its name // and type, and optionally a constant value. class VariableDecl final { public: VariableDecl() = default; VariableDecl(const VariableDecl&) = default; VariableDecl(VariableDecl&&) = default; VariableDecl& operator=(const VariableDecl&) = default; VariableDecl& operator=(VariableDecl&&) = default; ABSL_MUST_USE_RESULT const std::string& name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } void set_name(std::string name) { name_ = std::move(name); } void set_name(absl::string_view name) { name_.assign(name.data(), name.size()); } void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } ABSL_MUST_USE_RESULT std::string release_name() { std::string released; released.swap(name_); return released; } ABSL_MUST_USE_RESULT const Type& type() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return type_; } ABSL_MUST_USE_RESULT Type& mutable_type() ABSL_ATTRIBUTE_LIFETIME_BOUND { return type_; } void set_type(Type type) { mutable_type() = std::move(type); } ABSL_MUST_USE_RESULT bool has_value() const { return value_.has_value(); } ABSL_MUST_USE_RESULT const Constant& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_value() ? *value_ : Constant::default_instance(); } Constant& mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_value()) { value_.emplace(); } return *value_; } void set_value(absl::optional value) { value_ = std::move(value); } void set_value(Constant value) { mutable_value() = std::move(value); } ABSL_MUST_USE_RESULT Constant release_value() { absl::optional released; released.swap(value_); return std::move(released).value_or(Constant{}); } private: std::string name_; Type type_ = DynType{}; absl::optional value_; }; inline VariableDecl MakeVariableDecl(absl::string_view name, Type type) { VariableDecl variable_decl; variable_decl.set_name(std::string(name)); variable_decl.set_type(std::move(type)); return variable_decl; } inline VariableDecl MakeConstantVariableDecl(std::string name, Type type, Constant value) { VariableDecl variable_decl; variable_decl.set_name(std::move(name)); variable_decl.set_type(std::move(type)); variable_decl.set_value(std::move(value)); return variable_decl; } inline bool operator==(const VariableDecl& lhs, const VariableDecl& rhs) { return lhs.name() == rhs.name() && lhs.type() == rhs.type() && lhs.has_value() == rhs.has_value() && lhs.value() == rhs.value(); } inline bool operator!=(const VariableDecl& lhs, const VariableDecl& rhs) { return !operator==(lhs, rhs); } // `OverloadDecl` represents a single overload of `FunctionDecl`. class OverloadDecl final { public: OverloadDecl() = default; OverloadDecl(const OverloadDecl&) = default; OverloadDecl(OverloadDecl&&) = default; OverloadDecl& operator=(const OverloadDecl&) = default; OverloadDecl& operator=(OverloadDecl&&) = default; ABSL_MUST_USE_RESULT const std::string& id() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return id_; } void set_id(std::string id) { id_ = std::move(id); } void set_id(absl::string_view id) { id_.assign(id.data(), id.size()); } void set_id(const char* id) { set_id(absl::NullSafeStringView(id)); } ABSL_MUST_USE_RESULT std::string release_id() { std::string released; released.swap(id_); return released; } ABSL_MUST_USE_RESULT const std::vector& args() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return args_; } ABSL_MUST_USE_RESULT std::vector& mutable_args() ABSL_ATTRIBUTE_LIFETIME_BOUND { return args_; } ABSL_MUST_USE_RESULT std::vector release_args() { std::vector released; released.swap(mutable_args()); return released; } ABSL_MUST_USE_RESULT const Type& result() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return result_; } ABSL_MUST_USE_RESULT Type& mutable_result() ABSL_ATTRIBUTE_LIFETIME_BOUND { return result_; } void set_result(Type result) { mutable_result() = std::move(result); } ABSL_MUST_USE_RESULT bool member() const { return member_; } void set_member(bool member) { member_ = member; } absl::flat_hash_set GetTypeParams() const; private: std::string id_; std::vector args_; Type result_ = DynType{}; bool member_ = false; }; inline bool operator==(const OverloadDecl& lhs, const OverloadDecl& rhs) { return lhs.id() == rhs.id() && absl::c_equal(lhs.args(), rhs.args()) && lhs.result() == rhs.result() && lhs.member() == rhs.member(); } inline bool operator!=(const OverloadDecl& lhs, const OverloadDecl& rhs) { return !operator==(lhs, rhs); } template OverloadDecl MakeOverloadDecl(Type result, Args&&... args) { OverloadDecl overload_decl; overload_decl.set_result(std::move(result)); overload_decl.set_member(false); auto& mutable_args = overload_decl.mutable_args(); mutable_args.reserve(sizeof...(Args)); (mutable_args.push_back(std::forward(args)), ...); return overload_decl; } // Prefer the version of `MakeOverloadDecl` that does not specify the id. // This version is less robust than the version that automatically generates a // descriptive overload id at the time the overload is added to the function // declaration. template OverloadDecl MakeOverloadDecl(absl::string_view id, Type result, Args&&... args) { OverloadDecl overload_decl; overload_decl.set_id(std::string(id)); overload_decl.set_result(std::move(result)); overload_decl.set_member(false); auto& mutable_args = overload_decl.mutable_args(); mutable_args.reserve(sizeof...(Args)); (mutable_args.push_back(std::forward(args)), ...); return overload_decl; } template OverloadDecl MakeMemberOverloadDecl(Type result, Args&&... args) { OverloadDecl overload_decl; overload_decl.set_result(std::move(result)); overload_decl.set_member(true); auto& mutable_args = overload_decl.mutable_args(); mutable_args.reserve(sizeof...(Args)); (mutable_args.push_back(std::forward(args)), ...); return overload_decl; } // Avoid this version of `MakeMemberOverloadDecl`, it is less robust than the // version that automatically generates a descriptive overload id at the time // the overload is added to the function declaration. template OverloadDecl MakeMemberOverloadDecl(absl::string_view id, Type result, Args&&... args) { OverloadDecl overload_decl; overload_decl.set_id(std::string(id)); overload_decl.set_result(std::move(result)); overload_decl.set_member(true); auto& mutable_args = overload_decl.mutable_args(); mutable_args.reserve(sizeof...(Args)); (mutable_args.push_back(std::forward(args)), ...); return overload_decl; } struct OverloadDeclHash { using is_transparent = void; size_t operator()(const OverloadDecl& overload_decl) const { return (*this)(overload_decl.id()); } size_t operator()(absl::string_view id) const { return absl::HashOf(id); } }; struct OverloadDeclEqualTo { using is_transparent = void; bool operator()(const OverloadDecl& lhs, const OverloadDecl& rhs) const { return (*this)(lhs.id(), rhs.id()); } bool operator()(const OverloadDecl& lhs, absl::string_view rhs) const { return (*this)(lhs.id(), rhs); } bool operator()(absl::string_view lhs, const OverloadDecl& rhs) const { return (*this)(lhs, rhs.id()); } bool operator()(absl::string_view lhs, absl::string_view rhs) const { return lhs == rhs; } }; using OverloadDeclHashSet = absl::flat_hash_set; template absl::StatusOr MakeFunctionDecl(std::string name, Overloads&&... overloads); // `FunctionDecl` represents a function declaration. class FunctionDecl final { public: FunctionDecl() = default; FunctionDecl(const FunctionDecl&) = default; FunctionDecl(FunctionDecl&&) = default; FunctionDecl& operator=(const FunctionDecl&) = default; FunctionDecl& operator=(FunctionDecl&&) = default; ABSL_MUST_USE_RESULT const std::string& name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } void set_name(std::string name) { name_ = std::move(name); } void set_name(absl::string_view name) { name_.assign(name.data(), name.size()); } void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } ABSL_MUST_USE_RESULT std::string release_name() { std::string released; released.swap(name_); return released; } absl::Status AddOverload(const OverloadDecl& overload) { absl::Status status; AddOverloadImpl(overload, status); return status; } absl::Status AddOverload(OverloadDecl&& overload) { absl::Status status; AddOverloadImpl(std::move(overload), status); return status; } ABSL_MUST_USE_RESULT absl::Span overloads() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return overloads_.insertion_order; } std::vector release_overloads() { std::vector released = std::move(overloads_.insertion_order); overloads_.insertion_order.clear(); overloads_.set.clear(); return released; } private: struct Overloads { std::vector insertion_order; OverloadDeclHashSet set; void Reserve(size_t size) { insertion_order.reserve(size); set.reserve(size); } }; template friend absl::StatusOr MakeFunctionDecl( std::string name, Overloads&&... overloads); void AddOverloadImpl(const OverloadDecl& overload, absl::Status& status); void AddOverloadImpl(OverloadDecl&& overload, absl::Status& status); std::string name_; Overloads overloads_; }; inline bool operator==(const FunctionDecl& lhs, const FunctionDecl& rhs) { return lhs.name() == rhs.name() && absl::c_equal(lhs.overloads(), rhs.overloads()); } inline bool operator!=(const FunctionDecl& lhs, const FunctionDecl& rhs) { return !operator==(lhs, rhs); } template absl::StatusOr MakeFunctionDecl(std::string name, Overloads&&... overloads) { FunctionDecl function_decl; function_decl.set_name(std::move(name)); function_decl.overloads_.Reserve(sizeof...(Overloads)); absl::Status status; (function_decl.AddOverloadImpl(std::forward(overloads), status), ...); CEL_RETURN_IF_ERROR(status); return function_decl; } namespace common_internal { // Checks whether `from` is assignable to `to`. // This can probably be in a better place, it is here currently to ease testing. bool TypeIsAssignable(const Type& to, const Type& from); } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ ================================================ FILE: common/decl_proto.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/decl_proto.h" #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "common/decl.h" #include "common/type.h" #include "common/type_proto.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { absl::StatusOr VariableDeclFromProto( absl::string_view name, const cel::expr::Decl::IdentDecl& variable, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(Type type, TypeFromProto(variable.type(), descriptor_pool, arena)); return cel::MakeVariableDecl(std::string(name), type); } absl::StatusOr FunctionDeclFromProto( absl::string_view name, const cel::expr::Decl::FunctionDecl& function, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { cel::FunctionDecl decl; decl.set_name(name); for (const auto& overload_pb : function.overloads()) { cel::OverloadDecl ovl_decl; ovl_decl.set_id(overload_pb.overload_id()); ovl_decl.set_member(overload_pb.is_instance_function()); CEL_ASSIGN_OR_RETURN( cel::Type result, TypeFromProto(overload_pb.result_type(), descriptor_pool, arena)); ovl_decl.set_result(result); std::vector param_types; param_types.reserve(overload_pb.params_size()); for (const auto& param_type_pb : overload_pb.params()) { CEL_ASSIGN_OR_RETURN( param_types.emplace_back(), TypeFromProto(param_type_pb, descriptor_pool, arena)); } ovl_decl.mutable_args() = std::move(param_types); CEL_RETURN_IF_ERROR(decl.AddOverload(std::move(ovl_decl))); } return decl; } absl::StatusOr> DeclFromProto( const cel::expr::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { if (decl.has_ident()) { return VariableDeclFromProto(decl.name(), decl.ident(), descriptor_pool, arena); } else if (decl.has_function()) { return FunctionDeclFromProto(decl.name(), decl.function(), descriptor_pool, arena); } return absl::InvalidArgumentError("empty google.api.expr.Decl proto"); } } // namespace cel ================================================ FILE: common/decl_proto.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_H_ #define THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_H_ #include "cel/expr/checked.pb.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "common/decl.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { // Creates a VariableDecl from a google.api.expr.Decl.IdentDecl proto. absl::StatusOr VariableDeclFromProto( absl::string_view name, const cel::expr::Decl::IdentDecl& variable, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena); // Creates a FunctionDecl from a google.api.expr.Decl.FunctionDecl proto. absl::StatusOr FunctionDeclFromProto( absl::string_view name, const cel::expr::Decl::FunctionDecl& function, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena); // Creates a VariableDecl or FunctionDecl from a google.api.expr.Decl proto. absl::StatusOr> DeclFromProto( const cel::expr::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_H_ ================================================ FILE: common/decl_proto_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/decl_proto.h" #include #include "google/api/expr/v1alpha1/checked.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/types/variant.h" #include "common/decl.h" #include "common/decl_proto_v1alpha1.h" #include "internal/testing.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/text_format.h" namespace cel { namespace { using ::absl_testing::StatusIs; enum class DeclType { kVariable, kFunction, kInvalid }; struct TestCase { std::string proto_decl; DeclType decl_type; }; class DeclFromProtoTest : public ::testing::TestWithParam {}; TEST_P(DeclFromProtoTest, FromProtoWorks) { const TestCase& test_case = GetParam(); google::protobuf::Arena arena; const google::protobuf::DescriptorPool* descriptor_pool = google::protobuf::DescriptorPool::generated_pool(); cel::expr::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); absl::StatusOr> decl_or = DeclFromProto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { ASSERT_OK_AND_ASSIGN(auto decl, decl_or); EXPECT_TRUE(absl::holds_alternative(decl)); break; } case DeclType::kFunction: { ASSERT_OK_AND_ASSIGN(auto decl, decl_or); EXPECT_TRUE(absl::holds_alternative(decl)); break; } case DeclType::kInvalid: { EXPECT_THAT(decl_or, StatusIs(absl::StatusCode::kInvalidArgument)); break; } } } // Tests that the v1alpha1 proto can be converted to the unversioned proto. // Same underlying implementation. TEST_P(DeclFromProtoTest, FromV1Alpha1ProtoWorks) { const TestCase& test_case = GetParam(); google::protobuf::Arena arena; const google::protobuf::DescriptorPool* descriptor_pool = google::protobuf::DescriptorPool::generated_pool(); google::api::expr::v1alpha1::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); absl::StatusOr> decl_or = DeclFromV1Alpha1Proto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { ASSERT_OK_AND_ASSIGN(auto decl, decl_or); EXPECT_TRUE(absl::holds_alternative(decl)); break; } case DeclType::kFunction: { ASSERT_OK_AND_ASSIGN(auto decl, decl_or); EXPECT_TRUE(absl::holds_alternative(decl)); break; } case DeclType::kInvalid: { EXPECT_THAT(decl_or, StatusIs(absl::StatusCode::kInvalidArgument)); break; } } } // TODO(uncreated-issue/80): Add tests for round-trip conversion after the ToProto // functions are implemented. INSTANTIATE_TEST_SUITE_P( DeclFromProtoTest, DeclFromProtoTest, testing::Values( TestCase{ R"pb( name: "foo_var" ident { type { primitive: BOOL } })pb", DeclType::kVariable}, TestCase{ R"pb( name: "foo_fn" function { overloads { overload_id: "foo_fn_int" params { primitive: INT64 } result_type { primitive: BOOL } } overloads { overload_id: "int_foo_fn" is_instance_function: true params { primitive: INT64 } result_type { primitive: BOOL } } overloads { overload_id: "foo_fn_T" params { type_param: "T" } type_params: "T" result_type { primitive: BOOL } } })pb", DeclType::kFunction}, // Need a descriptor to lookup a struct type. TestCase{ R"pb( name: "foo_fn" ident { type { message_type: "com.example.UnknownType" } })pb", DeclType::kInvalid}, // Empty decl is invalid. TestCase{R"pb(name: "foo_fn")pb", DeclType::kInvalid})); } // namespace } // namespace cel ================================================ FILE: common/decl_proto_v1alpha1.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/decl_proto_v1alpha1.h" #include "cel/expr/checked.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "common/decl.h" #include "common/decl_proto.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { absl::StatusOr VariableDeclFromV1Alpha1Proto( absl::string_view name, const google::api::expr::v1alpha1::Decl::IdentDecl& variable, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { cel::expr::Decl::IdentDecl unversioned; if (!unversioned.MergeFromString(variable.SerializeAsString())) { return absl::InternalError( "failed to convert versioned to unversioned Decl proto"); } return VariableDeclFromProto(name, unversioned, descriptor_pool, arena); } absl::StatusOr FunctionDeclFromV1Alpha1Proto( absl::string_view name, const google::api::expr::v1alpha1::Decl::FunctionDecl& function, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { cel::expr::Decl::FunctionDecl unversioned; if (!unversioned.MergeFromString(function.SerializeAsString())) { return absl::InternalError( "failed to convert versioned to unversioned Decl proto"); } return FunctionDeclFromProto(name, unversioned, descriptor_pool, arena); } absl::StatusOr> DeclFromV1Alpha1Proto( const google::api::expr::v1alpha1::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { cel::expr::Decl unversioned; if (!unversioned.MergeFromString(decl.SerializeAsString())) { return absl::InternalError( "failed to convert versioned to unversioned Decl proto"); } return DeclFromProto(unversioned, descriptor_pool, arena); } } // namespace cel ================================================ FILE: common/decl_proto_v1alpha1.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Converters to/from versioned Decl protos to the equivalent CEL C++ types. #ifndef THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_V1ALPHA1_H_ #define THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_V1ALPHA1_H_ #include "google/api/expr/v1alpha1/checked.pb.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "common/decl.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { // Creates a VariableDecl from a google.api.expr.v1alpha1.Decl.IdentDecl proto. absl::StatusOr VariableDeclFromV1Alpha1Proto( absl::string_view name, const google::api::expr::v1alpha1::Decl::IdentDecl& variable, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena); // Creates a FunctionDecl from a google.api.expr.v1alpha1.Decl.FunctionDecl // proto. absl::StatusOr FunctionDeclFromV1Alpha1Proto( absl::string_view name, const google::api::expr::v1alpha1::Decl::FunctionDecl& function, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena); // Creates a VariableDecl or FunctionDecl from a google.api.expr.v1alpha1.Decl // proto. absl::StatusOr> DeclFromV1Alpha1Proto( const google::api::expr::v1alpha1::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_PROTO_V1ALPHA1_H_ ================================================ FILE: common/decl_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/decl.h" #include #include #include "absl/log/die_if_null.h" // IWYU pragma: keep #include "absl/status/status.h" #include "common/constant.h" #include "common/type.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::absl_testing::StatusIs; using ::cel::internal::GetTestingDescriptorPool; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Property; using ::testing::UnorderedElementsAre; TEST(VariableDecl, Name) { VariableDecl variable_decl; EXPECT_THAT(variable_decl.name(), IsEmpty()); variable_decl.set_name("foo"); EXPECT_EQ(variable_decl.name(), "foo"); EXPECT_EQ(variable_decl.release_name(), "foo"); EXPECT_THAT(variable_decl.name(), IsEmpty()); } TEST(VariableDecl, Type) { VariableDecl variable_decl; EXPECT_EQ(variable_decl.type(), DynType{}); variable_decl.set_type(StringType{}); EXPECT_EQ(variable_decl.type(), StringType{}); } TEST(VariableDecl, Value) { VariableDecl variable_decl; EXPECT_FALSE(variable_decl.has_value()); EXPECT_EQ(variable_decl.value(), Constant{}); Constant value; value.set_bool_value(true); variable_decl.set_value(value); EXPECT_TRUE(variable_decl.has_value()); EXPECT_EQ(variable_decl.value(), value); EXPECT_EQ(variable_decl.release_value(), value); EXPECT_EQ(variable_decl.value(), Constant{}); } Constant MakeBoolConstant(bool value) { Constant constant; constant.set_bool_value(value); return constant; } TEST(VariableDecl, Equality) { VariableDecl variable_decl; EXPECT_EQ(variable_decl, VariableDecl{}); variable_decl.mutable_value().set_bool_value(true); EXPECT_NE(variable_decl, VariableDecl{}); EXPECT_EQ(MakeVariableDecl("foo", StringType{}), MakeVariableDecl("foo", StringType{})); EXPECT_EQ(MakeVariableDecl("foo", StringType{}), MakeVariableDecl("foo", StringType{})); EXPECT_EQ( MakeConstantVariableDecl("foo", StringType{}, MakeBoolConstant(true)), MakeConstantVariableDecl("foo", StringType{}, MakeBoolConstant(true))); EXPECT_EQ( MakeConstantVariableDecl("foo", StringType{}, MakeBoolConstant(true)), MakeConstantVariableDecl("foo", StringType{}, MakeBoolConstant(true))); } TEST(OverloadDecl, Id) { OverloadDecl overload_decl; EXPECT_THAT(overload_decl.id(), IsEmpty()); overload_decl.set_id("foo"); EXPECT_EQ(overload_decl.id(), "foo"); EXPECT_EQ(overload_decl.release_id(), "foo"); EXPECT_THAT(overload_decl.id(), IsEmpty()); } TEST(OverloadDecl, Result) { OverloadDecl overload_decl; EXPECT_EQ(overload_decl.result(), DynType{}); overload_decl.set_result(StringType{}); EXPECT_EQ(overload_decl.result(), StringType{}); } TEST(OverloadDecl, Args) { OverloadDecl overload_decl; EXPECT_THAT(overload_decl.args(), IsEmpty()); overload_decl.mutable_args().push_back(StringType{}); EXPECT_THAT(overload_decl.args(), ElementsAre(StringType{})); EXPECT_THAT(overload_decl.release_args(), ElementsAre(StringType{})); EXPECT_THAT(overload_decl.args(), IsEmpty()); } TEST(OverloadDecl, Member) { OverloadDecl overload_decl; EXPECT_FALSE(overload_decl.member()); overload_decl.set_member(true); EXPECT_TRUE(overload_decl.member()); } TEST(OverloadDecl, Equality) { OverloadDecl overload_decl; EXPECT_EQ(overload_decl, OverloadDecl{}); overload_decl.set_member(true); EXPECT_NE(overload_decl, OverloadDecl{}); } TEST(OverloadDecl, GetTypeParams) { google::protobuf::Arena arena; auto overload_decl = MakeOverloadDecl( "foo", ListType(&arena, TypeParamType("A")), MapType(&arena, TypeParamType("B"), TypeParamType("C")), OpaqueType(&arena, "bar", {FunctionType(&arena, TypeParamType("D"), {})})); EXPECT_THAT(overload_decl.GetTypeParams(), UnorderedElementsAre("A", "B", "C", "D")); } TEST(FunctionDecl, Name) { FunctionDecl function_decl; EXPECT_THAT(function_decl.name(), IsEmpty()); function_decl.set_name("foo"); EXPECT_EQ(function_decl.name(), "foo"); EXPECT_EQ(function_decl.release_name(), "foo"); EXPECT_THAT(function_decl.name(), IsEmpty()); } TEST(FunctionDecl, Overloads) { ASSERT_OK_AND_ASSIGN( auto function_decl, MakeFunctionDecl( "hello", MakeOverloadDecl("foo", StringType{}, StringType{}), MakeMemberOverloadDecl("bar", StringType{}, StringType{}), MakeOverloadDecl("baz", IntType{}, IntType{}))); EXPECT_THAT(function_decl.overloads(), ElementsAre(Property(&OverloadDecl::id, "foo"), Property(&OverloadDecl::id, "bar"), Property(&OverloadDecl::id, "baz"))); EXPECT_THAT(function_decl.AddOverload( MakeOverloadDecl("qux", DynType{}, StringType{})), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(FunctionDecl, OverloadId) { google::protobuf::Arena arena; const auto* descriptor = ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")); ASSERT_OK_AND_ASSIGN( auto function_decl, MakeFunctionDecl( "hello", MakeOverloadDecl(DoubleType{}), MakeOverloadDecl(StringType{}, StringType{}), MakeOverloadDecl(IntType{}, IntType{}, UintType{}), MakeOverloadDecl(IntType{}, ListType(&arena, TypeParamType("A"))), MakeOverloadDecl(IntType{}, MapType(&arena, TypeParamType("B"), TypeParamType("C"))), MakeOverloadDecl( IntType{}, OpaqueType(&arena, "bar", {FunctionType(&arena, TypeParamType("D"), {})})), MakeOverloadDecl(IntType{}, AnyType{}), MakeOverloadDecl(IntType{}, DurationType{}), MakeOverloadDecl(IntType{}, TimestampType{}), MakeOverloadDecl(IntType{}, IntWrapperType{}), MakeOverloadDecl(IntType{}, MessageType(descriptor)), MakeMemberOverloadDecl(StringType{}, StringType{}), MakeMemberOverloadDecl(StringType{}, StringType{}, ListType(&arena, BoolType{})), MakeMemberOverloadDecl(StringType{}, StringType{}, BoolType{}, DynType{}))); EXPECT_THAT( function_decl.overloads(), ElementsAre(Property(&OverloadDecl::id, "hello()"), Property(&OverloadDecl::id, "hello(string)"), Property(&OverloadDecl::id, "hello(int,uint)"), Property(&OverloadDecl::id, "hello(list<~A>)"), Property(&OverloadDecl::id, "hello(map<~B,~C>)"), Property(&OverloadDecl::id, "hello(bar>)"), Property(&OverloadDecl::id, "hello(any)"), Property(&OverloadDecl::id, "hello(duration)"), Property(&OverloadDecl::id, "hello(timestamp)"), Property(&OverloadDecl::id, "hello(int_wrapper)"), Property(&OverloadDecl::id, "hello(cel.expr.conformance.proto3.TestAllTypes)"), Property(&OverloadDecl::id, "string.hello()"), Property(&OverloadDecl::id, "string.hello(list)"), Property(&OverloadDecl::id, "string.hello(bool,dyn)"))); } using common_internal::TypeIsAssignable; TEST(TypeIsAssignable, BoolWrapper) { EXPECT_TRUE(TypeIsAssignable(BoolWrapperType{}, BoolWrapperType{})); EXPECT_TRUE(TypeIsAssignable(BoolWrapperType{}, NullType{})); EXPECT_TRUE(TypeIsAssignable(BoolWrapperType{}, BoolType{})); EXPECT_FALSE(TypeIsAssignable(BoolWrapperType{}, DurationType{})); } TEST(TypeIsAssignable, IntWrapper) { EXPECT_TRUE(TypeIsAssignable(IntWrapperType{}, IntWrapperType{})); EXPECT_TRUE(TypeIsAssignable(IntWrapperType{}, NullType{})); EXPECT_TRUE(TypeIsAssignable(IntWrapperType{}, IntType{})); EXPECT_FALSE(TypeIsAssignable(IntWrapperType{}, DurationType{})); } TEST(TypeIsAssignable, UintWrapper) { EXPECT_TRUE(TypeIsAssignable(UintWrapperType{}, UintWrapperType{})); EXPECT_TRUE(TypeIsAssignable(UintWrapperType{}, NullType{})); EXPECT_TRUE(TypeIsAssignable(UintWrapperType{}, UintType{})); EXPECT_FALSE(TypeIsAssignable(UintWrapperType{}, DurationType{})); } TEST(TypeIsAssignable, DoubleWrapper) { EXPECT_TRUE(TypeIsAssignable(DoubleWrapperType{}, DoubleWrapperType{})); EXPECT_TRUE(TypeIsAssignable(DoubleWrapperType{}, NullType{})); EXPECT_TRUE(TypeIsAssignable(DoubleWrapperType{}, DoubleType{})); EXPECT_FALSE(TypeIsAssignable(DoubleWrapperType{}, DurationType{})); } TEST(TypeIsAssignable, BytesWrapper) { EXPECT_TRUE(TypeIsAssignable(BytesWrapperType{}, BytesWrapperType{})); EXPECT_TRUE(TypeIsAssignable(BytesWrapperType{}, NullType{})); EXPECT_TRUE(TypeIsAssignable(BytesWrapperType{}, BytesType{})); EXPECT_FALSE(TypeIsAssignable(BytesWrapperType{}, DurationType{})); } TEST(TypeIsAssignable, StringWrapper) { EXPECT_TRUE(TypeIsAssignable(StringWrapperType{}, StringWrapperType{})); EXPECT_TRUE(TypeIsAssignable(StringWrapperType{}, NullType{})); EXPECT_TRUE(TypeIsAssignable(StringWrapperType{}, StringType{})); EXPECT_FALSE(TypeIsAssignable(StringWrapperType{}, DurationType{})); } TEST(TypeIsAssignable, Complex) { google::protobuf::Arena arena; EXPECT_TRUE(TypeIsAssignable(OptionalType(&arena, DynType{}), OptionalType(&arena, StringType{}))); EXPECT_FALSE(TypeIsAssignable(OptionalType(&arena, BoolType{}), OptionalType(&arena, StringType{}))); } } // namespace } // namespace cel ================================================ FILE: common/expr.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/expr.h" #include #include #include "absl/base/no_destructor.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/types/variant.h" #include "common/constant.h" namespace cel { namespace { struct CopyStackRecord { const Expr* src; Expr* dst; }; void CopyNode(CopyStackRecord element, std::vector& stack) { const Expr* src = element.src; Expr* dst = element.dst; dst->set_id(src->id()); absl::visit( absl::Overload( [=](const UnspecifiedExpr&) { dst->mutable_kind().emplace(); }, [=](const IdentExpr& i) { dst->mutable_ident_expr().set_name(i.name()); }, [=](const Constant& c) { dst->mutable_const_expr() = c; }, [&](const SelectExpr& s) { dst->mutable_select_expr().set_field(s.field()); dst->mutable_select_expr().set_test_only(s.test_only()); if (s.has_operand()) { stack.push_back({&s.operand(), &dst->mutable_select_expr().mutable_operand()}); } }, [&](const CallExpr& c) { dst->mutable_call_expr().set_function(c.function()); if (c.has_target()) { stack.push_back( {&c.target(), &dst->mutable_call_expr().mutable_target()}); } dst->mutable_call_expr().mutable_args().resize(c.args().size()); for (int i = 0; i < dst->mutable_call_expr().mutable_args().size(); ++i) { stack.push_back( {&c.args()[i], &dst->mutable_call_expr().mutable_args()[i]}); } }, [&](const ListExpr& c) { auto& dst_list = dst->mutable_list_expr(); dst_list.mutable_elements().resize(c.elements().size()); for (int i = 0; i < src->list_expr().elements().size(); ++i) { dst_list.mutable_elements()[i].set_optional( c.elements()[i].optional()); stack.push_back({&c.elements()[i].expr(), &dst_list.mutable_elements()[i].mutable_expr()}); } }, [&](const StructExpr& s) { auto& dst_struct = dst->mutable_struct_expr(); dst_struct.mutable_fields().resize(s.fields().size()); dst_struct.set_name(s.name()); for (int i = 0; i < s.fields().size(); ++i) { dst_struct.mutable_fields()[i].set_optional( s.fields()[i].optional()); dst_struct.mutable_fields()[i].set_name(s.fields()[i].name()); dst_struct.mutable_fields()[i].set_id(s.fields()[i].id()); stack.push_back( {&s.fields()[i].value(), &dst_struct.mutable_fields()[i].mutable_value()}); } }, [&](const MapExpr& c) { auto& dst_map = dst->mutable_map_expr(); dst_map.mutable_entries().resize(c.entries().size()); for (int i = 0; i < c.entries().size(); ++i) { dst_map.mutable_entries()[i].set_optional( c.entries()[i].optional()); dst_map.mutable_entries()[i].set_id(c.entries()[i].id()); stack.push_back({&c.entries()[i].key(), &dst_map.mutable_entries()[i].mutable_key()}); stack.push_back({&c.entries()[i].value(), &dst_map.mutable_entries()[i].mutable_value()}); } }, [&](const ComprehensionExpr& c) { auto& dst_comprehension = dst->mutable_comprehension_expr(); dst_comprehension.set_iter_var(c.iter_var()); dst_comprehension.set_iter_var2(c.iter_var2()); dst_comprehension.set_accu_var(c.accu_var()); if (c.has_accu_init()) { stack.push_back( {&c.accu_init(), &dst_comprehension.mutable_accu_init()}); } if (c.has_iter_range()) { stack.push_back( {&c.iter_range(), &dst_comprehension.mutable_iter_range()}); } if (c.has_loop_condition()) { stack.push_back({&c.loop_condition(), &dst_comprehension.mutable_loop_condition()}); } if (c.has_loop_step()) { stack.push_back( {&c.loop_step(), &dst_comprehension.mutable_loop_step()}); } if (c.has_result()) { stack.push_back( {&c.result(), &dst_comprehension.mutable_result()}); } }), src->kind()); } void CloneImpl(const Expr& expr, Expr& dst) { std::vector stack; stack.push_back({&expr, &dst}); while (!stack.empty()) { CopyStackRecord element = stack.back(); stack.pop_back(); CopyNode(element, stack); } } } // namespace const UnspecifiedExpr& UnspecifiedExpr::default_instance() { static const absl::NoDestructor instance; return *instance; } const IdentExpr& IdentExpr::default_instance() { static const absl::NoDestructor instance; return *instance; } const SelectExpr& SelectExpr::default_instance() { static const absl::NoDestructor instance; return *instance; } const CallExpr& CallExpr::default_instance() { static const absl::NoDestructor instance; return *instance; } const ListExpr& ListExpr::default_instance() { static const absl::NoDestructor instance; return *instance; } const StructExpr& StructExpr::default_instance() { static const absl::NoDestructor instance; return *instance; } const MapExpr& MapExpr::default_instance() { static const absl::NoDestructor instance; return *instance; } const ComprehensionExpr& ComprehensionExpr::default_instance() { static const absl::NoDestructor instance; return *instance; } const Expr& Expr::default_instance() { static const absl::NoDestructor instance; return *instance; } Expr& Expr::operator=(const Expr& other) { if (this == &other) { return *this; } Expr tmp; CloneImpl(other, tmp); *this = std::move(tmp); return *this; } Expr::Expr(const Expr& other) { CloneImpl(other, *this); } namespace common_internal { struct ExprEraseTag {}; } // namespace common_internal namespace { void Expand(Expr** tail, Expr* cur) { static common_internal::ExprEraseTag tag; switch (cur->kind_case()) { case ExprKindCase::kSelectExpr: { SelectExpr& select = cur->mutable_select_expr(); if (select.has_operand()) { select.mutable_operand().SetNext(tag, *tail); *tail = &select.mutable_operand(); } break; } case ExprKindCase::kCallExpr: { CallExpr& call = cur->mutable_call_expr(); if (call.has_target()) { call.mutable_target().SetNext(tag, *tail); *tail = &call.mutable_target(); } for (auto& arg : call.mutable_args()) { arg.SetNext(tag, *tail); *tail = &arg; } break; } case ExprKindCase::kListExpr: { for (auto& arg : cur->mutable_list_expr().mutable_elements()) { arg.mutable_expr().SetNext(tag, *tail); *tail = &arg.mutable_expr(); } break; } case ExprKindCase::kStructExpr: { for (auto& field : cur->mutable_struct_expr().mutable_fields()) { field.mutable_value().SetNext(tag, *tail); *tail = &field.mutable_value(); } break; } case ExprKindCase::kMapExpr: { for (auto& entry : cur->mutable_map_expr().mutable_entries()) { entry.mutable_key().SetNext(tag, *tail); *tail = &entry.mutable_key(); entry.mutable_value().SetNext(tag, *tail); *tail = &entry.mutable_value(); } break; } case ExprKindCase::kComprehensionExpr: { if (cur->comprehension_expr().has_accu_init()) { cur->mutable_comprehension_expr().mutable_accu_init().SetNext(tag, *tail); *tail = &cur->mutable_comprehension_expr().mutable_accu_init(); } if (cur->comprehension_expr().has_iter_range()) { cur->mutable_comprehension_expr().mutable_iter_range().SetNext(tag, *tail); *tail = &cur->mutable_comprehension_expr().mutable_iter_range(); } if (cur->comprehension_expr().has_loop_condition()) { cur->mutable_comprehension_expr().mutable_loop_condition().SetNext( tag, *tail); *tail = &cur->mutable_comprehension_expr().mutable_loop_condition(); } if (cur->comprehension_expr().has_loop_step()) { cur->mutable_comprehension_expr().mutable_loop_step().SetNext(tag, *tail); *tail = &cur->mutable_comprehension_expr().mutable_loop_step(); } if (cur->comprehension_expr().has_result()) { cur->mutable_comprehension_expr().mutable_result().SetNext(tag, *tail); *tail = &cur->mutable_comprehension_expr().mutable_result(); } break; } default: // Leaf node, nothing to expand. // Also a fallback in case we add a new node type. // Note: already in the deleter list so we can't delete now, will be // deleted after ordering the AST. break; } } } // namespace void Expr::FlattenedErase() { // High level idea is to build a topological ordering of the AST, then erase // leaf to root. this->u_.next = nullptr; Expr* prev_tail = nullptr; Expr* tail = this; while (tail != prev_tail) { Expr* next_prev_tail = tail; Expr* expand_ptr = tail; while (expand_ptr != prev_tail) { ABSL_DCHECK(expand_ptr != nullptr); // Linked list is broken or changed. Expr* next_expand_ptr = expand_ptr->u_.next; Expand(&tail, expand_ptr); expand_ptr = next_expand_ptr; } prev_tail = next_prev_tail; } Expr* node = tail; while (node != nullptr) { Expr* next = node->u_.next; node->Clear(); node = next; } } } // namespace cel ================================================ FILE: common/expr.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_EXPR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_EXPR_H_ #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "common/constant.h" namespace cel { using ExprId = int64_t; class Expr; class UnspecifiedExpr; class IdentExpr; class SelectExpr; class CallExpr; class ListExprElement; class ListExpr; class StructExprField; class StructExpr; class MapExprEntry; class MapExpr; class ComprehensionExpr; inline constexpr absl::string_view kAccumulatorVariableName = "@result"; inline constexpr absl::string_view kDeprecatedAccumulatorVariableName = "__result__"; bool operator==(const Expr& lhs, const Expr& rhs); inline bool operator!=(const Expr& lhs, const Expr& rhs) { return !operator==(lhs, rhs); } bool operator==(const ListExprElement& lhs, const ListExprElement& rhs); inline bool operator!=(const ListExprElement& lhs, const ListExprElement& rhs) { return !operator==(lhs, rhs); } bool operator==(const StructExprField& lhs, const StructExprField& rhs); inline bool operator!=(const StructExprField& lhs, const StructExprField& rhs) { return !operator==(lhs, rhs); } bool operator==(const MapExprEntry& lhs, const MapExprEntry& rhs); inline bool operator!=(const MapExprEntry& lhs, const MapExprEntry& rhs) { return !operator==(lhs, rhs); } // `UnspecifiedExpr` is the default alternative of `Expr`. It is used for // default construction of `Expr` or as a placeholder for when errors occur. class UnspecifiedExpr final { public: UnspecifiedExpr() = default; UnspecifiedExpr(UnspecifiedExpr&&) = default; UnspecifiedExpr& operator=(UnspecifiedExpr&&) = default; UnspecifiedExpr(const UnspecifiedExpr&) = delete; UnspecifiedExpr& operator=(const UnspecifiedExpr&) = delete; void Clear() {} friend void swap(UnspecifiedExpr&, UnspecifiedExpr&) noexcept {} private: friend class Expr; static const UnspecifiedExpr& default_instance(); }; inline bool operator==(const UnspecifiedExpr&, const UnspecifiedExpr&) { return true; } inline bool operator!=(const UnspecifiedExpr& lhs, const UnspecifiedExpr& rhs) { return !operator==(lhs, rhs); } // `IdentExpr` is an alternative of `Expr`, representing an identifier. class IdentExpr final { public: IdentExpr() = default; IdentExpr(IdentExpr&&) = default; IdentExpr& operator=(IdentExpr&&) = default; explicit IdentExpr(std::string name) { set_name(std::move(name)); } explicit IdentExpr(absl::string_view name) { set_name(name); } explicit IdentExpr(const char* name) { set_name(name); } IdentExpr(const IdentExpr&) = delete; IdentExpr& operator=(const IdentExpr&) = delete; void Clear() { name_.clear(); } // Holds a single, unqualified identifier, possibly preceded by a '.'. // // Qualified names are represented by the [Expr.Select][] expression. ABSL_MUST_USE_RESULT const std::string& name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } void set_name(std::string name) { name_ = std::move(name); } void set_name(absl::string_view name) { name_.assign(name.data(), name.size()); } void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } ABSL_MUST_USE_RESULT std::string release_name() { return release(name_); } friend void swap(IdentExpr& lhs, IdentExpr& rhs) noexcept { using std::swap; swap(lhs.name_, rhs.name_); } private: friend class Expr; static const IdentExpr& default_instance(); static std::string release(std::string& property) { std::string result; result.swap(property); return result; } std::string name_; }; inline bool operator==(const IdentExpr& lhs, const IdentExpr& rhs) { return lhs.name() == rhs.name(); } inline bool operator!=(const IdentExpr& lhs, const IdentExpr& rhs) { return !operator==(lhs, rhs); } // `SelectExpr` is an alternative of `Expr`, representing field access. class SelectExpr final { public: SelectExpr() = default; SelectExpr(SelectExpr&&) = default; SelectExpr& operator=(SelectExpr&&) = default; SelectExpr(const SelectExpr&) = delete; SelectExpr& operator=(const SelectExpr&) = delete; void Clear(); ABSL_MUST_USE_RESULT bool has_operand() const { return operand_ != nullptr; } // The target of the selection expression. // // For example, in the select expression `request.auth`, the `request` // portion of the expression is the `operand`. ABSL_MUST_USE_RESULT const Expr& operand() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Expr& mutable_operand() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_operand(Expr operand); void set_operand(std::unique_ptr operand); ABSL_MUST_USE_RESULT std::unique_ptr release_operand(); // The name of the field to select. // // For example, in the select expression `request.auth`, the `auth` portion // of the expression would be the `field`. ABSL_MUST_USE_RESULT const std::string& field() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return field_; } void set_field(std::string field) { field_ = std::move(field); } void set_field(absl::string_view field) { field_.assign(field.data(), field.size()); } void set_field(const char* field) { set_field(absl::NullSafeStringView(field)); } ABSL_MUST_USE_RESULT std::string release_field() { return release(field_); } // Whether the select is to be interpreted as a field presence test. // // This results from the macro `has(request.auth)`. ABSL_MUST_USE_RESULT bool test_only() const { return test_only_; } void set_test_only(bool test_only) { test_only_ = test_only; } friend void swap(SelectExpr& lhs, SelectExpr& rhs) noexcept { using std::swap; swap(lhs.operand_, rhs.operand_); swap(lhs.field_, rhs.field_); swap(lhs.test_only_, rhs.test_only_); } private: friend class Expr; static const SelectExpr& default_instance(); static std::string release(std::string& property) { std::string result; result.swap(property); return result; } static std::unique_ptr release(std::unique_ptr& property); std::unique_ptr operand_; std::string field_; bool test_only_ = false; }; inline bool operator==(const SelectExpr& lhs, const SelectExpr& rhs) { return lhs.operand() == rhs.operand() && lhs.field() == rhs.field() && lhs.test_only() == rhs.test_only(); } inline bool operator!=(const SelectExpr& lhs, const SelectExpr& rhs) { return !operator==(lhs, rhs); } // `CallExpr` is an alternative of `Expr`, representing a function call. class CallExpr final { public: CallExpr() = default; CallExpr(CallExpr&&) = default; CallExpr& operator=(CallExpr&&) = default; CallExpr(const CallExpr&) = delete; CallExpr& operator=(const CallExpr&) = delete; void Clear(); // The name of the function or method being called. ABSL_MUST_USE_RESULT const std::string& function() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return function_; } void set_function(std::string function) { function_ = std::move(function); } void set_function(absl::string_view function) { function_.assign(function.data(), function.size()); } void set_function(const char* function) { set_function(absl::NullSafeStringView(function)); } ABSL_MUST_USE_RESULT std::string release_function() { return release(function_); } ABSL_MUST_USE_RESULT bool has_target() const { return target_ != nullptr; } // The target of an method call-style expression. For example, `x` in `x.f()`. ABSL_MUST_USE_RESULT const Expr& target() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Expr& mutable_target() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_target(Expr target); void set_target(std::unique_ptr target); ABSL_MUST_USE_RESULT std::unique_ptr release_target(); // The arguments. ABSL_MUST_USE_RESULT const std::vector& args() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return args_; } ABSL_MUST_USE_RESULT std::vector& mutable_args() ABSL_ATTRIBUTE_LIFETIME_BOUND { return args_; } void set_args(std::vector args); void set_args(absl::Span args); Expr& add_args() ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_MUST_USE_RESULT std::vector release_args(); friend void swap(CallExpr& lhs, CallExpr& rhs) noexcept { using std::swap; swap(lhs.function_, rhs.function_); swap(lhs.target_, rhs.target_); swap(lhs.args_, rhs.args_); } private: friend class Expr; static const CallExpr& default_instance(); static std::string release(std::string& property) { std::string result; result.swap(property); return result; } static std::unique_ptr release(std::unique_ptr& property); std::string function_; std::unique_ptr target_; std::vector args_; }; bool operator==(const CallExpr& lhs, const CallExpr& rhs); inline bool operator!=(const CallExpr& lhs, const CallExpr& rhs) { return !operator==(lhs, rhs); } // `ListExprElement` represents an element in `ListExpr`. class ListExprElement final { public: ListExprElement() = default; ListExprElement(ListExprElement&&) = default; ListExprElement& operator=(ListExprElement&&) = default; ListExprElement(const ListExprElement&) = delete; ListExprElement& operator=(const ListExprElement&) = delete; void Clear(); ABSL_MUST_USE_RESULT bool has_expr() const { return expr_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_MUST_USE_RESULT Expr& mutable_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_expr(Expr expr); void set_expr(std::unique_ptr expr); ABSL_MUST_USE_RESULT Expr release_expr(); ABSL_MUST_USE_RESULT bool optional() const { return optional_; } void set_optional(bool optional) { optional_ = optional; } friend void swap(ListExprElement& lhs, ListExprElement& rhs) noexcept; private: static Expr release(std::unique_ptr& property); std::unique_ptr expr_; bool optional_ = false; }; // `ListExpr` is an alternative of `Expr`, representing a list. class ListExpr final { public: ListExpr() = default; ListExpr(ListExpr&&) = default; ListExpr& operator=(ListExpr&&) = default; ListExpr(const ListExpr&) = delete; ListExpr& operator=(const ListExpr&) = delete; void Clear(); // The elements of the list. ABSL_MUST_USE_RESULT const std::vector& elements() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return elements_; } ABSL_MUST_USE_RESULT std::vector& mutable_elements() ABSL_ATTRIBUTE_LIFETIME_BOUND { return elements_; } void set_elements(std::vector elements); void set_elements(absl::Span elements); ListExprElement& add_elements() ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_MUST_USE_RESULT std::vector release_elements(); friend void swap(ListExpr& lhs, ListExpr& rhs) noexcept { using std::swap; swap(lhs.elements_, rhs.elements_); } private: friend class Expr; static const ListExpr& default_instance(); std::vector elements_; }; bool operator==(const ListExpr& lhs, const ListExpr& rhs); inline bool operator!=(const ListExpr& lhs, const ListExpr& rhs) { return !operator==(lhs, rhs); } // `StructExprField` represents a field in `StructExpr`. class StructExprField final { public: StructExprField() = default; StructExprField(StructExprField&&) = default; StructExprField& operator=(StructExprField&&) = default; StructExprField(const StructExprField&) = delete; StructExprField& operator=(const StructExprField&) = delete; void Clear(); ABSL_MUST_USE_RESULT ExprId id() const { return id_; } void set_id(ExprId id) { id_ = id; } ABSL_MUST_USE_RESULT const std::string& name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } void set_name(std::string name) { name_ = std::move(name); } void set_name(absl::string_view name) { name_.assign(name.data(), name.size()); } void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } ABSL_MUST_USE_RESULT std::string release_name() { return std::move(name_); } ABSL_MUST_USE_RESULT bool has_value() const { return value_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_MUST_USE_RESULT Expr& mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_value(Expr value); void set_value(std::unique_ptr value); ABSL_MUST_USE_RESULT Expr release_value(); ABSL_MUST_USE_RESULT bool optional() const { return optional_; } void set_optional(bool optional) { optional_ = optional; } friend void swap(StructExprField& lhs, StructExprField& rhs) noexcept; private: static Expr release(std::unique_ptr& property); ExprId id_ = 0; std::string name_; std::unique_ptr value_; bool optional_ = false; }; // `StructExpr` is an alternative of `Expr`, representing a struct. class StructExpr final { public: StructExpr() = default; StructExpr(StructExpr&&) = default; StructExpr& operator=(StructExpr&&) = default; StructExpr(const StructExpr&) = delete; StructExpr& operator=(const StructExpr&) = delete; void Clear(); // The type name of the struct to be created. ABSL_MUST_USE_RESULT const std::string& name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } void set_name(std::string name) { name_ = std::move(name); } void set_name(absl::string_view name) { name_.assign(name.data(), name.size()); } void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } ABSL_MUST_USE_RESULT std::string release_name() { return release(name_); } // The fields of the struct. ABSL_MUST_USE_RESULT const std::vector& fields() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return fields_; } ABSL_MUST_USE_RESULT std::vector& mutable_fields() ABSL_ATTRIBUTE_LIFETIME_BOUND { return fields_; } void set_fields(std::vector fields); void set_fields(absl::Span fields); StructExprField& add_fields() ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_MUST_USE_RESULT std::vector release_fields(); friend void swap(StructExpr& lhs, StructExpr& rhs) noexcept { using std::swap; swap(lhs.name_, rhs.name_); swap(lhs.fields_, rhs.fields_); } private: friend class Expr; static const StructExpr& default_instance(); static std::string release(std::string& property) { std::string result; result.swap(property); return result; } std::string name_; std::vector fields_; }; bool operator==(const StructExpr& lhs, const StructExpr& rhs); inline bool operator!=(const StructExpr& lhs, const StructExpr& rhs) { return !operator==(lhs, rhs); } // `MapExprEntry` represents an entry in `MapExpr`. class MapExprEntry final { public: MapExprEntry() = default; MapExprEntry(MapExprEntry&&) = default; MapExprEntry& operator=(MapExprEntry&&) = default; MapExprEntry(const MapExprEntry&) = delete; MapExprEntry& operator=(const MapExprEntry&) = delete; void Clear(); ABSL_MUST_USE_RESULT ExprId id() const { return id_; } void set_id(ExprId id) { id_ = id; } ABSL_MUST_USE_RESULT bool has_key() const { return key_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& key() const ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_MUST_USE_RESULT Expr& mutable_key() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_key(Expr key); void set_key(std::unique_ptr key); ABSL_MUST_USE_RESULT Expr release_key(); ABSL_MUST_USE_RESULT bool has_value() const { return value_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_MUST_USE_RESULT Expr& mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_value(Expr value); void set_value(std::unique_ptr value); ABSL_MUST_USE_RESULT Expr release_value(); ABSL_MUST_USE_RESULT bool optional() const { return optional_; } void set_optional(bool optional) { optional_ = optional; } friend void swap(MapExprEntry& lhs, MapExprEntry& rhs) noexcept; private: static Expr release(std::unique_ptr& property); ExprId id_ = 0; std::unique_ptr key_; std::unique_ptr value_; bool optional_ = false; }; // `MapExpr` is an alternative of `Expr`, representing a map. class MapExpr final { public: MapExpr() = default; MapExpr(MapExpr&&) = default; MapExpr& operator=(MapExpr&&) = default; MapExpr(const MapExpr&) = delete; MapExpr& operator=(const MapExpr&) = delete; void Clear(); // The entries of the map. ABSL_MUST_USE_RESULT const std::vector& entries() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return entries_; } ABSL_MUST_USE_RESULT std::vector& mutable_entries() ABSL_ATTRIBUTE_LIFETIME_BOUND { return entries_; } void set_entries(std::vector entries); void set_entries(absl::Span entries); MapExprEntry& add_entries() ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_MUST_USE_RESULT std::vector release_entries(); friend void swap(MapExpr& lhs, MapExpr& rhs) noexcept { using std::swap; swap(lhs.entries_, rhs.entries_); } private: friend class Expr; static const MapExpr& default_instance(); std::vector entries_; }; bool operator==(const MapExpr& lhs, const MapExpr& rhs); inline bool operator!=(const MapExpr& lhs, const MapExpr& rhs) { return !operator==(lhs, rhs); } // `ComprehensionExpr` is an alternative of `Expr`, representing a // comprehension. These are always synthetic as there is no way to express them // directly in the Common Expression Language, and are created by macros. class ComprehensionExpr final { public: ComprehensionExpr() = default; ComprehensionExpr(ComprehensionExpr&&) = default; ComprehensionExpr& operator=(ComprehensionExpr&&) = default; ComprehensionExpr(const ComprehensionExpr&) = delete; ComprehensionExpr& operator=(const ComprehensionExpr&) = delete; void Clear(); ABSL_MUST_USE_RESULT const std::string& iter_var() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return iter_var_; } void set_iter_var(std::string iter_var) { iter_var_ = std::move(iter_var); } void set_iter_var(absl::string_view iter_var) { iter_var_.assign(iter_var.data(), iter_var.size()); } void set_iter_var(const char* iter_var) { set_iter_var(absl::NullSafeStringView(iter_var)); } ABSL_MUST_USE_RESULT std::string release_iter_var() { return release(iter_var_); } ABSL_MUST_USE_RESULT const std::string& iter_var2() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return iter_var2_; } void set_iter_var2(std::string iter_var2) { iter_var2_ = std::move(iter_var2); } void set_iter_var2(absl::string_view iter_var2) { iter_var2_.assign(iter_var2.data(), iter_var2.size()); } void set_iter_var2(const char* iter_var2) { set_iter_var2(absl::NullSafeStringView(iter_var2)); } ABSL_MUST_USE_RESULT std::string release_iter_var2() { return release(iter_var2_); } ABSL_MUST_USE_RESULT bool has_iter_range() const { return iter_range_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& iter_range() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Expr& mutable_iter_range() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_iter_range(Expr iter_range); void set_iter_range(std::unique_ptr iter_range); ABSL_MUST_USE_RESULT std::unique_ptr release_iter_range(); ABSL_MUST_USE_RESULT const std::string& accu_var() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return accu_var_; } void set_accu_var(std::string accu_var) { accu_var_ = std::move(accu_var); } void set_accu_var(absl::string_view accu_var) { accu_var_.assign(accu_var.data(), accu_var.size()); } void set_accu_var(const char* accu_var) { set_accu_var(absl::NullSafeStringView(accu_var)); } ABSL_MUST_USE_RESULT std::string release_accu_var() { return release(accu_var_); } ABSL_MUST_USE_RESULT bool has_accu_init() const { return accu_init_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& accu_init() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Expr& mutable_accu_init() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_accu_init(Expr accu_init); void set_accu_init(std::unique_ptr accu_init); ABSL_MUST_USE_RESULT std::unique_ptr release_accu_init(); ABSL_MUST_USE_RESULT bool has_loop_condition() const { return loop_condition_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& loop_condition() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Expr& mutable_loop_condition() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_loop_condition(Expr loop_condition); void set_loop_condition(std::unique_ptr loop_condition); ABSL_MUST_USE_RESULT std::unique_ptr release_loop_condition(); ABSL_MUST_USE_RESULT bool has_loop_step() const { return loop_step_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& loop_step() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Expr& mutable_loop_step() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_loop_step(Expr loop_step); void set_loop_step(std::unique_ptr loop_step); ABSL_MUST_USE_RESULT std::unique_ptr release_loop_step(); ABSL_MUST_USE_RESULT bool has_result() const { return result_ != nullptr; } ABSL_MUST_USE_RESULT const Expr& result() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Expr& mutable_result() ABSL_ATTRIBUTE_LIFETIME_BOUND; void set_result(Expr result); void set_result(std::unique_ptr result); ABSL_MUST_USE_RESULT std::unique_ptr release_result(); friend void swap(ComprehensionExpr& lhs, ComprehensionExpr& rhs) noexcept { using std::swap; swap(lhs.iter_var_, rhs.iter_var_); swap(lhs.iter_var2_, rhs.iter_var2_); swap(lhs.iter_range_, rhs.iter_range_); swap(lhs.accu_var_, rhs.accu_var_); swap(lhs.accu_init_, rhs.accu_init_); swap(lhs.loop_condition_, rhs.loop_condition_); swap(lhs.loop_step_, rhs.loop_step_); swap(lhs.result_, rhs.result_); } private: friend class Expr; static const ComprehensionExpr& default_instance(); static std::string release(std::string& property) { std::string result; result.swap(property); return result; } static std::unique_ptr release(std::unique_ptr& property); std::string iter_var_; std::string iter_var2_; std::unique_ptr iter_range_; std::string accu_var_; std::unique_ptr accu_init_; std::unique_ptr loop_condition_; std::unique_ptr loop_step_; std::unique_ptr result_; }; inline bool operator==(const ComprehensionExpr& lhs, const ComprehensionExpr& rhs) { return lhs.iter_var() == rhs.iter_var() && lhs.iter_range() == rhs.iter_range() && lhs.accu_var() == rhs.accu_var() && lhs.accu_init() == rhs.accu_init() && lhs.loop_condition() == rhs.loop_condition() && lhs.loop_step() == rhs.loop_step() && lhs.result() == rhs.result(); } inline bool operator!=(const ComprehensionExpr& lhs, const ComprehensionExpr& rhs) { return !operator==(lhs, rhs); } using ExprKind = absl::variant; enum class ExprKindCase { kUnspecifiedExpr, kConstant, kIdentExpr, kSelectExpr, kCallExpr, kListExpr, kStructExpr, kMapExpr, kComprehensionExpr, }; namespace common_internal { struct ExprEraseTag; } // namespace common_internal // `Expr` is a node in the Common Expression Language's abstract syntax tree. It // is composed of a numeric ID and a kind variant. class Expr final { public: Expr() = default; Expr(Expr&&) = default; Expr& operator=(Expr&&); Expr(const Expr&); Expr& operator=(const Expr&); void Clear(); ABSL_MUST_USE_RESULT ExprId id() const { return u_.id; } void set_id(ExprId id) { u_.id = id; } ABSL_MUST_USE_RESULT const ExprKind& kind() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return kind_; } ABSL_MUST_USE_RESULT ExprKind& mutable_kind() ABSL_ATTRIBUTE_LIFETIME_BOUND { return kind_; } void set_kind(ExprKind kind); ABSL_MUST_USE_RESULT ExprKind release_kind(); ABSL_MUST_USE_RESULT bool has_const_expr() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const Constant& const_expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get_kind(); } Constant& mutable_const_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { return try_emplace_kind(); } void set_const_expr(Constant const_expr) { try_emplace_kind() = std::move(const_expr); } ABSL_MUST_USE_RESULT Constant release_const_expr() { return release_kind(); } ABSL_MUST_USE_RESULT bool has_ident_expr() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const IdentExpr& ident_expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get_kind(); } IdentExpr& mutable_ident_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { return try_emplace_kind(); } void set_ident_expr(IdentExpr ident_expr) { try_emplace_kind() = std::move(ident_expr); } ABSL_MUST_USE_RESULT IdentExpr release_ident_expr() { return release_kind(); } ABSL_MUST_USE_RESULT bool has_select_expr() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const SelectExpr& select_expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get_kind(); } SelectExpr& mutable_select_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { return try_emplace_kind(); } void set_select_expr(SelectExpr select_expr) { try_emplace_kind() = std::move(select_expr); } ABSL_MUST_USE_RESULT SelectExpr release_select_expr() { return release_kind(); } ABSL_MUST_USE_RESULT bool has_call_expr() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const CallExpr& call_expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get_kind(); } CallExpr& mutable_call_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { return try_emplace_kind(); } void set_call_expr(CallExpr call_expr); ABSL_MUST_USE_RESULT CallExpr release_call_expr(); ABSL_MUST_USE_RESULT bool has_list_expr() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const ListExpr& list_expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get_kind(); } ListExpr& mutable_list_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { return try_emplace_kind(); } void set_list_expr(ListExpr list_expr); ABSL_MUST_USE_RESULT ListExpr release_list_expr(); ABSL_MUST_USE_RESULT bool has_struct_expr() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const StructExpr& struct_expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get_kind(); } StructExpr& mutable_struct_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { return try_emplace_kind(); } void set_struct_expr(StructExpr struct_expr); ABSL_MUST_USE_RESULT StructExpr release_struct_expr(); ABSL_MUST_USE_RESULT bool has_map_expr() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const MapExpr& map_expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get_kind(); } MapExpr& mutable_map_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { return try_emplace_kind(); } void set_map_expr(MapExpr map_expr); ABSL_MUST_USE_RESULT MapExpr release_map_expr(); ABSL_MUST_USE_RESULT bool has_comprehension_expr() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const ComprehensionExpr& comprehension_expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get_kind(); } ComprehensionExpr& mutable_comprehension_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { return try_emplace_kind(); } void set_comprehension_expr(ComprehensionExpr comprehension_expr) { try_emplace_kind() = std::move(comprehension_expr); } ABSL_MUST_USE_RESULT ComprehensionExpr release_comprehension_expr() { return release_kind(); } ExprKindCase kind_case() const; friend void swap(Expr& lhs, Expr& rhs) noexcept; // Erases the expr in place without recursion. void FlattenedErase(); inline void SetNext(common_internal::ExprEraseTag&, Expr* next); private: friend class IdentExpr; friend class SelectExpr; friend class CallExpr; friend class ListExpr; friend class StructExpr; friend class MapExpr; friend class ComprehensionExpr; friend class ListExprElement; friend class StructExprField; friend class MapExprEntry; static const Expr& default_instance(); template ABSL_MUST_USE_RESULT T& try_emplace_kind(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { if (auto* alt = absl::get_if(&mutable_kind()); alt) { return *alt; } return kind_.emplace(std::forward(args)...); } template ABSL_MUST_USE_RESULT const T& get_kind() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (const auto* alt = absl::get_if(&kind()); alt) { return *alt; } return T::default_instance(); } template ABSL_MUST_USE_RESULT T release_kind(); union { ExprId id = 0; // Intrusive pointer to the next element in the destructor chain. // Only set from FlattenedErase. Expr* next; } u_; ExprKind kind_; }; inline bool operator==(const Expr& lhs, const Expr& rhs) { return lhs.id() == rhs.id() && lhs.kind() == rhs.kind(); } inline bool operator==(const CallExpr& lhs, const CallExpr& rhs) { return lhs.function() == rhs.function() && lhs.target() == rhs.target() && absl::c_equal(lhs.args(), rhs.args()); } inline void SelectExpr::Clear() { operand_.reset(); field_.clear(); test_only_ = false; } ABSL_MUST_USE_RESULT inline std::unique_ptr SelectExpr::release_operand() { return release(operand_); } inline const Expr& SelectExpr::operand() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_operand() ? *operand_ : Expr::default_instance(); } inline Expr& SelectExpr::mutable_operand() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_operand()) { operand_ = std::make_unique(); } return *operand_; } inline void SelectExpr::set_operand(Expr operand) { mutable_operand() = std::move(operand); } inline void SelectExpr::set_operand(std::unique_ptr operand) { operand_ = std::move(operand); } inline std::unique_ptr SelectExpr::release( std::unique_ptr& property) { std::unique_ptr result; result.swap(property); return result; } inline void ComprehensionExpr::Clear() { iter_var_.clear(); iter_range_.reset(); accu_var_.clear(); accu_init_.reset(); loop_condition_.reset(); loop_step_.reset(); result_.reset(); } inline const Expr& ComprehensionExpr::iter_range() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_iter_range() ? *iter_range_ : Expr::default_instance(); } inline Expr& ComprehensionExpr::mutable_iter_range() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_iter_range()) { iter_range_ = std::make_unique(); } return *iter_range_; } inline void ComprehensionExpr::set_iter_range(Expr iter_range) { mutable_iter_range() = std::move(iter_range); } inline void ComprehensionExpr::set_iter_range( std::unique_ptr iter_range) { iter_range_ = std::move(iter_range); } ABSL_MUST_USE_RESULT inline std::unique_ptr ComprehensionExpr::release_iter_range() { return release(iter_range_); } inline const Expr& ComprehensionExpr::accu_init() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_accu_init() ? *accu_init_ : Expr::default_instance(); } ABSL_MUST_USE_RESULT inline std::unique_ptr ComprehensionExpr::release_accu_init() { return release(accu_init_); } inline Expr& ComprehensionExpr::mutable_accu_init() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_accu_init()) { accu_init_ = std::make_unique(); } return *accu_init_; } inline void ComprehensionExpr::set_accu_init(Expr accu_init) { mutable_accu_init() = std::move(accu_init); } inline void ComprehensionExpr::set_accu_init(std::unique_ptr accu_init) { accu_init_ = std::move(accu_init); } inline const Expr& ComprehensionExpr::loop_step() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_loop_step() ? *loop_step_ : Expr::default_instance(); } inline Expr& ComprehensionExpr::mutable_loop_step() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_loop_step()) { loop_step_ = std::make_unique(); } return *loop_step_; } inline void ComprehensionExpr::set_loop_step(Expr loop_step) { mutable_loop_step() = std::move(loop_step); } inline void ComprehensionExpr::set_loop_step(std::unique_ptr loop_step) { loop_step_ = std::move(loop_step); } ABSL_MUST_USE_RESULT inline std::unique_ptr ComprehensionExpr::release_loop_step() { return release(loop_step_); } inline const Expr& ComprehensionExpr::loop_condition() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_loop_condition() ? *loop_condition_ : Expr::default_instance(); } ABSL_MUST_USE_RESULT inline std::unique_ptr ComprehensionExpr::release_loop_condition() { return release(loop_condition_); } inline Expr& ComprehensionExpr::mutable_loop_condition() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_loop_condition()) { loop_condition_ = std::make_unique(); } return *loop_condition_; } inline void ComprehensionExpr::set_loop_condition(Expr loop_condition) { mutable_loop_condition() = std::move(loop_condition); } inline void ComprehensionExpr::set_loop_condition( std::unique_ptr loop_condition) { loop_condition_ = std::move(loop_condition); } inline const Expr& ComprehensionExpr::result() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_result() ? *result_ : Expr::default_instance(); } inline Expr& ComprehensionExpr::mutable_result() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_result()) { result_ = std::make_unique(); } return *result_; } inline void ComprehensionExpr::set_result(Expr result) { mutable_result() = std::move(result); } inline void ComprehensionExpr::set_result(std::unique_ptr result) { result_ = std::move(result); } ABSL_MUST_USE_RESULT inline std::unique_ptr ComprehensionExpr::release_result() { return release(result_); } inline std::unique_ptr ComprehensionExpr::release( std::unique_ptr& property) { std::unique_ptr result; result.swap(property); return result; } inline bool operator==(const ListExprElement& lhs, const ListExprElement& rhs) { return lhs.expr() == rhs.expr() && lhs.optional() == rhs.optional(); } inline bool operator==(const ListExpr& lhs, const ListExpr& rhs) { return absl::c_equal(lhs.elements(), rhs.elements()); } inline bool operator==(const StructExprField& lhs, const StructExprField& rhs) { return lhs.id() == rhs.id() && lhs.name() == rhs.name() && lhs.value() == rhs.value() && lhs.optional() == rhs.optional(); } inline bool operator==(const StructExpr& lhs, const StructExpr& rhs) { return lhs.name() == rhs.name() && absl::c_equal(lhs.fields(), rhs.fields()); } inline bool operator==(const MapExprEntry& lhs, const MapExprEntry& rhs) { return lhs.id() == rhs.id() && lhs.key() == rhs.key() && lhs.value() == rhs.value() && lhs.optional() == rhs.optional(); } inline bool operator==(const MapExpr& lhs, const MapExpr& rhs) { return absl::c_equal(lhs.entries(), rhs.entries()); } inline void MapExpr::Clear() { entries_.clear(); } inline void MapExpr::set_entries(std::vector entries) { entries_ = std::move(entries); } inline void MapExpr::set_entries(absl::Span entries) { entries_.clear(); entries_.reserve(entries.size()); for (auto& entry : entries) { entries_.push_back(std::move(entry)); } } inline MapExprEntry& MapExpr::add_entries() ABSL_ATTRIBUTE_LIFETIME_BOUND { return mutable_entries().emplace_back(); } inline std::vector MapExpr::release_entries() { std::vector entries; entries.swap(entries_); return entries; } inline void Expr::Clear() { u_.id = 0; mutable_kind().emplace(); } inline Expr& Expr::operator=(Expr&&) = default; inline void Expr::set_kind(ExprKind kind) { kind_ = std::move(kind); } inline ABSL_MUST_USE_RESULT ExprKind Expr::release_kind() { ExprKind kind = std::move(kind_); kind_.emplace(); return kind; } inline void Expr::set_call_expr(CallExpr call_expr) { try_emplace_kind() = std::move(call_expr); } inline ABSL_MUST_USE_RESULT CallExpr Expr::release_call_expr() { return release_kind(); } inline void Expr::set_list_expr(ListExpr list_expr) { try_emplace_kind() = std::move(list_expr); } inline ListExpr Expr::release_list_expr() { return release_kind(); } inline void Expr::set_struct_expr(StructExpr struct_expr) { try_emplace_kind() = std::move(struct_expr); } inline StructExpr Expr::release_struct_expr() { return release_kind(); } inline void Expr::set_map_expr(MapExpr map_expr) { try_emplace_kind() = std::move(map_expr); } inline MapExpr Expr::release_map_expr() { return release_kind(); } template ABSL_MUST_USE_RESULT T Expr::release_kind() { T result; if (auto* alt = absl::get_if(&mutable_kind()); alt) { result = std::move(*alt); } kind_.emplace(); return result; } inline ExprKindCase Expr::kind_case() const { static_assert(absl::variant_size_v == 9); if (kind_.index() <= 9) { return static_cast(kind_.index()); } return ExprKindCase::kUnspecifiedExpr; } inline void swap(Expr& lhs, Expr& rhs) noexcept { using std::swap; swap(lhs.u_, rhs.u_); swap(lhs.kind_, rhs.kind_); } inline void CallExpr::Clear() { function_.clear(); target_.reset(); args_.clear(); } inline const Expr& CallExpr::target() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_target() ? *target_ : Expr::default_instance(); } inline Expr& CallExpr::mutable_target() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_target()) { target_ = std::make_unique(); } return *target_; } inline void CallExpr::set_target(Expr target) { mutable_target() = std::move(target); } inline void CallExpr::set_target(std::unique_ptr target) { target_ = std::move(target); } ABSL_MUST_USE_RESULT inline std::unique_ptr CallExpr::release_target() { return release(target_); } inline void CallExpr::set_args(std::vector args) { args_ = std::move(args); } inline void CallExpr::set_args(absl::Span args) { args_.clear(); args_.reserve(args.size()); for (auto& arg : args) { args_.push_back(std::move(arg)); } } inline Expr& CallExpr::add_args() ABSL_ATTRIBUTE_LIFETIME_BOUND { return mutable_args().emplace_back(); } inline std::vector CallExpr::release_args() { std::vector args; args.swap(args_); return args; } inline std::unique_ptr CallExpr::release( std::unique_ptr& property) { std::unique_ptr result; result.swap(property); return result; } inline void ListExprElement::Clear() { expr_.reset(); optional_ = false; } inline ABSL_MUST_USE_RESULT const Expr& ListExprElement::expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_expr() ? *expr_ : Expr::default_instance(); } inline ABSL_MUST_USE_RESULT Expr& ListExprElement::mutable_expr() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_expr()) { expr_ = std::make_unique(); } return *expr_; } inline void ListExprElement::set_expr(Expr expr) { mutable_expr() = std::move(expr); } inline void ListExprElement::set_expr(std::unique_ptr expr) { expr_ = std::move(expr); } inline ABSL_MUST_USE_RESULT Expr ListExprElement::release_expr() { return release(expr_); } inline void swap(ListExprElement& lhs, ListExprElement& rhs) noexcept { using std::swap; swap(lhs.expr_, rhs.expr_); swap(lhs.optional_, rhs.optional_); } inline Expr ListExprElement::release(std::unique_ptr& property) { std::unique_ptr result; result.swap(property); if (result != nullptr) { return std::move(*result); } return Expr{}; } inline void ListExpr::Clear() { elements_.clear(); } inline void ListExpr::set_elements(std::vector elements) { elements_ = std::move(elements); } inline void ListExpr::set_elements(absl::Span elements) { elements_.clear(); elements_.reserve(elements.size()); for (auto& element : elements) { elements_.push_back(std::move(element)); } } inline ListExprElement& ListExpr::add_elements() ABSL_ATTRIBUTE_LIFETIME_BOUND { return mutable_elements().emplace_back(); } inline std::vector ListExpr::release_elements() { std::vector elements; elements.swap(elements_); return elements; } inline void StructExprField::Clear() { id_ = 0; name_.clear(); value_.reset(); optional_ = false; } inline ABSL_MUST_USE_RESULT const Expr& StructExprField::value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_value() ? *value_ : Expr::default_instance(); } inline ABSL_MUST_USE_RESULT Expr& StructExprField::mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_value()) { value_ = std::make_unique(); } return *value_; } inline void StructExprField::set_value(Expr value) { mutable_value() = std::move(value); } inline void StructExprField::set_value(std::unique_ptr value) { value_ = std::move(value); } inline ABSL_MUST_USE_RESULT Expr StructExprField::release_value() { return release(value_); } inline void swap(StructExprField& lhs, StructExprField& rhs) noexcept { using std::swap; swap(lhs.id_, rhs.id_); swap(lhs.name_, rhs.name_); swap(lhs.value_, rhs.value_); swap(lhs.optional_, rhs.optional_); } inline Expr StructExprField::release(std::unique_ptr& property) { std::unique_ptr result; result.swap(property); if (result != nullptr) { return std::move(*result); } return Expr{}; } inline void StructExpr::Clear() { name_.clear(); fields_.clear(); } inline void StructExpr::set_fields(std::vector fields) { fields_ = std::move(fields); } inline void StructExpr::set_fields(absl::Span fields) { fields_.clear(); fields_.reserve(fields.size()); for (auto& field : fields) { fields_.push_back(std::move(field)); } } inline StructExprField& StructExpr::add_fields() ABSL_ATTRIBUTE_LIFETIME_BOUND { return mutable_fields().emplace_back(); } inline std::vector StructExpr::release_fields() { std::vector fields; fields.swap(fields_); return fields; } inline void MapExprEntry::Clear() { id_ = 0; key_.reset(); value_.reset(); optional_ = false; } inline ABSL_MUST_USE_RESULT const Expr& MapExprEntry::key() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_key() ? *key_ : Expr::default_instance(); } inline ABSL_MUST_USE_RESULT Expr& MapExprEntry::mutable_key() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_key()) { key_ = std::make_unique(); } return *key_; } inline void MapExprEntry::set_key(Expr key) { mutable_key() = std::move(key); } inline void MapExprEntry::set_key(std::unique_ptr key) { key_ = std::move(key); } inline ABSL_MUST_USE_RESULT Expr MapExprEntry::release_key() { return release(key_); } inline ABSL_MUST_USE_RESULT const Expr& MapExprEntry::value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return has_value() ? *value_ : Expr::default_instance(); } inline ABSL_MUST_USE_RESULT Expr& MapExprEntry::mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_value()) { value_ = std::make_unique(); } return *value_; } inline void MapExprEntry::set_value(Expr value) { mutable_value() = std::move(value); } inline void MapExprEntry::set_value(std::unique_ptr value) { value_ = std::move(value); } inline ABSL_MUST_USE_RESULT Expr MapExprEntry::release_value() { return release(value_); } inline void swap(MapExprEntry& lhs, MapExprEntry& rhs) noexcept { using std::swap; swap(lhs.id_, rhs.id_); swap(lhs.key_, rhs.key_); swap(lhs.value_, rhs.value_); swap(lhs.optional_, rhs.optional_); } inline Expr MapExprEntry::release(std::unique_ptr& property) { std::unique_ptr result; result.swap(property); if (result != nullptr) { return std::move(*result); } return Expr{}; } inline void Expr::SetNext(common_internal::ExprEraseTag&, Expr* next) { u_.next = next; } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_EXPR_H_ ================================================ FILE: common/expr_factory.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_EXPR_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_COMMON_EXPR_FACTORY_H_ #include #include #include #include #include #include #include #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/constant.h" #include "common/expr.h" namespace cel { class MacroExprFactory; class ParserMacroExprFactory; class ExprFactory { protected: // `IsExprLike` determines whether `T` is some `Expr`. Currently that means // either `Expr` or `std::unique_ptr`. This allows us to make the // factory functions generic and avoid redefining them for every argument // combination. template struct IsExprLike : std::bool_constant, std::is_same>>> {}; // `IsStringLike` determines whether `T` is something that looks like a // string. Currently that means `const char*`, `std::string`, or // `absl::string_view`. This allows us to make the factory functions generic // and avoid redefining them for every argument combination. This is necessary // to avoid copies if possible. template struct IsStringLike : std::bool_constant, std::is_same, std::is_same, std::is_same>> { }; template struct IsStringLike : std::true_type {}; // `IsArrayLike` determines whether `T` is something that looks like an array // or span of some element. template struct IsArrayLike : std::false_type {}; template struct IsArrayLike> : std::true_type {}; template struct IsArrayLike> : std::true_type {}; public: ExprFactory(const ExprFactory&) = delete; ExprFactory(ExprFactory&&) = delete; ExprFactory& operator=(const ExprFactory&) = delete; ExprFactory& operator=(ExprFactory&&) = delete; virtual ~ExprFactory() = default; Expr NewUnspecified(ExprId id) { Expr expr; expr.set_id(id); return expr; } Expr NewConst(ExprId id, Constant value) { Expr expr; expr.set_id(id); expr.mutable_const_expr() = std::move(value); return expr; } Expr NewNullConst(ExprId id) { Constant constant; constant.set_null_value(); return NewConst(id, std::move(constant)); } Expr NewBoolConst(ExprId id, bool value) { Constant constant; constant.set_bool_value(value); return NewConst(id, std::move(constant)); } Expr NewIntConst(ExprId id, int64_t value) { Constant constant; constant.set_int_value(value); return NewConst(id, std::move(constant)); } Expr NewUintConst(ExprId id, uint64_t value) { Constant constant; constant.set_uint_value(value); return NewConst(id, std::move(constant)); } Expr NewDoubleConst(ExprId id, double value) { Constant constant; constant.set_double_value(value); return NewConst(id, std::move(constant)); } Expr NewBytesConst(ExprId id, BytesConstant value) { Constant constant; constant.set_bytes_value(std::move(value)); return NewConst(id, std::move(constant)); } Expr NewBytesConst(ExprId id, std::string value) { Constant constant; constant.set_bytes_value(std::move(value)); return NewConst(id, std::move(constant)); } Expr NewBytesConst(ExprId id, absl::string_view value) { Constant constant; constant.set_bytes_value(value); return NewConst(id, std::move(constant)); } Expr NewBytesConst(ExprId id, const char* value) { Constant constant; constant.set_bytes_value(value); return NewConst(id, std::move(constant)); } Expr NewStringConst(ExprId id, StringConstant value) { Constant constant; constant.set_string_value(std::move(value)); return NewConst(id, std::move(constant)); } Expr NewStringConst(ExprId id, std::string value) { Constant constant; constant.set_string_value(std::move(value)); return NewConst(id, std::move(constant)); } Expr NewStringConst(ExprId id, absl::string_view value) { Constant constant; constant.set_string_value(value); return NewConst(id, std::move(constant)); } Expr NewStringConst(ExprId id, const char* value) { Constant constant; constant.set_string_value(value); return NewConst(id, std::move(constant)); } template ::value>> Expr NewIdent(ExprId id, Name name) { Expr expr; expr.set_id(id); auto& ident_expr = expr.mutable_ident_expr(); ident_expr.set_name(std::move(name)); return expr; } absl::string_view AccuVarName() { return accu_var_; } Expr NewAccuIdent(ExprId id) { return NewIdent(id, AccuVarName()); } template ::value>, typename = std::enable_if_t::value>> Expr NewSelect(ExprId id, Operand operand, Field field) { Expr expr; expr.set_id(id); auto& select_expr = expr.mutable_select_expr(); select_expr.set_operand(std::move(operand)); select_expr.set_field(std::move(field)); select_expr.set_test_only(false); return expr; } template ::value>, typename = std::enable_if_t::value>> Expr NewPresenceTest(ExprId id, Operand operand, Field field) { Expr expr; expr.set_id(id); auto& select_expr = expr.mutable_select_expr(); select_expr.set_operand(std::move(operand)); select_expr.set_field(std::move(field)); select_expr.set_test_only(true); return expr; } template ::value>, typename = std::enable_if_t::value>> Expr NewCall(ExprId id, Function function, Args args) { Expr expr; expr.set_id(id); auto& call_expr = expr.mutable_call_expr(); call_expr.set_function(std::move(function)); call_expr.set_args(std::move(args)); return expr; } template ::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>> Expr NewMemberCall(ExprId id, Function function, Target target, Args args) { Expr expr; expr.set_id(id); auto& call_expr = expr.mutable_call_expr(); call_expr.set_function(std::move(function)); call_expr.set_target(std::move(target)); call_expr.set_args(std::move(args)); return expr; } template ::value>> ListExprElement NewListElement(Expr expr, bool optional = false) { ListExprElement element; element.set_expr(std::move(expr)); element.set_optional(optional); return element; } template ::value>> Expr NewList(ExprId id, Elements elements) { Expr expr; expr.set_id(id); auto& list_expr = expr.mutable_list_expr(); list_expr.set_elements(std::move(elements)); return expr; } template ::value>, typename = std::enable_if_t::value>> StructExprField NewStructField(ExprId id, Name name, Value value, bool optional = false) { StructExprField field; field.set_id(id); field.set_name(std::move(name)); field.set_value(std::move(value)); field.set_optional(optional); return field; } template < typename Name, typename Fields, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>> Expr NewStruct(ExprId id, Name name, Fields fields) { Expr expr; expr.set_id(id); auto& struct_expr = expr.mutable_struct_expr(); struct_expr.set_name(std::move(name)); struct_expr.set_fields(std::move(fields)); return expr; } template ::value>, typename = std::enable_if_t::value>> MapExprEntry NewMapEntry(ExprId id, Key key, Value value, bool optional = false) { MapExprEntry entry; entry.set_id(id); entry.set_key(std::move(key)); entry.set_value(std::move(value)); entry.set_optional(optional); return entry; } template ::value>> Expr NewMap(ExprId id, Entries entries) { Expr expr; expr.set_id(id); auto& map_expr = expr.mutable_map_expr(); map_expr.set_entries(std::move(entries)); return expr; } template ::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>> Expr NewComprehension(ExprId id, IterVar iter_var, IterRange iter_range, AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, LoopStep loop_step, Result result) { return NewComprehension(id, std::move(iter_var), "", std::move(iter_range), std::move(accu_var), std::move(accu_init), std::move(loop_condition), std::move(loop_step), std::move(result)); } template ::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>> Expr NewComprehension(ExprId id, IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, LoopStep loop_step, Result result) { Expr expr; expr.set_id(id); auto& comprehension_expr = expr.mutable_comprehension_expr(); comprehension_expr.set_iter_var(std::move(iter_var)); comprehension_expr.set_iter_var2(std::move(iter_var2)); comprehension_expr.set_iter_range(std::move(iter_range)); comprehension_expr.set_accu_var(std::move(accu_var)); comprehension_expr.set_accu_init(std::move(accu_init)); comprehension_expr.set_loop_condition(std::move(loop_condition)); comprehension_expr.set_loop_step(std::move(loop_step)); comprehension_expr.set_result(std::move(result)); return expr; } private: friend class MacroExprFactory; friend class ParserMacroExprFactory; ExprFactory() : accu_var_(kAccumulatorVariableName) {} std::string accu_var_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_EXPR_FACTORY_H_ ================================================ FILE: common/expr_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/expr.h" #include #include "internal/testing.h" namespace cel { namespace { using ::testing::_; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::IsFalse; using ::testing::IsTrue; using ::testing::SizeIs; using ::testing::VariantWith; Expr MakeUnspecifiedExpr(ExprId id) { Expr expr; expr.set_id(id); return expr; } ListExprElement MakeListExprElement(Expr expr, bool optional = false) { ListExprElement element; element.set_expr(std::move(expr)); element.set_optional(optional); return element; } StructExprField MakeStructExprField(ExprId id, const char* name, Expr value, bool optional = false) { StructExprField field; field.set_id(id); field.set_name(name); field.set_value(std::move(value)); field.set_optional(optional); return field; } MapExprEntry MakeMapExprEntry(ExprId id, Expr key, Expr value, bool optional = false) { MapExprEntry entry; entry.set_id(id); entry.set_key(std::move(key)); entry.set_value(std::move(value)); entry.set_optional(optional); return entry; } TEST(UnspecifiedExpr, Equality) { EXPECT_EQ(UnspecifiedExpr{}, UnspecifiedExpr{}); } TEST(IdentExpr, Name) { IdentExpr ident_expr; EXPECT_THAT(ident_expr.name(), IsEmpty()); ident_expr.set_name("foo"); EXPECT_THAT(ident_expr.name(), Eq("foo")); auto name = ident_expr.release_name(); EXPECT_THAT(name, Eq("foo")); EXPECT_THAT(ident_expr.name(), IsEmpty()); } TEST(IdentExpr, Equality) { EXPECT_EQ(IdentExpr{}, IdentExpr{}); IdentExpr ident_expr; ident_expr.set_name("foo"); EXPECT_NE(IdentExpr{}, ident_expr); } TEST(SelectExpr, Operand) { SelectExpr select_expr; EXPECT_THAT(select_expr.has_operand(), IsFalse()); EXPECT_EQ(select_expr.operand(), Expr{}); select_expr.set_operand(MakeUnspecifiedExpr(1)); EXPECT_THAT(select_expr.has_operand(), IsTrue()); EXPECT_EQ(select_expr.operand(), MakeUnspecifiedExpr(1)); auto operand = select_expr.release_operand(); EXPECT_THAT(select_expr.has_operand(), IsFalse()); EXPECT_EQ(select_expr.operand(), Expr{}); } TEST(SelectExpr, Field) { SelectExpr select_expr; EXPECT_THAT(select_expr.field(), IsEmpty()); select_expr.set_field("foo"); EXPECT_THAT(select_expr.field(), Eq("foo")); auto field = select_expr.release_field(); EXPECT_THAT(field, Eq("foo")); EXPECT_THAT(select_expr.field(), IsEmpty()); } TEST(SelectExpr, TestOnly) { SelectExpr select_expr; EXPECT_THAT(select_expr.test_only(), IsFalse()); select_expr.set_test_only(true); EXPECT_THAT(select_expr.test_only(), IsTrue()); } TEST(SelectExpr, Equality) { EXPECT_EQ(SelectExpr{}, SelectExpr{}); SelectExpr select_expr; select_expr.set_test_only(true); EXPECT_NE(SelectExpr{}, select_expr); } TEST(CallExpr, Function) { CallExpr call_expr; EXPECT_THAT(call_expr.function(), IsEmpty()); call_expr.set_function("foo"); EXPECT_THAT(call_expr.function(), Eq("foo")); auto function = call_expr.release_function(); EXPECT_THAT(function, Eq("foo")); EXPECT_THAT(call_expr.function(), IsEmpty()); } TEST(CallExpr, Target) { CallExpr call_expr; EXPECT_THAT(call_expr.has_target(), IsFalse()); EXPECT_EQ(call_expr.target(), Expr{}); call_expr.set_target(MakeUnspecifiedExpr(1)); EXPECT_THAT(call_expr.has_target(), IsTrue()); EXPECT_EQ(call_expr.target(), MakeUnspecifiedExpr(1)); auto operand = call_expr.release_target(); EXPECT_THAT(call_expr.has_target(), IsFalse()); EXPECT_EQ(call_expr.target(), Expr{}); } TEST(CallExpr, Args) { CallExpr call_expr; EXPECT_THAT(call_expr.args(), IsEmpty()); call_expr.mutable_args().push_back(MakeUnspecifiedExpr(1)); ASSERT_THAT(call_expr.args(), SizeIs(1)); EXPECT_EQ(call_expr.args()[0], MakeUnspecifiedExpr(1)); auto args = call_expr.release_args(); static_cast(args); EXPECT_THAT(call_expr.args(), IsEmpty()); } TEST(CallExpr, Equality) { EXPECT_EQ(CallExpr{}, CallExpr{}); CallExpr call_expr; call_expr.mutable_args().push_back(MakeUnspecifiedExpr(1)); EXPECT_NE(CallExpr{}, call_expr); } TEST(ListExprElement, Expr) { ListExprElement element; EXPECT_THAT(element.has_expr(), IsFalse()); EXPECT_EQ(element.expr(), Expr{}); element.set_expr(MakeUnspecifiedExpr(1)); EXPECT_THAT(element.has_expr(), IsTrue()); EXPECT_EQ(element.expr(), MakeUnspecifiedExpr(1)); auto operand = element.release_expr(); EXPECT_THAT(element.has_expr(), IsFalse()); EXPECT_EQ(element.expr(), Expr{}); } TEST(ListExprElement, Optional) { ListExprElement element; EXPECT_THAT(element.optional(), IsFalse()); element.set_optional(true); EXPECT_THAT(element.optional(), IsTrue()); } TEST(ListExprElement, Equality) { EXPECT_EQ(ListExprElement{}, ListExprElement{}); ListExprElement element; element.set_optional(true); EXPECT_NE(ListExprElement{}, element); } TEST(ListExpr, Elements) { ListExpr list_expr; EXPECT_THAT(list_expr.elements(), IsEmpty()); list_expr.mutable_elements().push_back( MakeListExprElement(MakeUnspecifiedExpr(1))); ASSERT_THAT(list_expr.elements(), SizeIs(1)); EXPECT_EQ(list_expr.elements()[0], MakeListExprElement(MakeUnspecifiedExpr(1))); auto elements = list_expr.release_elements(); static_cast(elements); EXPECT_THAT(list_expr.elements(), IsEmpty()); } TEST(ListExpr, Equality) { EXPECT_EQ(ListExpr{}, ListExpr{}); ListExpr list_expr; list_expr.mutable_elements().push_back( MakeListExprElement(MakeUnspecifiedExpr(0), true)); EXPECT_NE(ListExpr{}, list_expr); } TEST(StructExprField, Id) { StructExprField field; EXPECT_THAT(field.id(), Eq(0)); field.set_id(1); EXPECT_THAT(field.id(), Eq(1)); } TEST(StructExprField, Name) { StructExprField field; EXPECT_THAT(field.name(), IsEmpty()); field.set_name("foo"); EXPECT_THAT(field.name(), Eq("foo")); auto name = field.release_name(); EXPECT_THAT(name, Eq("foo")); EXPECT_THAT(field.name(), IsEmpty()); } TEST(StructExprField, Value) { StructExprField field; EXPECT_THAT(field.has_value(), IsFalse()); EXPECT_EQ(field.value(), Expr{}); field.set_value(MakeUnspecifiedExpr(1)); EXPECT_THAT(field.has_value(), IsTrue()); EXPECT_EQ(field.value(), MakeUnspecifiedExpr(1)); auto value = field.release_value(); EXPECT_THAT(field.has_value(), IsFalse()); EXPECT_EQ(field.value(), Expr{}); } TEST(StructExprField, Optional) { StructExprField field; EXPECT_THAT(field.optional(), IsFalse()); field.set_optional(true); EXPECT_THAT(field.optional(), IsTrue()); } TEST(StructExprField, Equality) { EXPECT_EQ(StructExprField{}, StructExprField{}); StructExprField field; field.set_optional(true); EXPECT_NE(StructExprField{}, field); } TEST(StructExpr, Name) { StructExpr struct_expr; EXPECT_THAT(struct_expr.name(), IsEmpty()); struct_expr.set_name("foo"); EXPECT_THAT(struct_expr.name(), Eq("foo")); auto name = struct_expr.release_name(); EXPECT_THAT(name, Eq("foo")); EXPECT_THAT(struct_expr.name(), IsEmpty()); } TEST(StructExpr, Fields) { StructExpr struct_expr; EXPECT_THAT(struct_expr.fields(), IsEmpty()); struct_expr.mutable_fields().push_back( MakeStructExprField(1, "foo", MakeUnspecifiedExpr(1))); ASSERT_THAT(struct_expr.fields(), SizeIs(1)); EXPECT_EQ(struct_expr.fields()[0], MakeStructExprField(1, "foo", MakeUnspecifiedExpr(1))); auto fields = struct_expr.release_fields(); static_cast(fields); EXPECT_THAT(struct_expr.fields(), IsEmpty()); } TEST(StructExpr, Equality) { EXPECT_EQ(StructExpr{}, StructExpr{}); StructExpr struct_expr; struct_expr.mutable_fields().push_back( MakeStructExprField(0, "", MakeUnspecifiedExpr(0), true)); EXPECT_NE(StructExpr{}, struct_expr); } TEST(MapExprEntry, Id) { MapExprEntry entry; EXPECT_THAT(entry.id(), Eq(0)); entry.set_id(1); EXPECT_THAT(entry.id(), Eq(1)); } TEST(MapExprEntry, Key) { MapExprEntry entry; EXPECT_THAT(entry.has_key(), IsFalse()); EXPECT_EQ(entry.key(), Expr{}); entry.set_key(MakeUnspecifiedExpr(1)); EXPECT_THAT(entry.has_key(), IsTrue()); EXPECT_EQ(entry.key(), MakeUnspecifiedExpr(1)); auto key = entry.release_key(); static_cast(key); EXPECT_THAT(entry.has_key(), IsFalse()); EXPECT_EQ(entry.key(), Expr{}); } TEST(MapExprEntry, Value) { MapExprEntry entry; EXPECT_THAT(entry.has_value(), IsFalse()); EXPECT_EQ(entry.value(), Expr{}); entry.set_value(MakeUnspecifiedExpr(1)); EXPECT_THAT(entry.has_value(), IsTrue()); EXPECT_EQ(entry.value(), MakeUnspecifiedExpr(1)); auto value = entry.release_value(); static_cast(value); EXPECT_THAT(entry.has_value(), IsFalse()); EXPECT_EQ(entry.value(), Expr{}); } TEST(MapExprEntry, Optional) { MapExprEntry entry; EXPECT_THAT(entry.optional(), IsFalse()); entry.set_optional(true); EXPECT_THAT(entry.optional(), IsTrue()); } TEST(MapExprEntry, Equality) { EXPECT_EQ(StructExprField{}, StructExprField{}); StructExprField field; field.set_optional(true); EXPECT_NE(StructExprField{}, field); } TEST(MapExpr, Entries) { MapExpr map_expr; EXPECT_THAT(map_expr.entries(), IsEmpty()); map_expr.mutable_entries().push_back( MakeMapExprEntry(1, MakeUnspecifiedExpr(1), MakeUnspecifiedExpr(1))); ASSERT_THAT(map_expr.entries(), SizeIs(1)); EXPECT_EQ(map_expr.entries()[0], MakeMapExprEntry(1, MakeUnspecifiedExpr(1), MakeUnspecifiedExpr(1))); auto entries = map_expr.release_entries(); static_cast(entries); EXPECT_THAT(map_expr.entries(), IsEmpty()); } TEST(MapExpr, Equality) { EXPECT_EQ(MapExpr{}, MapExpr{}); MapExpr map_expr; map_expr.mutable_entries().push_back(MakeMapExprEntry( 0, MakeUnspecifiedExpr(0), MakeUnspecifiedExpr(0), true)); EXPECT_NE(MapExpr{}, map_expr); } TEST(ComprehensionExpr, IterVar) { ComprehensionExpr comprehension_expr; EXPECT_THAT(comprehension_expr.iter_var(), IsEmpty()); comprehension_expr.set_iter_var("foo"); EXPECT_THAT(comprehension_expr.iter_var(), Eq("foo")); auto iter_var = comprehension_expr.release_iter_var(); EXPECT_THAT(iter_var, Eq("foo")); EXPECT_THAT(comprehension_expr.iter_var(), IsEmpty()); } TEST(ComprehensionExpr, IterRange) { ComprehensionExpr comprehension_expr; EXPECT_THAT(comprehension_expr.has_iter_range(), IsFalse()); EXPECT_EQ(comprehension_expr.iter_range(), Expr{}); comprehension_expr.set_iter_range(MakeUnspecifiedExpr(1)); EXPECT_THAT(comprehension_expr.has_iter_range(), IsTrue()); EXPECT_EQ(comprehension_expr.iter_range(), MakeUnspecifiedExpr(1)); auto operand = comprehension_expr.release_iter_range(); EXPECT_THAT(comprehension_expr.has_iter_range(), IsFalse()); EXPECT_EQ(comprehension_expr.iter_range(), Expr{}); } TEST(ComprehensionExpr, AccuVar) { ComprehensionExpr comprehension_expr; EXPECT_THAT(comprehension_expr.accu_var(), IsEmpty()); comprehension_expr.set_accu_var("foo"); EXPECT_THAT(comprehension_expr.accu_var(), Eq("foo")); auto accu_var = comprehension_expr.release_accu_var(); EXPECT_THAT(accu_var, Eq("foo")); EXPECT_THAT(comprehension_expr.accu_var(), IsEmpty()); } TEST(ComprehensionExpr, AccuInit) { ComprehensionExpr comprehension_expr; EXPECT_THAT(comprehension_expr.has_accu_init(), IsFalse()); EXPECT_EQ(comprehension_expr.accu_init(), Expr{}); comprehension_expr.set_accu_init(MakeUnspecifiedExpr(1)); EXPECT_THAT(comprehension_expr.has_accu_init(), IsTrue()); EXPECT_EQ(comprehension_expr.accu_init(), MakeUnspecifiedExpr(1)); auto operand = comprehension_expr.release_accu_init(); EXPECT_THAT(comprehension_expr.has_accu_init(), IsFalse()); EXPECT_EQ(comprehension_expr.accu_init(), Expr{}); } TEST(ComprehensionExpr, LoopCondition) { ComprehensionExpr comprehension_expr; EXPECT_THAT(comprehension_expr.has_loop_condition(), IsFalse()); EXPECT_EQ(comprehension_expr.loop_condition(), Expr{}); comprehension_expr.set_loop_condition(MakeUnspecifiedExpr(1)); EXPECT_THAT(comprehension_expr.has_loop_condition(), IsTrue()); EXPECT_EQ(comprehension_expr.loop_condition(), MakeUnspecifiedExpr(1)); auto operand = comprehension_expr.release_loop_condition(); EXPECT_THAT(comprehension_expr.has_loop_condition(), IsFalse()); EXPECT_EQ(comprehension_expr.loop_condition(), Expr{}); } TEST(ComprehensionExpr, LoopStep) { ComprehensionExpr comprehension_expr; EXPECT_THAT(comprehension_expr.has_loop_step(), IsFalse()); EXPECT_EQ(comprehension_expr.loop_step(), Expr{}); comprehension_expr.set_loop_step(MakeUnspecifiedExpr(1)); EXPECT_THAT(comprehension_expr.has_loop_step(), IsTrue()); EXPECT_EQ(comprehension_expr.loop_step(), MakeUnspecifiedExpr(1)); auto operand = comprehension_expr.release_loop_step(); EXPECT_THAT(comprehension_expr.has_loop_step(), IsFalse()); EXPECT_EQ(comprehension_expr.loop_step(), Expr{}); } TEST(ComprehensionExpr, Result) { ComprehensionExpr comprehension_expr; EXPECT_THAT(comprehension_expr.has_result(), IsFalse()); EXPECT_EQ(comprehension_expr.result(), Expr{}); comprehension_expr.set_result(MakeUnspecifiedExpr(1)); EXPECT_THAT(comprehension_expr.has_result(), IsTrue()); EXPECT_EQ(comprehension_expr.result(), MakeUnspecifiedExpr(1)); auto operand = comprehension_expr.release_result(); EXPECT_THAT(comprehension_expr.has_result(), IsFalse()); EXPECT_EQ(comprehension_expr.result(), Expr{}); } TEST(ComprehensionExpr, Equality) { EXPECT_EQ(ComprehensionExpr{}, ComprehensionExpr{}); ComprehensionExpr comprehension_expr; comprehension_expr.set_result(MakeUnspecifiedExpr(1)); EXPECT_NE(ComprehensionExpr{}, comprehension_expr); } TEST(Expr, Unspecified) { Expr expr; EXPECT_THAT(expr.id(), Eq(ExprId{0})); EXPECT_THAT(expr.kind(), VariantWith(_)); EXPECT_EQ(expr.kind_case(), ExprKindCase::kUnspecifiedExpr); EXPECT_EQ(expr, Expr{}); } TEST(Expr, Ident) { Expr expr; EXPECT_THAT(expr.has_ident_expr(), IsFalse()); EXPECT_EQ(expr.ident_expr(), IdentExpr{}); auto& ident_expr = expr.mutable_ident_expr(); EXPECT_THAT(expr.has_ident_expr(), IsTrue()); EXPECT_NE(expr, Expr{}); ident_expr.set_name("foo"); EXPECT_NE(expr.ident_expr(), IdentExpr{}); EXPECT_EQ(expr.kind_case(), ExprKindCase::kIdentExpr); static_cast(expr.release_ident_expr()); EXPECT_THAT(expr.has_ident_expr(), IsFalse()); EXPECT_EQ(expr.ident_expr(), IdentExpr{}); EXPECT_EQ(expr, Expr{}); } TEST(Expr, Select) { Expr expr; EXPECT_THAT(expr.has_select_expr(), IsFalse()); EXPECT_EQ(expr.select_expr(), SelectExpr{}); auto& select_expr = expr.mutable_select_expr(); EXPECT_THAT(expr.has_select_expr(), IsTrue()); EXPECT_NE(expr, Expr{}); select_expr.set_field("foo"); EXPECT_NE(expr.select_expr(), SelectExpr{}); EXPECT_EQ(expr.kind_case(), ExprKindCase::kSelectExpr); static_cast(expr.release_select_expr()); EXPECT_THAT(expr.has_select_expr(), IsFalse()); EXPECT_EQ(expr.select_expr(), SelectExpr{}); EXPECT_EQ(expr, Expr{}); } TEST(Expr, Call) { Expr expr; EXPECT_THAT(expr.has_call_expr(), IsFalse()); EXPECT_EQ(expr.call_expr(), CallExpr{}); auto& call_expr = expr.mutable_call_expr(); EXPECT_THAT(expr.has_call_expr(), IsTrue()); EXPECT_NE(expr, Expr{}); call_expr.set_function("foo"); EXPECT_NE(expr.call_expr(), CallExpr{}); EXPECT_EQ(expr.kind_case(), ExprKindCase::kCallExpr); static_cast(expr.release_call_expr()); EXPECT_THAT(expr.has_call_expr(), IsFalse()); EXPECT_EQ(expr.call_expr(), CallExpr{}); EXPECT_EQ(expr, Expr{}); } TEST(Expr, List) { Expr expr; EXPECT_THAT(expr.has_list_expr(), IsFalse()); EXPECT_EQ(expr.list_expr(), ListExpr{}); auto& list_expr = expr.mutable_list_expr(); EXPECT_THAT(expr.has_list_expr(), IsTrue()); EXPECT_NE(expr, Expr{}); list_expr.mutable_elements().push_back(MakeListExprElement(Expr{}, true)); EXPECT_NE(expr.list_expr(), ListExpr{}); EXPECT_EQ(expr.kind_case(), ExprKindCase::kListExpr); static_cast(expr.release_list_expr()); EXPECT_THAT(expr.has_list_expr(), IsFalse()); EXPECT_EQ(expr.list_expr(), ListExpr{}); EXPECT_EQ(expr, Expr{}); } TEST(Expr, Struct) { Expr expr; EXPECT_THAT(expr.has_struct_expr(), IsFalse()); EXPECT_EQ(expr.struct_expr(), StructExpr{}); auto& struct_expr = expr.mutable_struct_expr(); EXPECT_THAT(expr.has_struct_expr(), IsTrue()); EXPECT_NE(expr, Expr{}); struct_expr.set_name("foo"); EXPECT_NE(expr.struct_expr(), StructExpr{}); EXPECT_EQ(expr.kind_case(), ExprKindCase::kStructExpr); static_cast(expr.release_struct_expr()); EXPECT_THAT(expr.has_struct_expr(), IsFalse()); EXPECT_EQ(expr.struct_expr(), StructExpr{}); EXPECT_EQ(expr, Expr{}); } TEST(Expr, Map) { Expr expr; EXPECT_THAT(expr.has_map_expr(), IsFalse()); EXPECT_EQ(expr.map_expr(), MapExpr{}); auto& map_expr = expr.mutable_map_expr(); EXPECT_THAT(expr.has_map_expr(), IsTrue()); EXPECT_NE(expr, Expr{}); map_expr.mutable_entries().push_back(MakeMapExprEntry(1, Expr{}, Expr{})); EXPECT_NE(expr.map_expr(), MapExpr{}); EXPECT_EQ(expr.kind_case(), ExprKindCase::kMapExpr); static_cast(expr.release_map_expr()); EXPECT_THAT(expr.has_map_expr(), IsFalse()); EXPECT_EQ(expr.map_expr(), MapExpr{}); EXPECT_EQ(expr, Expr{}); } TEST(Expr, Comprehension) { Expr expr; EXPECT_THAT(expr.has_comprehension_expr(), IsFalse()); EXPECT_EQ(expr.comprehension_expr(), ComprehensionExpr{}); auto& comprehension_expr = expr.mutable_comprehension_expr(); EXPECT_THAT(expr.has_comprehension_expr(), IsTrue()); EXPECT_NE(expr, Expr{}); comprehension_expr.set_iter_var("foo"); EXPECT_NE(expr.comprehension_expr(), ComprehensionExpr{}); EXPECT_EQ(expr.kind_case(), ExprKindCase::kComprehensionExpr); static_cast(expr.release_comprehension_expr()); EXPECT_THAT(expr.has_comprehension_expr(), IsFalse()); EXPECT_EQ(expr.comprehension_expr(), ComprehensionExpr{}); EXPECT_EQ(expr, Expr{}); } TEST(Expr, CopyCtor) { Expr expr; expr.mutable_select_expr().set_field("foo"); Expr& operand = expr.mutable_select_expr().mutable_operand(); operand.mutable_ident_expr().set_name("bar"); Expr expr_copy = expr; EXPECT_EQ(expr_copy.select_expr().field(), "foo"); EXPECT_EQ(expr_copy.select_expr().operand().ident_expr().name(), "bar"); } TEST(Expr, CopyAssignChildReference) { Expr expr; expr.mutable_select_expr().set_field("foo"); Expr& operand = expr.mutable_select_expr().mutable_operand(); operand.mutable_call_expr().set_function("bar"); auto& args = operand.mutable_call_expr().mutable_args(); args.emplace_back().mutable_ident_expr().set_name("baz"); args.emplace_back().mutable_ident_expr().set_name("qux"); expr = expr.mutable_select_expr().mutable_operand(); EXPECT_EQ(expr.call_expr().function(), "bar"); EXPECT_EQ(expr.call_expr().args().size(), 2); EXPECT_EQ(expr.call_expr().args()[0].ident_expr().name(), "baz"); EXPECT_EQ(expr.call_expr().args()[1].ident_expr().name(), "qux"); } TEST(Expr, FlattenedErase) { Expr expr; auto& list_expr = expr.mutable_list_expr(); list_expr.mutable_elements() .emplace_back() .mutable_expr() .mutable_ident_expr() .set_name("foo"); list_expr.mutable_elements() .emplace_back() .mutable_expr() .mutable_select_expr() .mutable_operand() .mutable_ident_expr() .set_name("foo"); auto& call_expr = list_expr.mutable_elements() .emplace_back() .mutable_expr() .mutable_call_expr(); call_expr.set_function("foo"); call_expr.mutable_target().mutable_ident_expr().set_name("bar"); call_expr.mutable_args().emplace_back().mutable_ident_expr().set_name("baz"); auto& struct_expr = list_expr.mutable_elements() .emplace_back() .mutable_expr() .mutable_struct_expr(); struct_expr.set_name("foo"); auto& field = struct_expr.mutable_fields().emplace_back(); field.set_name("bar"); field.mutable_value().mutable_ident_expr().set_name("baz"); auto& map_expr = list_expr.mutable_elements() .emplace_back() .mutable_expr() .mutable_map_expr(); auto& map_entry = map_expr.mutable_entries().emplace_back(); map_entry.mutable_key().mutable_const_expr().set_string_value("foo"); map_entry.mutable_value().mutable_ident_expr().set_name("bar"); auto& comprehension_expr = list_expr.mutable_elements() .emplace_back() .mutable_expr() .mutable_comprehension_expr(); comprehension_expr.set_iter_var("foo"); comprehension_expr.set_accu_var("bar"); comprehension_expr.set_iter_range(Expr{}); comprehension_expr.set_accu_init(Expr{}); comprehension_expr.set_loop_condition(Expr{}); comprehension_expr.set_loop_step(Expr{}); comprehension_expr.set_result(Expr{}); expr.FlattenedErase(); EXPECT_EQ(expr.kind_case(), ExprKindCase::kUnspecifiedExpr); } Expr MakeNestedList(int size) { Expr e; Expr* node = &e; e.set_id(1); for (int i = 0; i < size; ++i) { node = &node->mutable_list_expr() .mutable_elements() .emplace_back() .mutable_expr(); node->set_id(i + 2); } return e; } TEST(Expr, FlattenedErase256k) { // Large expr to ensure we're not recursing. Would likely hit stack limits // with default destructor. constexpr int size = 256 * 1024; Expr expr = MakeNestedList(size); expr.FlattenedErase(); EXPECT_EQ(expr.kind_case(), ExprKindCase::kUnspecifiedExpr); } TEST(Expr, Id) { Expr expr; EXPECT_THAT(expr.id(), Eq(0)); expr.set_id(1); EXPECT_THAT(expr.id(), Eq(1)); } } // namespace } // namespace cel ================================================ FILE: common/function_descriptor.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/function_descriptor.h" #include #include #include "absl/base/macros.h" #include "absl/types/span.h" #include "common/kind.h" namespace cel { bool FunctionDescriptor::ShapeMatches(bool receiver_style, absl::Span types) const { if (this->receiver_style() != receiver_style) { return false; } if (this->types().size() != types.size()) { return false; } for (size_t i = 0; i < this->types().size(); i++) { Kind this_type = this->types()[i]; Kind other_type = types[i]; if (this_type != Kind::kAny && other_type != Kind::kAny && this_type != other_type) { return false; } } return true; } bool FunctionDescriptor::operator==(const FunctionDescriptor& other) const { return impl_.get() == other.impl_.get() || (name() == other.name() && receiver_style() == other.receiver_style() && types().size() == other.types().size() && std::equal(types().begin(), types().end(), other.types().begin())); } bool FunctionDescriptor::operator<(const FunctionDescriptor& other) const { if (impl_.get() == other.impl_.get()) { return false; } if (name() < other.name()) { return true; } if (name() != other.name()) { return false; } if (receiver_style() < other.receiver_style()) { return true; } if (receiver_style() != other.receiver_style()) { return false; } auto lhs_begin = types().begin(); auto lhs_end = types().end(); auto rhs_begin = other.types().begin(); auto rhs_end = other.types().end(); while (lhs_begin != lhs_end && rhs_begin != rhs_end) { if (*lhs_begin < *rhs_begin) { return true; } if (!(*lhs_begin == *rhs_begin)) { return false; } lhs_begin++; rhs_begin++; } if (lhs_begin == lhs_end && rhs_begin == rhs_end) { // Neither has any elements left, they are equal. return false; } if (lhs_begin == lhs_end) { // Left has no more elements. Right is greater. return true; } // Right has no more elements. Left is greater. ABSL_ASSERT(rhs_begin == rhs_end); return false; } } // namespace cel ================================================ FILE: common/function_descriptor.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_DESCRIPTOR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_DESCRIPTOR_H_ #include #include #include #include #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/kind.h" namespace cel { struct FunctionDescriptorOptions { // If true (strict, default), error or unknown arguments are propagated // instead of calling the function. if false (non-strict), the function may // receive error or unknown values as arguments. bool is_strict = true; // Whether the function is impure or context-sensitive. // // Impure functions depend on state other than the arguments received during // the CEL expression evaluation or have visible side effects. This breaks // some of the assumptions of the CEL evaluation model. This flag is used as a // hint to the planner that some optimizations are not safe or not effective. bool is_contextual = false; }; // Coarsely describes a function for the purpose of runtime resolution of // overloads. class FunctionDescriptor final { public: FunctionDescriptor(absl::string_view name, bool receiver_style, std::vector types, bool is_strict) : impl_(std::make_shared( name, std::move(types), receiver_style, FunctionDescriptorOptions{is_strict, /*is_contextual=*/false})) {} FunctionDescriptor(absl::string_view name, bool receiver_style, std::vector types, bool is_strict, bool is_contextual) : impl_(std::make_shared( name, std::move(types), receiver_style, FunctionDescriptorOptions{is_strict, is_contextual})) {} FunctionDescriptor(absl::string_view name, bool is_receiver_style, std::vector types, FunctionDescriptorOptions options = {}) : impl_(std::make_shared(name, std::move(types), is_receiver_style, options)) {} // Function name. const std::string& name() const { return impl_->name; } // Whether function is receiver style i.e. true means arg0.name(args[1:]...). bool receiver_style() const { return impl_->is_receiver_style; } // The argument types the function accepts. // // TODO(uncreated-issue/17): make this kinds const std::vector& types() const { return impl_->types; } // if true (strict, default), error or unknown arguments are propagated // instead of calling the function. if false (non-strict), the function may // receive error or unknown values as arguments. bool is_strict() const { return impl_->options.is_strict; } // Whether the function is contextual (impure). // // Contextual functions depend on state other than the arguments received in // the CEL expression evaluation or have visible side effects. This breaks // some of the assumptions of CEL. This flag is used as a hint to the planner // that some optimizations are not safe or not effective. bool is_contextual() const { return impl_->options.is_contextual; } // Helper for matching a descriptor. This tests that the shape is the same -- // |other| accepts the same number and types of arguments and is the same call // style). bool ShapeMatches(const FunctionDescriptor& other) const { return ShapeMatches(other.receiver_style(), other.types()); } bool ShapeMatches(bool receiver_style, absl::Span types) const; bool operator==(const FunctionDescriptor& other) const; bool operator<(const FunctionDescriptor& other) const; private: struct Impl final { Impl(absl::string_view name, std::vector types, bool is_receiver_style, FunctionDescriptorOptions options) : name(name), types(std::move(types)), is_receiver_style(is_receiver_style), options(options) {} std::string name; std::vector types; bool is_receiver_style; FunctionDescriptorOptions options; }; std::shared_ptr impl_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_DESCRIPTOR_H_ ================================================ FILE: common/internal/BUILD ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) cc_library( name = "casting", hdrs = ["casting.h"], deps = [ "//common:native_type", "//internal:casts", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "reference_count", srcs = ["reference_count.cc"], hdrs = ["reference_count.h"], deps = [ "//common:data", "//internal:new", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "reference_count_test", srcs = ["reference_count_test.cc"], deps = [ ":reference_count", "//common:data", "//internal:testing", "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_library( name = "metadata", hdrs = ["metadata.h"], deps = ["@com_google_protobuf//:protobuf"], ) cc_library( name = "byte_string", srcs = ["byte_string.cc"], hdrs = ["byte_string.h"], deps = [ ":metadata", ":reference_count", "//common:allocator", "//common:arena", "//common:memory", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "byte_string_test", srcs = ["byte_string_test.cc"], deps = [ ":byte_string", ":reference_count", "//common:allocator", "//common:memory", "//internal:testing", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/hash", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:cord_test_helpers", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "value_conversion", srcs = ["value_conversion.cc"], hdrs = ["value_conversion.h"], deps = [ "//common:any", "//common:value", "//common:value_kind", "//extensions/protobuf:value", "//internal:proto_time_encoding", "//internal:status_macros", "//internal:time", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/time", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:value_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//src/google/protobuf/io", ], ) cc_library( name = "signature", srcs = ["signature.cc"], hdrs = ["signature.h"], deps = [ "//common:type", "//common:type_kind", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "signature_test", srcs = ["signature_test.cc"], deps = [ ":signature", "//common:type", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: common/internal/byte_string.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/internal/byte_string.h" #include #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/hash/hash.h" #include "absl/log/absl_check.h" #include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/allocator.h" #include "common/internal/metadata.h" #include "common/internal/reference_count.h" #include "common/memory.h" #include "google/protobuf/arena.h" namespace cel::common_internal { namespace { char* CopyCordToArray(const absl::Cord& cord, char* data) { for (auto chunk : cord.Chunks()) { std::memcpy(data, chunk.data(), chunk.size()); data += chunk.size(); } return data; } template T ConsumeAndDestroy(T& object) { T consumed = std::move(object); object.~T(); // NOLINT(bugprone-use-after-move) return consumed; } } // namespace ByteString ByteString::Concat(const ByteString& lhs, const ByteString& rhs, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(arena != nullptr); if (lhs.empty()) { return rhs; } if (rhs.empty()) { return lhs; } if (lhs.GetKind() == ByteStringKind::kLarge || rhs.GetKind() == ByteStringKind::kLarge) { // If either the left or right are absl::Cord, use absl::Cord. absl::Cord result; result.Append(lhs.ToCord()); result.Append(rhs.ToCord()); return ByteString(std::move(result)); } const size_t lhs_size = lhs.size(); const size_t rhs_size = rhs.size(); const size_t result_size = lhs_size + rhs_size; ByteString result; if (result_size <= kSmallByteStringCapacity) { // If the resulting string fits in inline storage, do it. result.rep_.small.size = result_size; result.rep_.small.arena = arena; lhs.CopyToArray(result.rep_.small.data); rhs.CopyToArray(result.rep_.small.data + lhs_size); } else { // Otherwise allocate on the arena. char* result_data = reinterpret_cast(arena->AllocateAligned(result_size)); lhs.CopyToArray(result_data); rhs.CopyToArray(result_data + lhs_size); result.rep_.medium.data = result_data; result.rep_.medium.size = result_size; result.rep_.medium.owner = reinterpret_cast(arena) | kMetadataOwnerArenaBit; result.rep_.header.kind = ByteStringKind::kMedium; } return result; } ByteString::ByteString(Allocator<> allocator, absl::string_view string) { ABSL_DCHECK_LE(string.size(), max_size()); auto* arena = allocator.arena(); if (string.size() <= kSmallByteStringCapacity) { SetSmall(arena, string); } else { SetMedium(arena, string); } } ByteString::ByteString(Allocator<> allocator, const std::string& string) { ABSL_DCHECK_LE(string.size(), max_size()); auto* arena = allocator.arena(); if (string.size() <= kSmallByteStringCapacity) { SetSmall(arena, string); } else { SetMedium(arena, string); } } ByteString::ByteString(Allocator<> allocator, std::string&& string) { ABSL_DCHECK_LE(string.size(), max_size()); auto* arena = allocator.arena(); if (string.size() <= kSmallByteStringCapacity) { SetSmall(arena, string); } else { SetMedium(arena, std::move(string)); } } ByteString::ByteString(Allocator<> allocator, const absl::Cord& cord) { ABSL_DCHECK_LE(cord.size(), max_size()); auto* arena = allocator.arena(); if (cord.size() <= kSmallByteStringCapacity) { SetSmall(arena, cord); } else if (arena != nullptr) { SetMedium(arena, cord); } else { SetLarge(cord); } } ByteString ByteString::Borrowed(Borrower borrower, absl::string_view string) { ABSL_DCHECK(borrower != Borrower::None()) << "Borrowing from Owner::None()"; auto* arena = borrower.arena(); if (string.size() <= kSmallByteStringCapacity || arena != nullptr) { return ByteString(arena, string); } const auto* refcount = BorrowerRelease(borrower); // A nullptr refcount indicates somebody called us to borrow something that // has no owner. If this is the case, we fallback to assuming operator // new/delete and convert it to a reference count. if (refcount == nullptr) { std::tie(refcount, string) = MakeReferenceCountedString(string); } else { StrongRef(*refcount); } return ByteString(refcount, string); } ByteString ByteString::Borrowed(Borrower borrower, const absl::Cord& cord) { ABSL_DCHECK(borrower != Borrower::None()) << "Borrowing from Owner::None()"; return ByteString(borrower.arena(), cord); } ByteString::ByteString(const ReferenceCount* absl_nonnull refcount, absl::string_view string) { ABSL_DCHECK_LE(string.size(), max_size()); SetMedium(string, reinterpret_cast(refcount) | kMetadataOwnerReferenceCountBit); } ByteString::ByteString(ByteString::ExternalStringTag, absl::string_view string) { if (string.size() <= kSmallByteStringCapacity) { SetSmall(nullptr, string); } else { SetExternalMedium(string); } } ByteString ByteString::FromExternal(absl::string_view string) { return ByteString(ExternalStringTag{}, string); } google::protobuf::Arena* absl_nullable ByteString::GetArena() const { switch (GetKind()) { case ByteStringKind::kSmall: return GetSmallArena(); case ByteStringKind::kMedium: return GetMediumArena(); case ByteStringKind::kLarge: return nullptr; } } bool ByteString::empty() const { switch (GetKind()) { case ByteStringKind::kSmall: return rep_.small.size == 0; case ByteStringKind::kMedium: return rep_.medium.size == 0; case ByteStringKind::kLarge: return GetLarge().empty(); } } size_t ByteString::size() const { switch (GetKind()) { case ByteStringKind::kSmall: return rep_.small.size; case ByteStringKind::kMedium: return rep_.medium.size; case ByteStringKind::kLarge: return GetLarge().size(); } } absl::string_view ByteString::Flatten() { switch (GetKind()) { case ByteStringKind::kSmall: return GetSmall(); case ByteStringKind::kMedium: return GetMedium(); case ByteStringKind::kLarge: return GetLarge().Flatten(); } } absl::optional ByteString::TryFlat() const { switch (GetKind()) { case ByteStringKind::kSmall: return GetSmall(); case ByteStringKind::kMedium: return GetMedium(); case ByteStringKind::kLarge: return GetLarge().TryFlat(); } } bool ByteString::Equals(absl::string_view rhs) const { return Visit(absl::Overload( [&rhs](absl::string_view lhs) -> bool { return lhs == rhs; }, [&rhs](const absl::Cord& lhs) -> bool { return lhs == rhs; })); } bool ByteString::Equals(const absl::Cord& rhs) const { return Visit(absl::Overload( [&rhs](absl::string_view lhs) -> bool { return lhs == rhs; }, [&rhs](const absl::Cord& lhs) -> bool { return lhs == rhs; })); } int ByteString::Compare(absl::string_view rhs) const { return Visit(absl::Overload( [&rhs](absl::string_view lhs) -> int { return lhs.compare(rhs); }, [&rhs](const absl::Cord& lhs) -> int { return lhs.Compare(rhs); })); } int ByteString::Compare(const absl::Cord& rhs) const { return Visit(absl::Overload( [&rhs](absl::string_view lhs) -> int { return -rhs.Compare(lhs); }, [&rhs](const absl::Cord& lhs) -> int { return lhs.Compare(rhs); })); } bool ByteString::StartsWith(absl::string_view rhs) const { return Visit(absl::Overload( [&rhs](absl::string_view lhs) -> bool { return absl::StartsWith(lhs, rhs); }, [&rhs](const absl::Cord& lhs) -> bool { return lhs.StartsWith(rhs); })); } bool ByteString::StartsWith(const absl::Cord& rhs) const { return Visit(absl::Overload( [&rhs](absl::string_view lhs) -> bool { return lhs.size() >= rhs.size() && lhs.substr(0, rhs.size()) == rhs; }, [&rhs](const absl::Cord& lhs) -> bool { return lhs.StartsWith(rhs); })); } bool ByteString::EndsWith(absl::string_view rhs) const { return Visit(absl::Overload( [&rhs](absl::string_view lhs) -> bool { return absl::EndsWith(lhs, rhs); }, [&rhs](const absl::Cord& lhs) -> bool { return lhs.EndsWith(rhs); })); } bool ByteString::EndsWith(const absl::Cord& rhs) const { return Visit(absl::Overload( [&rhs](absl::string_view lhs) -> bool { return lhs.size() >= rhs.size() && lhs.substr(lhs.size() - rhs.size()) == rhs; }, [&rhs](const absl::Cord& lhs) -> bool { return lhs.EndsWith(rhs); })); } absl::optional ByteString::Find(absl::string_view needle, size_t pos) const { ABSL_DCHECK_LE(pos, size()); return Visit(absl::Overload( [&needle, pos](absl::string_view lhs) -> absl::optional { absl::string_view::size_type i = lhs.find(needle, pos); if (i == absl::string_view::npos) { return absl::nullopt; } return i; }, [&needle, pos](const absl::Cord& lhs) -> absl::optional { absl::Cord cord = lhs.Subcord(pos, lhs.size() - pos); absl::Cord::CharIterator it = cord.Find(needle); if (it == cord.char_end()) { return absl::nullopt; } return pos + static_cast(absl::Cord::Distance(cord.char_begin(), it)); })); } absl::optional ByteString::Find(const absl::Cord& needle, size_t pos) const { ABSL_DCHECK_LE(pos, size()); return Visit(absl::Overload( [&needle, pos](absl::string_view lhs) -> absl::optional { if (auto flat_needle = needle.TryFlat(); flat_needle) { absl::string_view::size_type i = lhs.find(*flat_needle, pos); if (i == absl::string_view::npos) { return absl::nullopt; } return i; } // Needle is fragmented, we have to do a linear scan. const size_t needle_size = needle.size(); if (pos + needle_size > lhs.size()) { return absl::nullopt; } if (ABSL_PREDICT_FALSE(needle_size == 0)) { return pos; } // Optimization: find the first chunk of the needle, then compare the // rest. If the first chunk is empty, `lhs.find` will return // `current_pos`, which correctly degrades to a linear scan. absl::string_view first_chunk = *needle.Chunks().begin(); absl::Cord rest_of_needle = needle.Subcord( first_chunk.size(), needle_size - first_chunk.size()); size_t current_pos = pos; while (true) { size_t found_pos = lhs.find(first_chunk, current_pos); if (found_pos == absl::string_view::npos || found_pos > lhs.size() - needle_size) { return absl::nullopt; } if (lhs.substr(found_pos + first_chunk.size(), rest_of_needle.size()) == rest_of_needle) { return found_pos; } current_pos = found_pos + 1; } }, [&needle, pos](const absl::Cord& lhs) -> absl::optional { absl::Cord cord = lhs.Subcord(pos, lhs.size() - pos); absl::Cord::CharIterator it = cord.Find(needle); if (it == cord.char_end()) { return absl::nullopt; } return pos + static_cast(absl::Cord::Distance(cord.char_begin(), it)); })); } ByteString ByteString::Substring(size_t pos, size_t npos) const { ABSL_DCHECK_LE(npos, size()); ABSL_DCHECK_LE(pos, npos); switch (GetKind()) { case ByteStringKind::kSmall: { ByteString result; result.rep_.header.kind = ByteStringKind::kSmall; result.rep_.small.size = npos - pos; std::memcpy(result.rep_.small.data, rep_.small.data + pos, result.rep_.small.size); result.rep_.small.arena = GetSmallArena(); return result; } case ByteStringKind::kMedium: { ByteString result(*this); result.rep_.medium.data += pos; result.rep_.medium.size = npos - pos; return result; } case ByteStringKind::kLarge: return ByteString(GetLarge().Subcord(pos, npos - pos)); } } void ByteString::RemovePrefix(size_t n) { ABSL_DCHECK_LE(n, size()); if (n == 0) { return; } switch (GetKind()) { case ByteStringKind::kSmall: std::memmove(rep_.small.data, rep_.small.data + n, rep_.small.size - n); rep_.small.size -= n; break; case ByteStringKind::kMedium: rep_.medium.data += n; rep_.medium.size -= n; if (rep_.medium.size <= kSmallByteStringCapacity) { const auto* refcount = GetMediumReferenceCount(); SetSmall(GetMediumArena(), GetMedium()); StrongUnref(refcount); } break; case ByteStringKind::kLarge: { auto& large = GetLarge(); const auto large_size = large.size(); const auto new_large_pos = n; const auto new_large_size = large_size - n; large = large.Subcord(new_large_pos, new_large_size); if (new_large_size <= kSmallByteStringCapacity) { auto large_copy = std::move(large); DestroyLarge(); SetSmall(nullptr, large_copy); } } break; } } void ByteString::RemoveSuffix(size_t n) { ABSL_DCHECK_LE(n, size()); if (n == 0) { return; } switch (GetKind()) { case ByteStringKind::kSmall: rep_.small.size -= n; break; case ByteStringKind::kMedium: rep_.medium.size -= n; if (rep_.medium.size <= kSmallByteStringCapacity) { const auto* refcount = GetMediumReferenceCount(); SetSmall(GetMediumArena(), GetMedium()); StrongUnref(refcount); } break; case ByteStringKind::kLarge: { auto& large = GetLarge(); const auto large_size = large.size(); const auto new_large_pos = 0; const auto new_large_size = large_size - n; large = large.Subcord(new_large_pos, new_large_size); if (new_large_size <= kSmallByteStringCapacity) { auto large_copy = std::move(large); DestroyLarge(); SetSmall(nullptr, large_copy); } } break; } } void ByteString::CopyToArray(char* absl_nonnull out) const { ABSL_DCHECK(out != nullptr); switch (GetKind()) { case ByteStringKind::kSmall: { absl::string_view small = GetSmall(); std::memcpy(out, small.data(), small.size()); } break; case ByteStringKind::kMedium: { absl::string_view medium = GetMedium(); std::memcpy(out, medium.data(), medium.size()); } break; case ByteStringKind::kLarge: { const absl::Cord& large = GetLarge(); (CopyCordToArray)(large, out); } break; } } std::string ByteString::ToString() const { switch (GetKind()) { case ByteStringKind::kSmall: return std::string(GetSmall()); case ByteStringKind::kMedium: return std::string(GetMedium()); case ByteStringKind::kLarge: return static_cast(GetLarge()); } } void ByteString::CopyToString(std::string* absl_nonnull out) const { ABSL_DCHECK(out != nullptr); switch (GetKind()) { case ByteStringKind::kSmall: out->assign(GetSmall()); break; case ByteStringKind::kMedium: out->assign(GetMedium()); break; case ByteStringKind::kLarge: absl::CopyCordToString(GetLarge(), out); break; } } void ByteString::AppendToString(std::string* absl_nonnull out) const { ABSL_DCHECK(out != nullptr); switch (GetKind()) { case ByteStringKind::kSmall: out->append(GetSmall()); break; case ByteStringKind::kMedium: out->append(GetMedium()); break; case ByteStringKind::kLarge: absl::AppendCordToString(GetLarge(), out); break; } } namespace { struct ReferenceCountReleaser { const ReferenceCount* absl_nonnull refcount; void operator()() const { StrongUnref(*refcount); } }; } // namespace absl::Cord ByteString::ToCord() const& { switch (GetKind()) { case ByteStringKind::kSmall: return absl::Cord(GetSmall()); case ByteStringKind::kMedium: { const auto* refcount = GetMediumReferenceCount(); if (refcount != nullptr) { StrongRef(*refcount); return absl::MakeCordFromExternal(GetMedium(), ReferenceCountReleaser{refcount}); } return absl::Cord(GetMedium()); } case ByteStringKind::kLarge: return GetLarge(); } } absl::Cord ByteString::ToCord() && { switch (GetKind()) { case ByteStringKind::kSmall: return absl::Cord(GetSmall()); case ByteStringKind::kMedium: { const auto* refcount = GetMediumReferenceCount(); if (refcount != nullptr) { auto medium = GetMedium(); SetSmallEmpty(nullptr); return absl::MakeCordFromExternal(medium, ReferenceCountReleaser{refcount}); } return absl::Cord(GetMedium()); } case ByteStringKind::kLarge: return GetLarge(); } } void ByteString::CopyToCord(absl::Cord* absl_nonnull out) const { ABSL_DCHECK(out != nullptr); switch (GetKind()) { case ByteStringKind::kSmall: *out = absl::Cord(GetSmall()); break; case ByteStringKind::kMedium: { const auto* refcount = GetMediumReferenceCount(); if (refcount != nullptr) { StrongRef(*refcount); *out = absl::MakeCordFromExternal(GetMedium(), ReferenceCountReleaser{refcount}); } else { *out = absl::Cord(GetMedium()); } } break; case ByteStringKind::kLarge: *out = GetLarge(); break; } } void ByteString::AppendToCord(absl::Cord* absl_nonnull out) const { ABSL_DCHECK(out != nullptr); switch (GetKind()) { case ByteStringKind::kSmall: out->Append(GetSmall()); break; case ByteStringKind::kMedium: { const auto* refcount = GetMediumReferenceCount(); if (refcount != nullptr) { StrongRef(*refcount); out->Append(absl::MakeCordFromExternal( GetMedium(), ReferenceCountReleaser{refcount})); } else { out->Append(GetMedium()); } } break; case ByteStringKind::kLarge: out->Append(GetLarge()); break; } } absl::string_view ByteString::ToStringView( std::string* absl_nonnull scratch) const { ABSL_DCHECK(scratch != nullptr); switch (GetKind()) { case ByteStringKind::kSmall: return GetSmall(); case ByteStringKind::kMedium: return GetMedium(); case ByteStringKind::kLarge: if (auto flat = GetLarge().TryFlat(); flat) { return *flat; } absl::CopyCordToString(GetLarge(), scratch); return absl::string_view(*scratch); } } absl::string_view ByteString::AsStringView() const { const ByteStringKind kind = GetKind(); ABSL_CHECK(kind == ByteStringKind::kSmall || // Crash OK kind == ByteStringKind::kMedium); switch (kind) { case ByteStringKind::kSmall: return GetSmall(); case ByteStringKind::kMedium: return GetMedium(); case ByteStringKind::kLarge: ABSL_UNREACHABLE(); } } google::protobuf::Arena* absl_nullable ByteString::GetMediumArena( const MediumByteStringRep& rep) { if ((rep.owner & kMetadataOwnerBits) == kMetadataOwnerArenaBit) { return reinterpret_cast(rep.owner & kMetadataOwnerPointerMask); } return nullptr; } const ReferenceCount* absl_nullable ByteString::GetMediumReferenceCount( const MediumByteStringRep& rep) { if ((rep.owner & kMetadataOwnerBits) == kMetadataOwnerReferenceCountBit) { return reinterpret_cast(rep.owner & kMetadataOwnerPointerMask); } return nullptr; } void ByteString::Construct(const ByteString& other, absl::optional> allocator) { switch (other.GetKind()) { case ByteStringKind::kSmall: rep_.small = other.rep_.small; if (allocator.has_value()) { rep_.small.arena = allocator->arena(); } break; case ByteStringKind::kMedium: if (allocator.has_value() && allocator->arena() != other.GetMediumArena()) { SetMedium(allocator->arena(), other.GetMedium()); } else { rep_.medium = other.rep_.medium; StrongRef(GetMediumReferenceCount()); } break; case ByteStringKind::kLarge: if (allocator.has_value() && allocator->arena() != nullptr) { SetMedium(allocator->arena(), other.GetLarge()); } else { SetLarge(other.GetLarge()); } break; } } void ByteString::Construct(ByteString& other, absl::optional> allocator) { switch (other.GetKind()) { case ByteStringKind::kSmall: rep_.small = other.rep_.small; if (allocator.has_value()) { rep_.small.arena = allocator->arena(); } break; case ByteStringKind::kMedium: if (allocator.has_value() && allocator->arena() != other.GetMediumArena()) { SetMedium(allocator->arena(), other.GetMedium()); } else { rep_.medium = other.rep_.medium; other.rep_.medium.owner = 0; } break; case ByteStringKind::kLarge: if (allocator.has_value() && allocator->arena() != nullptr) { SetMedium(allocator->arena(), other.GetLarge()); } else { SetLarge(std::move(other.GetLarge())); } break; } } void ByteString::CopyFrom(const ByteString& other) { ABSL_DCHECK_NE(&other, this); switch (other.GetKind()) { case ByteStringKind::kSmall: switch (GetKind()) { case ByteStringKind::kSmall: break; case ByteStringKind::kMedium: DestroyMedium(); break; case ByteStringKind::kLarge: DestroyLarge(); break; } rep_.small = other.rep_.small; break; case ByteStringKind::kMedium: switch (GetKind()) { case ByteStringKind::kSmall: rep_.medium = other.rep_.medium; StrongRef(GetMediumReferenceCount()); break; case ByteStringKind::kMedium: StrongRef(other.GetMediumReferenceCount()); DestroyMedium(); rep_.medium = other.rep_.medium; break; case ByteStringKind::kLarge: DestroyLarge(); rep_.medium = other.rep_.medium; StrongRef(GetMediumReferenceCount()); break; } break; case ByteStringKind::kLarge: switch (GetKind()) { case ByteStringKind::kSmall: SetLarge(other.GetLarge()); break; case ByteStringKind::kMedium: DestroyMedium(); SetLarge(other.GetLarge()); break; case ByteStringKind::kLarge: GetLarge() = other.GetLarge(); break; } break; } } void ByteString::MoveFrom(ByteString& other) { ABSL_DCHECK_NE(&other, this); switch (other.GetKind()) { case ByteStringKind::kSmall: switch (GetKind()) { case ByteStringKind::kSmall: break; case ByteStringKind::kMedium: DestroyMedium(); break; case ByteStringKind::kLarge: DestroyLarge(); break; } rep_.small = other.rep_.small; break; case ByteStringKind::kMedium: switch (GetKind()) { case ByteStringKind::kSmall: rep_.medium = other.rep_.medium; break; case ByteStringKind::kMedium: DestroyMedium(); rep_.medium = other.rep_.medium; break; case ByteStringKind::kLarge: DestroyLarge(); rep_.medium = other.rep_.medium; break; } other.rep_.medium.owner = 0; break; case ByteStringKind::kLarge: switch (GetKind()) { case ByteStringKind::kSmall: SetLarge(std::move(other.GetLarge())); break; case ByteStringKind::kMedium: DestroyMedium(); SetLarge(std::move(other.GetLarge())); break; case ByteStringKind::kLarge: GetLarge() = std::move(other.GetLarge()); break; } break; } } ByteString ByteString::Clone(google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); switch (GetKind()) { case ByteStringKind::kSmall: return ByteString(arena, GetSmall()); case ByteStringKind::kMedium: { google::protobuf::Arena* absl_nullable other_arena = GetMediumArena(); if (arena != nullptr) { if (arena == other_arena) { return *this; } return ByteString(arena, GetMedium()); } if (other_arena != nullptr) { return ByteString(arena, GetMedium()); } return *this; } case ByteStringKind::kLarge: return ByteString(arena, GetLarge()); } } void ByteString::HashValue(absl::HashState state) const { switch (GetKind()) { case ByteStringKind::kSmall: absl::HashState::combine(std::move(state), GetSmall()); break; case ByteStringKind::kMedium: absl::HashState::combine(std::move(state), GetMedium()); break; case ByteStringKind::kLarge: absl::HashState::combine(std::move(state), GetLarge()); break; } } void ByteString::Swap(ByteString& other) { ABSL_DCHECK_NE(&other, this); using std::swap; switch (other.GetKind()) { case ByteStringKind::kSmall: switch (GetKind()) { case ByteStringKind::kSmall: // small <=> small swap(rep_.small, other.rep_.small); break; case ByteStringKind::kMedium: // medium <=> small swap(rep_, other.rep_); break; case ByteStringKind::kLarge: { absl::Cord cord = std::move(GetLarge()); DestroyLarge(); rep_ = other.rep_; other.SetLarge(std::move(cord)); } break; } break; case ByteStringKind::kMedium: switch (GetKind()) { case ByteStringKind::kSmall: swap(rep_, other.rep_); break; case ByteStringKind::kMedium: swap(rep_.medium, other.rep_.medium); break; case ByteStringKind::kLarge: { absl::Cord cord = std::move(GetLarge()); DestroyLarge(); rep_ = other.rep_; other.SetLarge(std::move(cord)); } break; } break; case ByteStringKind::kLarge: switch (GetKind()) { case ByteStringKind::kSmall: { absl::Cord cord = std::move(other.GetLarge()); other.DestroyLarge(); other.rep_.small = rep_.small; SetLarge(std::move(cord)); } break; case ByteStringKind::kMedium: { absl::Cord cord = std::move(other.GetLarge()); other.DestroyLarge(); other.rep_.medium = rep_.medium; SetLarge(std::move(cord)); } break; case ByteStringKind::kLarge: swap(GetLarge(), other.GetLarge()); break; } break; } } void ByteString::Destroy() { switch (GetKind()) { case ByteStringKind::kSmall: break; case ByteStringKind::kMedium: DestroyMedium(); break; case ByteStringKind::kLarge: DestroyLarge(); break; } } void ByteString::SetSmall(google::protobuf::Arena* absl_nullable arena, absl::string_view string) { ABSL_DCHECK_LE(string.size(), kSmallByteStringCapacity); rep_.header.kind = ByteStringKind::kSmall; rep_.small.size = string.size(); rep_.small.arena = arena; std::memcpy(rep_.small.data, string.data(), rep_.small.size); } void ByteString::SetSmall(google::protobuf::Arena* absl_nullable arena, const absl::Cord& cord) { ABSL_DCHECK_LE(cord.size(), kSmallByteStringCapacity); rep_.header.kind = ByteStringKind::kSmall; rep_.small.size = cord.size(); rep_.small.arena = arena; (CopyCordToArray)(cord, rep_.small.data); } void ByteString::SetMedium(google::protobuf::Arena* absl_nullable arena, absl::string_view string) { ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); rep_.header.kind = ByteStringKind::kMedium; rep_.medium.size = string.size(); if (arena != nullptr) { char* data = static_cast( arena->AllocateAligned(rep_.medium.size, alignof(char))); std::memcpy(data, string.data(), rep_.medium.size); rep_.medium.data = data; rep_.medium.owner = reinterpret_cast(arena) | kMetadataOwnerArenaBit; } else { auto pair = MakeReferenceCountedString(string); rep_.medium.data = pair.second.data(); rep_.medium.owner = reinterpret_cast(pair.first) | kMetadataOwnerReferenceCountBit; } } void ByteString::SetExternalMedium(absl::string_view string) { ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); rep_.header.kind = ByteStringKind::kMedium; rep_.medium.size = string.size(); rep_.medium.data = string.data(); rep_.medium.owner = 0; } void ByteString::SetMedium(google::protobuf::Arena* absl_nullable arena, std::string&& string) { ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); rep_.header.kind = ByteStringKind::kMedium; rep_.medium.size = string.size(); if (arena != nullptr) { auto* data = google::protobuf::Arena::Create(arena, std::move(string)); rep_.medium.data = data->data(); rep_.medium.owner = reinterpret_cast(arena) | kMetadataOwnerArenaBit; } else { auto pair = MakeReferenceCountedString(std::move(string)); rep_.medium.data = pair.second.data(); rep_.medium.owner = reinterpret_cast(pair.first) | kMetadataOwnerReferenceCountBit; } } void ByteString::SetMedium(google::protobuf::Arena* absl_nonnull arena, const absl::Cord& cord) { ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); rep_.header.kind = ByteStringKind::kMedium; rep_.medium.size = cord.size(); char* data = static_cast( arena->AllocateAligned(rep_.medium.size, alignof(char))); (CopyCordToArray)(cord, data); rep_.medium.data = data; rep_.medium.owner = reinterpret_cast(arena) | kMetadataOwnerArenaBit; } void ByteString::SetMedium(absl::string_view string, uintptr_t owner) { ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); ABSL_DCHECK_NE(owner, 0); rep_.header.kind = ByteStringKind::kMedium; rep_.medium.size = string.size(); rep_.medium.data = string.data(); rep_.medium.owner = owner; } void ByteString::SetLarge(const absl::Cord& cord) { ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); rep_.header.kind = ByteStringKind::kLarge; ::new (static_cast(&rep_.large.data[0])) absl::Cord(cord); } void ByteString::SetLarge(absl::Cord&& cord) { ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); rep_.header.kind = ByteStringKind::kLarge; ::new (static_cast(&rep_.large.data[0])) absl::Cord(std::move(cord)); } absl::string_view LegacyByteString(const ByteString& string, bool stable, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(arena != nullptr); if (string.empty()) { return absl::string_view(); } const ByteStringKind kind = string.GetKind(); if (kind == ByteStringKind::kMedium && string.GetMediumArena() == arena) { google::protobuf::Arena* absl_nullable other_arena = string.GetMediumArena(); if (other_arena == arena || other_arena == nullptr) { // Legacy values do not preserve arena. For speed, we assume the arena is // compatible. return string.GetMedium(); } } if (stable && kind == ByteStringKind::kSmall) { return string.GetSmall(); } std::string* absl_nonnull result = google::protobuf::Arena::Create(arena); switch (kind) { case ByteStringKind::kSmall: result->assign(string.GetSmall()); break; case ByteStringKind::kMedium: result->assign(string.GetMedium()); break; case ByteStringKind::kLarge: absl::CopyCordToString(string.GetLarge(), result); break; } return absl::string_view(*result); } } // namespace cel::common_internal ================================================ FILE: common/internal/byte_string.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ #define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ #include #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/hash/hash.h" #include "absl/log/absl_check.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/allocator.h" #include "common/arena.h" #include "common/internal/reference_count.h" #include "common/memory.h" #include "google/protobuf/arena.h" namespace cel { class BytesValueInputStream; class BytesValueOutputStream; class StringValue; namespace common_internal { // absl::Cord is trivially relocatable IFF we are not using ASan or MSan. When // using ASan or MSan absl::Cord will poison/unpoison its inline storage. #if defined(ABSL_HAVE_ADDRESS_SANITIZER) || defined(ABSL_HAVE_MEMORY_SANITIZER) #define CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI #else #define CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI ABSL_ATTRIBUTE_TRIVIAL_ABI #endif class CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI [[nodiscard]] ByteString; struct ByteStringTestFriend; enum class ByteStringKind : unsigned int { kSmall = 0, kMedium, kLarge, }; inline std::ostream& operator<<(std::ostream& out, ByteStringKind kind) { switch (kind) { case ByteStringKind::kSmall: return out << "SMALL"; case ByteStringKind::kMedium: return out << "MEDIUM"; case ByteStringKind::kLarge: return out << "LARGE"; } } // Representation of small strings in ByteString, which are stored in place. struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI SmallByteStringRep final { #ifdef _MSC_VER #pragma pack(push, 1) #endif struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { std::uint8_t kind : 2; std::uint8_t size : 6; }; #ifdef _MSC_VER #pragma pack(pop) #endif char data[23 - sizeof(google::protobuf::Arena*)]; google::protobuf::Arena* absl_nullable arena; }; inline constexpr size_t kSmallByteStringCapacity = sizeof(SmallByteStringRep::data); inline constexpr size_t kMediumByteStringSizeBits = sizeof(size_t) * 8 - 2; inline constexpr size_t kMediumByteStringMaxSize = (size_t{1} << kMediumByteStringSizeBits) - 1; inline constexpr size_t kByteStringViewSizeBits = sizeof(size_t) * 8 - 1; inline constexpr size_t kByteStringViewMaxSize = (size_t{1} << kByteStringViewSizeBits) - 1; // Representation of medium strings in ByteString. These are either owned by an // arena or managed by a reference count. This is encoded in `owner` following // the same semantics as `cel::Owner`. struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI MediumByteStringRep final { #ifdef _MSC_VER #pragma pack(push, 1) #endif struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { size_t kind : 2; size_t size : kMediumByteStringSizeBits; }; #ifdef _MSC_VER #pragma pack(pop) #endif const char* data; uintptr_t owner; }; // Representation of large strings in ByteString. These are stored as // `absl::Cord` and never owned by an arena. struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI LargeByteStringRep final { #ifdef _MSC_VER #pragma pack(push, 1) #endif struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { size_t kind : 2; size_t padding : kMediumByteStringSizeBits; }; #ifdef _MSC_VER #pragma pack(pop) #endif alignas(absl::Cord) std::byte data[sizeof(absl::Cord)]; }; // Representation of ByteString. union CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI ByteStringRep final { #ifdef _MSC_VER #pragma pack(push, 1) #endif struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { ByteStringKind kind : 2; } header; #ifdef _MSC_VER #pragma pack(pop) #endif SmallByteStringRep small; MediumByteStringRep medium; LargeByteStringRep large; }; // Returns a `absl::string_view` from `ByteString`, using `arena` to make memory // allocations if necessary. `stable` indicates whether `cel::Value` is in a // location where it will not be moved, so that inline string/bytes storage can // be referenced. absl::string_view LegacyByteString(const ByteString& string, bool stable, google::protobuf::Arena* absl_nonnull arena); // `ByteString` is a vocabulary type capable of representing copy-on-write // strings efficiently for arenas and reference counting. The contents of the // byte string are owned by an arena or managed by a reference count. All byte // strings have an associated allocator specified at construction, once the byte // string is constructed the allocator will not and cannot change. Copying and // moving between different allocators is supported and dealt with // transparently by copying. class CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI [[nodiscard]] ByteString final { public: static ByteString Concat(const ByteString& lhs, const ByteString& rhs, google::protobuf::Arena* absl_nonnull arena); ByteString() : ByteString(NewDeleteAllocator()) {} explicit ByteString(const char* absl_nullable string) : ByteString(NewDeleteAllocator(), string) {} explicit ByteString(absl::string_view string) : ByteString(NewDeleteAllocator(), string) {} explicit ByteString(const std::string& string) : ByteString(NewDeleteAllocator(), string) {} explicit ByteString(std::string&& string) : ByteString(NewDeleteAllocator(), std::move(string)) {} explicit ByteString(const absl::Cord& cord) : ByteString(NewDeleteAllocator(), cord) {} ByteString(const ByteString& other) noexcept { Construct(other, /*allocator=*/absl::nullopt); } ByteString(ByteString&& other) noexcept { Construct(other, /*allocator=*/absl::nullopt); } explicit ByteString(Allocator<> allocator) { SetSmallEmpty(allocator.arena()); } ByteString(Allocator<> allocator, const char* absl_nullable string) : ByteString(allocator, absl::NullSafeStringView(string)) {} ByteString(Allocator<> allocator, absl::string_view string); ByteString(Allocator<> allocator, const std::string& string); ByteString(Allocator<> allocator, std::string&& string); ByteString(Allocator<> allocator, const absl::Cord& cord); ByteString(Allocator<> allocator, const ByteString& other) { Construct(other, allocator); } ByteString(Allocator<> allocator, ByteString&& other) { Construct(other, allocator); } ByteString(Borrower borrower, const char* absl_nullable string ABSL_ATTRIBUTE_LIFETIME_BOUND) : ByteString(borrower, absl::NullSafeStringView(string)) {} ByteString(Borrower borrower, absl::string_view string ABSL_ATTRIBUTE_LIFETIME_BOUND) : ByteString(Borrowed(borrower, string)) {} ByteString(Borrower borrower, const absl::Cord& cord ABSL_ATTRIBUTE_LIFETIME_BOUND) : ByteString(Borrowed(borrower, cord)) {} // Creates a medium byte string that is backed by an external string. Should // only be called from explicit 'Unsafe' factories. static ByteString FromExternal(absl::string_view string); ~ByteString() { Destroy(); } ByteString& operator=(const ByteString& other) noexcept { if (ABSL_PREDICT_TRUE(this != &other)) { CopyFrom(other); } return *this; } ByteString& operator=(ByteString&& other) noexcept { if (ABSL_PREDICT_TRUE(this != &other)) { MoveFrom(other); } return *this; } bool empty() const; size_t size() const; size_t max_size() const { return kByteStringViewMaxSize; } absl::string_view Flatten() ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional TryFlat() const ABSL_ATTRIBUTE_LIFETIME_BOUND; bool Equals(absl::string_view rhs) const; bool Equals(const absl::Cord& rhs) const; bool Equals(const ByteString& rhs) const; int Compare(absl::string_view rhs) const; int Compare(const absl::Cord& rhs) const; int Compare(const ByteString& rhs) const; bool StartsWith(absl::string_view rhs) const; bool StartsWith(const absl::Cord& rhs) const; bool StartsWith(const ByteString& rhs) const; bool EndsWith(absl::string_view rhs) const; bool EndsWith(const absl::Cord& rhs) const; bool EndsWith(const ByteString& rhs) const; // Finds the first occurrence of `needle` in this object, starting at byte // position `pos`. Returns `absl::nullopt` if `needle` is not found. // Note: Positions are byte-based, not code point based as in // `cel::StringValue`. absl::optional Find(absl::string_view needle, size_t pos = 0) const; absl::optional Find(const absl::Cord& needle, size_t pos = 0) const; absl::optional Find(const ByteString& needle, size_t pos = 0) const; // Returns a new `ByteString` that is a substring of this object, starting at // byte position `pos` and with a length of `npos` bytes. // Note: Positions are byte-based, not code point based as in // `cel::StringValue`. ByteString Substring(size_t pos, size_t npos) const; ByteString Substring(size_t pos) const { ABSL_DCHECK_LE(pos, size()); return Substring(pos, size()); } void RemovePrefix(size_t n); void RemoveSuffix(size_t n); std::string ToString() const; void CopyToString(std::string* absl_nonnull out) const; void AppendToString(std::string* absl_nonnull out) const; absl::Cord ToCord() const&; absl::Cord ToCord() &&; void CopyToCord(absl::Cord* absl_nonnull out) const; void AppendToCord(absl::Cord* absl_nonnull out) const; absl::string_view ToStringView( std::string* absl_nonnull scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::string_view AsStringView() const ABSL_ATTRIBUTE_LIFETIME_BOUND; google::protobuf::Arena* absl_nullable GetArena() const; ByteString Clone(google::protobuf::Arena* absl_nonnull arena) const; void HashValue(absl::HashState state) const; template decltype(auto) Visit(Visitor&& visitor) const { switch (GetKind()) { case ByteStringKind::kSmall: return std::forward(visitor)(GetSmall()); case ByteStringKind::kMedium: return std::forward(visitor)(GetMedium()); case ByteStringKind::kLarge: return std::forward(visitor)(GetLarge()); } } friend void swap(ByteString& lhs, ByteString& rhs) { if (&lhs != &rhs) { lhs.Swap(rhs); } } template friend H AbslHashValue(H state, const ByteString& byte_string) { byte_string.HashValue(absl::HashState::Create(&state)); return state; } private: friend class ByteStringView; friend struct ByteStringTestFriend; friend class cel::BytesValueInputStream; friend class cel::BytesValueOutputStream; friend class cel::StringValue; friend absl::string_view LegacyByteString(const ByteString& string, bool stable, google::protobuf::Arena* absl_nonnull arena); friend struct cel::ArenaTraits; struct ExternalStringTag {}; static ByteString Borrowed(Borrower borrower, absl::string_view string ABSL_ATTRIBUTE_LIFETIME_BOUND); static ByteString Borrowed( Borrower borrower, const absl::Cord& cord ABSL_ATTRIBUTE_LIFETIME_BOUND); ByteString(const ReferenceCount* absl_nonnull refcount, absl::string_view string); ByteString(ExternalStringTag, absl::string_view string); constexpr ByteStringKind GetKind() const { return rep_.header.kind; } absl::string_view GetSmall() const { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); return GetSmall(rep_.small); } static absl::string_view GetSmall(const SmallByteStringRep& rep) { return absl::string_view(rep.data, rep.size); } absl::string_view GetMedium() const { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); return GetMedium(rep_.medium); } static absl::string_view GetMedium(const MediumByteStringRep& rep) { return absl::string_view(rep.data, rep.size); } google::protobuf::Arena* absl_nullable GetSmallArena() const { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); return GetSmallArena(rep_.small); } static google::protobuf::Arena* absl_nullable GetSmallArena( const SmallByteStringRep& rep) { return rep.arena; } google::protobuf::Arena* absl_nullable GetMediumArena() const { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); return GetMediumArena(rep_.medium); } static google::protobuf::Arena* absl_nullable GetMediumArena( const MediumByteStringRep& rep); const ReferenceCount* absl_nullable GetMediumReferenceCount() const { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); return GetMediumReferenceCount(rep_.medium); } static const ReferenceCount* absl_nullable GetMediumReferenceCount( const MediumByteStringRep& rep); uintptr_t GetMediumOwner() const { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); return rep_.medium.owner; } absl::Cord& GetLarge() ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); return GetLarge(rep_.large); } static absl::Cord& GetLarge( LargeByteStringRep& rep ABSL_ATTRIBUTE_LIFETIME_BOUND) { return *std::launder(reinterpret_cast(&rep.data[0])); } const absl::Cord& GetLarge() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); return GetLarge(rep_.large); } static const absl::Cord& GetLarge( const LargeByteStringRep& rep ABSL_ATTRIBUTE_LIFETIME_BOUND) { return *std::launder(reinterpret_cast(&rep.data[0])); } void SetSmallEmpty(google::protobuf::Arena* absl_nullable arena) { rep_.header.kind = ByteStringKind::kSmall; rep_.small.size = 0; rep_.small.arena = arena; } void SetSmall(google::protobuf::Arena* absl_nullable arena, absl::string_view string); void SetSmall(google::protobuf::Arena* absl_nullable arena, const absl::Cord& cord); void SetMedium(google::protobuf::Arena* absl_nullable arena, absl::string_view string); // This is used to create a medium byte string that is backed by an external // string. Should only be called from explicit 'Unsafe' factories. void SetExternalMedium(absl::string_view string); void SetMedium(google::protobuf::Arena* absl_nullable arena, std::string&& string); void SetMedium(google::protobuf::Arena* absl_nonnull arena, const absl::Cord& cord); void SetMedium(absl::string_view string, uintptr_t owner); void SetLarge(const absl::Cord& cord); void SetLarge(absl::Cord&& cord); void Swap(ByteString& other); void Construct(const ByteString& other, absl::optional> allocator); void Construct(ByteString& other, absl::optional> allocator); void CopyFrom(const ByteString& other); void MoveFrom(ByteString& other); void Destroy(); void DestroyMedium() { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); DestroyMedium(rep_.medium); } static void DestroyMedium(const MediumByteStringRep& rep) { StrongUnref(GetMediumReferenceCount(rep)); } void DestroyLarge() { ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); DestroyLarge(rep_.large); } static void DestroyLarge(LargeByteStringRep& rep) { GetLarge(rep).~Cord(); } void CopyToArray(char* absl_nonnull out) const; ByteStringRep rep_; }; inline bool ByteString::Equals(const ByteString& rhs) const { return rhs.Visit(absl::Overload( [this](absl::string_view rhs) -> bool { return Equals(rhs); }, [this](const absl::Cord& rhs) -> bool { return Equals(rhs); })); } inline int ByteString::Compare(const ByteString& rhs) const { return rhs.Visit(absl::Overload( [this](absl::string_view rhs) -> int { return Compare(rhs); }, [this](const absl::Cord& rhs) -> int { return Compare(rhs); })); } inline bool ByteString::StartsWith(const ByteString& rhs) const { return rhs.Visit(absl::Overload( [this](absl::string_view rhs) -> bool { return StartsWith(rhs); }, [this](const absl::Cord& rhs) -> bool { return StartsWith(rhs); })); } inline bool ByteString::EndsWith(const ByteString& rhs) const { return rhs.Visit(absl::Overload( [this](absl::string_view rhs) -> bool { return EndsWith(rhs); }, [this](const absl::Cord& rhs) -> bool { return EndsWith(rhs); })); } inline absl::optional ByteString::Find(const ByteString& needle, size_t pos) const { return needle.Visit(absl::Overload( [this, pos](absl::string_view rhs) -> absl::optional { return Find(rhs, pos); }, [this, pos](const absl::Cord& rhs) -> absl::optional { return Find(rhs, pos); })); } inline bool operator==(const ByteString& lhs, const ByteString& rhs) { return lhs.Equals(rhs); } inline bool operator==(const ByteString& lhs, absl::string_view rhs) { return lhs.Equals(rhs); } inline bool operator==(absl::string_view lhs, const ByteString& rhs) { return rhs.Equals(lhs); } inline bool operator==(const ByteString& lhs, const absl::Cord& rhs) { return lhs.Equals(rhs); } inline bool operator==(const absl::Cord& lhs, const ByteString& rhs) { return rhs.Equals(lhs); } inline bool operator!=(const ByteString& lhs, const ByteString& rhs) { return !operator==(lhs, rhs); } inline bool operator!=(const ByteString& lhs, absl::string_view rhs) { return !operator==(lhs, rhs); } inline bool operator!=(absl::string_view lhs, const ByteString& rhs) { return !operator==(lhs, rhs); } inline bool operator!=(const ByteString& lhs, const absl::Cord& rhs) { return !operator==(lhs, rhs); } inline bool operator!=(const absl::Cord& lhs, const ByteString& rhs) { return !operator==(lhs, rhs); } inline bool operator<(const ByteString& lhs, const ByteString& rhs) { return lhs.Compare(rhs) < 0; } inline bool operator<(const ByteString& lhs, absl::string_view rhs) { return lhs.Compare(rhs) < 0; } inline bool operator<(absl::string_view lhs, const ByteString& rhs) { return -rhs.Compare(lhs) < 0; } inline bool operator<(const ByteString& lhs, const absl::Cord& rhs) { return lhs.Compare(rhs) < 0; } inline bool operator<(const absl::Cord& lhs, const ByteString& rhs) { return -rhs.Compare(lhs) < 0; } inline bool operator<=(const ByteString& lhs, const ByteString& rhs) { return lhs.Compare(rhs) <= 0; } inline bool operator<=(const ByteString& lhs, absl::string_view rhs) { return lhs.Compare(rhs) <= 0; } inline bool operator<=(absl::string_view lhs, const ByteString& rhs) { return -rhs.Compare(lhs) <= 0; } inline bool operator<=(const ByteString& lhs, const absl::Cord& rhs) { return lhs.Compare(rhs) <= 0; } inline bool operator<=(const absl::Cord& lhs, const ByteString& rhs) { return -rhs.Compare(lhs) <= 0; } inline bool operator>(const ByteString& lhs, const ByteString& rhs) { return lhs.Compare(rhs) > 0; } inline bool operator>(const ByteString& lhs, absl::string_view rhs) { return lhs.Compare(rhs) > 0; } inline bool operator>(absl::string_view lhs, const ByteString& rhs) { return -rhs.Compare(lhs) > 0; } inline bool operator>(const ByteString& lhs, const absl::Cord& rhs) { return lhs.Compare(rhs) > 0; } inline bool operator>(const absl::Cord& lhs, const ByteString& rhs) { return -rhs.Compare(lhs) > 0; } inline bool operator>=(const ByteString& lhs, const ByteString& rhs) { return lhs.Compare(rhs) >= 0; } inline bool operator>=(const ByteString& lhs, absl::string_view rhs) { return lhs.Compare(rhs) >= 0; } inline bool operator>=(absl::string_view lhs, const ByteString& rhs) { return -rhs.Compare(lhs) >= 0; } inline bool operator>=(const ByteString& lhs, const absl::Cord& rhs) { return lhs.Compare(rhs) >= 0; } inline bool operator>=(const absl::Cord& lhs, const ByteString& rhs) { return -rhs.Compare(lhs) >= 0; } #undef CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI } // namespace common_internal template <> struct ArenaTraits { using constructible = std::true_type; static bool trivially_destructible( const common_internal::ByteString& byte_string) { switch (byte_string.GetKind()) { case common_internal::ByteStringKind::kSmall: return true; case common_internal::ByteStringKind::kMedium: return byte_string.GetMediumReferenceCount() == nullptr; case common_internal::ByteStringKind::kLarge: return false; } } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ ================================================ FILE: common/internal/byte_string_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/internal/byte_string.h" #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/hash/hash.h" #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/allocator.h" #include "common/internal/reference_count.h" #include "common/memory.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel::common_internal { struct ByteStringTestFriend { static ByteStringKind GetKind(const ByteString& byte_string) { return byte_string.GetKind(); } }; namespace { using ::testing::_; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Not; using ::testing::Optional; using ::testing::SizeIs; using ::testing::TestWithParam; TEST(ByteStringKind, Ostream) { { std::ostringstream out; out << ByteStringKind::kSmall; EXPECT_EQ(out.str(), "SMALL"); } { std::ostringstream out; out << ByteStringKind::kMedium; EXPECT_EQ(out.str(), "MEDIUM"); } { std::ostringstream out; out << ByteStringKind::kLarge; EXPECT_EQ(out.str(), "LARGE"); } } class ByteStringTest : public TestWithParam, public ByteStringTestFriend { public: Allocator<> GetAllocator() { switch (GetParam()) { case AllocatorKind::kNewDelete: return NewDeleteAllocator<>{}; case AllocatorKind::kArena: return ArenaAllocator<>(&arena_); } } private: google::protobuf::Arena arena_; }; absl::string_view GetSmallStringView() { static constexpr absl::string_view small = "A small string!"; return small.substr(0, std::min(kSmallByteStringCapacity, small.size())); } std::string GetSmallString() { return std::string(GetSmallStringView()); } absl::Cord GetSmallCord() { static const absl::NoDestructor small(GetSmallStringView()); return *small; } absl::string_view GetMediumStringView() { static constexpr absl::string_view medium = "A string that is too large for the small string optimization!"; return medium; } std::string GetMediumString() { return std::string(GetMediumStringView()); } const absl::Cord& GetMediumOrLargeCord() { static const absl::NoDestructor medium_or_large( GetMediumStringView()); return *medium_or_large; } const absl::Cord& GetMediumOrLargeFragmentedCord() { static const absl::NoDestructor medium_or_large( absl::MakeFragmentedCord( {GetMediumStringView().substr(0, kSmallByteStringCapacity), GetMediumStringView().substr(kSmallByteStringCapacity)})); return *medium_or_large; } TEST_P(ByteStringTest, Default) { ByteString byte_string = ByteString(GetAllocator(), ""); EXPECT_THAT(byte_string, SizeIs(0)); EXPECT_THAT(byte_string, IsEmpty()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); } TEST_P(ByteStringTest, ConstructSmallCString) { ByteString byte_string = ByteString(GetAllocator(), GetSmallString().c_str()); EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); EXPECT_THAT(byte_string, Not(IsEmpty())); EXPECT_EQ(byte_string, GetSmallStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); } TEST_P(ByteStringTest, ConstructMediumCString) { ByteString byte_string = ByteString(GetAllocator(), GetMediumString().c_str()); EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); EXPECT_THAT(byte_string, Not(IsEmpty())); EXPECT_EQ(byte_string, GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); } TEST_P(ByteStringTest, ConstructSmallRValueString) { ByteString byte_string = ByteString(GetAllocator(), GetSmallString()); EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); EXPECT_THAT(byte_string, Not(IsEmpty())); EXPECT_EQ(byte_string, GetSmallStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); } TEST_P(ByteStringTest, ConstructSmallLValueString) { ByteString byte_string = ByteString( GetAllocator(), static_cast(GetSmallString())); EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); EXPECT_THAT(byte_string, Not(IsEmpty())); EXPECT_EQ(byte_string, GetSmallStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); } TEST_P(ByteStringTest, ConstructMediumRValueString) { ByteString byte_string = ByteString(GetAllocator(), GetMediumString()); EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); EXPECT_THAT(byte_string, Not(IsEmpty())); EXPECT_EQ(byte_string, GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); } TEST_P(ByteStringTest, ConstructMediumLValueString) { ByteString byte_string = ByteString( GetAllocator(), static_cast(GetMediumString())); EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); EXPECT_THAT(byte_string, Not(IsEmpty())); EXPECT_EQ(byte_string, GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); } TEST_P(ByteStringTest, ConstructSmallCord) { ByteString byte_string = ByteString(GetAllocator(), GetSmallCord()); EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); EXPECT_THAT(byte_string, Not(IsEmpty())); EXPECT_EQ(byte_string, GetSmallStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); } TEST_P(ByteStringTest, ConstructMediumOrLargeCord) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); EXPECT_THAT(byte_string, Not(IsEmpty())); EXPECT_EQ(byte_string, GetMediumStringView()); if (GetAllocator().arena() == nullptr) { EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); } else { EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); } EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); } TEST(ByteStringTest, BorrowedUnownedString) { #ifdef NDEBUG ByteString byte_string = ByteString(Owner::None(), GetMediumStringView()); EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kMedium); EXPECT_EQ(byte_string.GetArena(), nullptr); EXPECT_EQ(byte_string, GetMediumStringView()); #else EXPECT_DEBUG_DEATH( static_cast(ByteString(Owner::None(), GetMediumStringView())), ::testing::_); #endif } TEST(ByteStringTest, BorrowedUnownedCord) { #ifdef NDEBUG ByteString byte_string = ByteString(Owner::None(), GetMediumOrLargeCord()); EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kLarge); EXPECT_EQ(byte_string.GetArena(), nullptr); EXPECT_EQ(byte_string, GetMediumOrLargeCord()); #else EXPECT_DEBUG_DEATH( static_cast(ByteString(Owner::None(), GetMediumOrLargeCord())), ::testing::_); #endif } TEST(ByteStringTest, BorrowedReferenceCountSmallString) { auto* refcount = new ReferenceCounted(); Owner owner = Owner::ReferenceCount(refcount); StrongUnref(refcount); ByteString byte_string = ByteString(owner, GetSmallStringView()); EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string.GetArena(), nullptr); EXPECT_EQ(byte_string, GetSmallStringView()); } TEST(ByteStringTest, BorrowedReferenceCountMediumString) { auto* refcount = new ReferenceCounted(); Owner owner = Owner::ReferenceCount(refcount); StrongUnref(refcount); ByteString byte_string = ByteString(owner, GetMediumStringView()); EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kMedium); EXPECT_EQ(byte_string.GetArena(), nullptr); EXPECT_EQ(byte_string, GetMediumStringView()); } TEST(ByteStringTest, BorrowedArenaSmallString) { google::protobuf::Arena arena; ByteString byte_string = ByteString(Owner::Arena(&arena), GetSmallStringView()); EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string.GetArena(), &arena); EXPECT_EQ(byte_string, GetSmallStringView()); } TEST(ByteStringTest, BorrowedArenaMediumString) { google::protobuf::Arena arena; ByteString byte_string = ByteString(Owner::Arena(&arena), GetMediumStringView()); EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kMedium); EXPECT_EQ(byte_string.GetArena(), &arena); EXPECT_EQ(byte_string, GetMediumStringView()); } TEST(ByteStringTest, BorrowedReferenceCountCord) { auto* refcount = new ReferenceCounted(); Owner owner = Owner::ReferenceCount(refcount); StrongUnref(refcount); ByteString byte_string = ByteString(owner, GetMediumOrLargeCord()); EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kLarge); EXPECT_EQ(byte_string.GetArena(), nullptr); EXPECT_EQ(byte_string, GetMediumOrLargeCord()); } TEST(ByteStringTest, BorrowedArenaCord) { google::protobuf::Arena arena; Owner owner = Owner::Arena(&arena); ByteString byte_string = ByteString(owner, GetMediumOrLargeCord()); EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kMedium); EXPECT_EQ(byte_string.GetArena(), &arena); EXPECT_EQ(byte_string, GetMediumOrLargeCord()); } TEST_P(ByteStringTest, CopyConstruct) { ByteString small_byte_string = ByteString(GetAllocator(), GetSmallStringView()); ByteString medium_byte_string = ByteString(GetAllocator(), GetMediumStringView()); ByteString large_byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(ByteString(NewDeleteAllocator(), small_byte_string), small_byte_string); EXPECT_EQ(ByteString(NewDeleteAllocator(), medium_byte_string), medium_byte_string); EXPECT_EQ(ByteString(NewDeleteAllocator(), large_byte_string), large_byte_string); google::protobuf::Arena arena; EXPECT_EQ(ByteString(ArenaAllocator(&arena), small_byte_string), small_byte_string); EXPECT_EQ(ByteString(ArenaAllocator(&arena), medium_byte_string), medium_byte_string); EXPECT_EQ(ByteString(ArenaAllocator(&arena), large_byte_string), large_byte_string); EXPECT_EQ(ByteString(GetAllocator(), small_byte_string), small_byte_string); EXPECT_EQ(ByteString(GetAllocator(), medium_byte_string), medium_byte_string); EXPECT_EQ(ByteString(GetAllocator(), large_byte_string), large_byte_string); EXPECT_EQ(ByteString(small_byte_string), small_byte_string); EXPECT_EQ(ByteString(medium_byte_string), medium_byte_string); EXPECT_EQ(ByteString(large_byte_string), large_byte_string); } TEST_P(ByteStringTest, CopyConstructFromExternal) { ByteString small_byte_string = ByteString::FromExternal(GetSmallStringView()); ByteString medium_byte_string = ByteString::FromExternal(GetMediumStringView()); EXPECT_EQ(ByteString(NewDeleteAllocator(), small_byte_string), small_byte_string); EXPECT_EQ(ByteString(NewDeleteAllocator(), medium_byte_string), medium_byte_string); google::protobuf::Arena arena; EXPECT_EQ(ByteString(ArenaAllocator(&arena), small_byte_string), small_byte_string); EXPECT_EQ(ByteString(ArenaAllocator(&arena), medium_byte_string), medium_byte_string); EXPECT_EQ(ByteString(GetAllocator(), small_byte_string), small_byte_string); EXPECT_EQ(ByteString(GetAllocator(), medium_byte_string), medium_byte_string); EXPECT_EQ(ByteString(small_byte_string), small_byte_string); EXPECT_EQ(ByteString(medium_byte_string), medium_byte_string); } TEST_P(ByteStringTest, MoveConstruct) { const auto& small_byte_string = [this]() { return ByteString(GetAllocator(), GetSmallStringView()); }; const auto& medium_byte_string = [this]() { return ByteString(GetAllocator(), GetMediumStringView()); }; const auto& large_byte_string = [this]() { return ByteString(GetAllocator(), GetMediumOrLargeCord()); }; EXPECT_EQ(ByteString(NewDeleteAllocator(), small_byte_string()), small_byte_string()); EXPECT_EQ(ByteString(NewDeleteAllocator(), medium_byte_string()), medium_byte_string()); EXPECT_EQ(ByteString(NewDeleteAllocator(), large_byte_string()), large_byte_string()); google::protobuf::Arena arena; EXPECT_EQ(ByteString(ArenaAllocator(&arena), small_byte_string()), small_byte_string()); EXPECT_EQ(ByteString(ArenaAllocator(&arena), medium_byte_string()), medium_byte_string()); EXPECT_EQ(ByteString(ArenaAllocator(&arena), large_byte_string()), large_byte_string()); EXPECT_EQ(ByteString(GetAllocator(), small_byte_string()), small_byte_string()); EXPECT_EQ(ByteString(GetAllocator(), medium_byte_string()), medium_byte_string()); EXPECT_EQ(ByteString(GetAllocator(), large_byte_string()), large_byte_string()); EXPECT_EQ(ByteString(small_byte_string()), small_byte_string()); EXPECT_EQ(ByteString(medium_byte_string()), medium_byte_string()); EXPECT_EQ(ByteString(large_byte_string()), large_byte_string()); } TEST_P(ByteStringTest, MoveConstructFromExternal) { const auto& small_byte_string = []() { return ByteString::FromExternal(GetSmallStringView()); }; const auto& medium_byte_string = []() { return ByteString::FromExternal(GetMediumStringView()); }; EXPECT_EQ(ByteString(NewDeleteAllocator(), small_byte_string()), small_byte_string()); EXPECT_EQ(ByteString(NewDeleteAllocator(), medium_byte_string()), medium_byte_string()); google::protobuf::Arena arena; EXPECT_EQ(ByteString(ArenaAllocator(&arena), small_byte_string()), small_byte_string()); EXPECT_EQ(ByteString(ArenaAllocator(&arena), medium_byte_string()), medium_byte_string()); EXPECT_EQ(ByteString(GetAllocator(), small_byte_string()), small_byte_string()); EXPECT_EQ(ByteString(GetAllocator(), medium_byte_string()), medium_byte_string()); EXPECT_EQ(ByteString(small_byte_string()), small_byte_string()); EXPECT_EQ(ByteString(medium_byte_string()), medium_byte_string()); } TEST_P(ByteStringTest, CopyFromByteString) { ByteString small_byte_string = ByteString(GetAllocator(), GetSmallStringView()); ByteString medium_byte_string = ByteString(GetAllocator(), GetMediumStringView()); ByteString large_byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); ByteString new_delete_byte_string(NewDeleteAllocator<>{}); // Small <= Small new_delete_byte_string = small_byte_string; EXPECT_EQ(new_delete_byte_string, small_byte_string); // Small <= Medium new_delete_byte_string = medium_byte_string; EXPECT_EQ(new_delete_byte_string, medium_byte_string); // Medium <= Medium new_delete_byte_string = medium_byte_string; EXPECT_EQ(new_delete_byte_string, medium_byte_string); // Medium <= Large new_delete_byte_string = large_byte_string; EXPECT_EQ(new_delete_byte_string, large_byte_string); // Large <= Large new_delete_byte_string = large_byte_string; EXPECT_EQ(new_delete_byte_string, large_byte_string); // Large <= Small new_delete_byte_string = small_byte_string; EXPECT_EQ(new_delete_byte_string, small_byte_string); // Small <= Large new_delete_byte_string = large_byte_string; EXPECT_EQ(new_delete_byte_string, large_byte_string); // Large <= Medium new_delete_byte_string = medium_byte_string; EXPECT_EQ(new_delete_byte_string, medium_byte_string); // Medium <= Small new_delete_byte_string = small_byte_string; EXPECT_EQ(new_delete_byte_string, small_byte_string); google::protobuf::Arena arena; ByteString arena_byte_string(ArenaAllocator<>{&arena}); // Small <= Small arena_byte_string = small_byte_string; EXPECT_EQ(arena_byte_string, small_byte_string); // Small <= Medium arena_byte_string = medium_byte_string; EXPECT_EQ(arena_byte_string, medium_byte_string); // Medium <= Medium arena_byte_string = medium_byte_string; EXPECT_EQ(arena_byte_string, medium_byte_string); // Medium <= Large arena_byte_string = large_byte_string; EXPECT_EQ(arena_byte_string, large_byte_string); // Large <= Large arena_byte_string = large_byte_string; EXPECT_EQ(arena_byte_string, large_byte_string); // Large <= Small arena_byte_string = small_byte_string; EXPECT_EQ(arena_byte_string, small_byte_string); // Small <= Large arena_byte_string = large_byte_string; EXPECT_EQ(arena_byte_string, large_byte_string); // Large <= Medium arena_byte_string = medium_byte_string; EXPECT_EQ(arena_byte_string, medium_byte_string); // Medium <= Small arena_byte_string = small_byte_string; EXPECT_EQ(arena_byte_string, small_byte_string); ByteString allocator_byte_string(GetAllocator()); // Small <= Small allocator_byte_string = small_byte_string; EXPECT_EQ(allocator_byte_string, small_byte_string); // Small <= Medium allocator_byte_string = medium_byte_string; EXPECT_EQ(allocator_byte_string, medium_byte_string); // Medium <= Medium allocator_byte_string = medium_byte_string; EXPECT_EQ(allocator_byte_string, medium_byte_string); // Medium <= Large allocator_byte_string = large_byte_string; EXPECT_EQ(allocator_byte_string, large_byte_string); // Large <= Large allocator_byte_string = large_byte_string; EXPECT_EQ(allocator_byte_string, large_byte_string); // Large <= Small allocator_byte_string = small_byte_string; EXPECT_EQ(allocator_byte_string, small_byte_string); // Small <= Large allocator_byte_string = large_byte_string; EXPECT_EQ(allocator_byte_string, large_byte_string); // Large <= Medium allocator_byte_string = medium_byte_string; EXPECT_EQ(allocator_byte_string, medium_byte_string); // Medium <= Small allocator_byte_string = small_byte_string; EXPECT_EQ(allocator_byte_string, small_byte_string); // Miscellaneous cases not covered above. // Large <= Medium Arena String ByteString large_new_delete_byte_string(NewDeleteAllocator<>{}, GetMediumOrLargeCord()); ByteString medium_arena_byte_string(ArenaAllocator<>{&arena}, GetMediumStringView()); large_new_delete_byte_string = medium_arena_byte_string; EXPECT_EQ(large_new_delete_byte_string, medium_arena_byte_string); } TEST_P(ByteStringTest, MoveFrom) { const auto& small_byte_string = [this]() { return ByteString(GetAllocator(), GetSmallStringView()); }; const auto& medium_byte_string = [this]() { return ByteString(GetAllocator(), GetMediumStringView()); }; const auto& large_byte_string = [this]() { return ByteString(GetAllocator(), GetMediumOrLargeCord()); }; ByteString new_delete_byte_string(NewDeleteAllocator<>{}); // Small <= Small new_delete_byte_string = small_byte_string(); EXPECT_EQ(new_delete_byte_string, small_byte_string()); // Small <= Medium new_delete_byte_string = medium_byte_string(); EXPECT_EQ(new_delete_byte_string, medium_byte_string()); // Medium <= Medium new_delete_byte_string = medium_byte_string(); EXPECT_EQ(new_delete_byte_string, medium_byte_string()); // Medium <= Large new_delete_byte_string = large_byte_string(); EXPECT_EQ(new_delete_byte_string, large_byte_string()); // Large <= Large new_delete_byte_string = large_byte_string(); EXPECT_EQ(new_delete_byte_string, large_byte_string()); // Large <= Small new_delete_byte_string = small_byte_string(); EXPECT_EQ(new_delete_byte_string, small_byte_string()); // Small <= Large new_delete_byte_string = large_byte_string(); EXPECT_EQ(new_delete_byte_string, large_byte_string()); // Large <= Medium new_delete_byte_string = medium_byte_string(); EXPECT_EQ(new_delete_byte_string, medium_byte_string()); // Medium <= Small new_delete_byte_string = small_byte_string(); EXPECT_EQ(new_delete_byte_string, small_byte_string()); google::protobuf::Arena arena; ByteString arena_byte_string(ArenaAllocator<>{&arena}); // Small <= Small arena_byte_string = small_byte_string(); EXPECT_EQ(arena_byte_string, small_byte_string()); // Small <= Medium arena_byte_string = medium_byte_string(); EXPECT_EQ(arena_byte_string, medium_byte_string()); // Medium <= Medium arena_byte_string = medium_byte_string(); EXPECT_EQ(arena_byte_string, medium_byte_string()); // Medium <= Large arena_byte_string = large_byte_string(); EXPECT_EQ(arena_byte_string, large_byte_string()); // Large <= Large arena_byte_string = large_byte_string(); EXPECT_EQ(arena_byte_string, large_byte_string()); // Large <= Small arena_byte_string = small_byte_string(); EXPECT_EQ(arena_byte_string, small_byte_string()); // Small <= Large arena_byte_string = large_byte_string(); EXPECT_EQ(arena_byte_string, large_byte_string()); // Large <= Medium arena_byte_string = medium_byte_string(); EXPECT_EQ(arena_byte_string, medium_byte_string()); // Medium <= Small arena_byte_string = small_byte_string(); EXPECT_EQ(arena_byte_string, small_byte_string()); ByteString allocator_byte_string(GetAllocator()); // Small <= Small allocator_byte_string = small_byte_string(); EXPECT_EQ(allocator_byte_string, small_byte_string()); // Small <= Medium allocator_byte_string = medium_byte_string(); EXPECT_EQ(allocator_byte_string, medium_byte_string()); // Medium <= Medium allocator_byte_string = medium_byte_string(); EXPECT_EQ(allocator_byte_string, medium_byte_string()); // Medium <= Large allocator_byte_string = large_byte_string(); EXPECT_EQ(allocator_byte_string, large_byte_string()); // Large <= Large allocator_byte_string = large_byte_string(); EXPECT_EQ(allocator_byte_string, large_byte_string()); // Large <= Small allocator_byte_string = small_byte_string(); EXPECT_EQ(allocator_byte_string, small_byte_string()); // Small <= Large allocator_byte_string = large_byte_string(); EXPECT_EQ(allocator_byte_string, large_byte_string()); // Large <= Medium allocator_byte_string = medium_byte_string(); EXPECT_EQ(allocator_byte_string, medium_byte_string()); // Medium <= Small allocator_byte_string = small_byte_string(); EXPECT_EQ(allocator_byte_string, small_byte_string()); // Miscellaneous cases not covered above. // Large <= Medium Arena String ByteString large_new_delete_byte_string(NewDeleteAllocator<>{}, GetMediumOrLargeCord()); ByteString medium_arena_byte_string(ArenaAllocator<>{&arena}, GetMediumStringView()); large_new_delete_byte_string = std::move(medium_arena_byte_string); EXPECT_EQ(large_new_delete_byte_string, GetMediumStringView()); } TEST_P(ByteStringTest, Swap) { using std::swap; ByteString empty_byte_string(GetAllocator()); ByteString small_byte_string = ByteString(GetAllocator(), GetSmallStringView()); ByteString medium_byte_string = ByteString(GetAllocator(), GetMediumStringView()); ByteString large_byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); // Small <=> Small swap(empty_byte_string, small_byte_string); EXPECT_EQ(empty_byte_string, GetSmallStringView()); EXPECT_EQ(small_byte_string, ""); swap(empty_byte_string, small_byte_string); EXPECT_EQ(empty_byte_string, ""); EXPECT_EQ(small_byte_string, GetSmallStringView()); // Small <=> Medium swap(small_byte_string, medium_byte_string); EXPECT_EQ(small_byte_string, GetMediumStringView()); EXPECT_EQ(medium_byte_string, GetSmallStringView()); swap(small_byte_string, medium_byte_string); EXPECT_EQ(small_byte_string, GetSmallStringView()); EXPECT_EQ(medium_byte_string, GetMediumStringView()); // Small <=> Large swap(small_byte_string, large_byte_string); EXPECT_EQ(small_byte_string, GetMediumOrLargeCord()); EXPECT_EQ(large_byte_string, GetSmallStringView()); swap(small_byte_string, large_byte_string); EXPECT_EQ(small_byte_string, GetSmallStringView()); EXPECT_EQ(large_byte_string, GetMediumOrLargeCord()); // Medium <=> Medium static constexpr absl::string_view kDifferentMediumStringView = "A different string that is too large for the small string optimization!"; ByteString other_medium_byte_string = ByteString(GetAllocator(), kDifferentMediumStringView); swap(medium_byte_string, other_medium_byte_string); EXPECT_EQ(medium_byte_string, kDifferentMediumStringView); EXPECT_EQ(other_medium_byte_string, GetMediumStringView()); swap(medium_byte_string, other_medium_byte_string); EXPECT_EQ(medium_byte_string, GetMediumStringView()); EXPECT_EQ(other_medium_byte_string, kDifferentMediumStringView); // Medium <=> Large swap(medium_byte_string, large_byte_string); EXPECT_EQ(medium_byte_string, GetMediumOrLargeCord()); EXPECT_EQ(large_byte_string, GetMediumStringView()); swap(medium_byte_string, large_byte_string); EXPECT_EQ(medium_byte_string, GetMediumStringView()); EXPECT_EQ(large_byte_string, GetMediumOrLargeCord()); // Large <=> Large const absl::Cord different_medium_or_large_cord = absl::Cord(kDifferentMediumStringView); ByteString other_large_byte_string = ByteString(GetAllocator(), different_medium_or_large_cord); swap(large_byte_string, other_large_byte_string); EXPECT_EQ(large_byte_string, different_medium_or_large_cord); EXPECT_EQ(other_large_byte_string, GetMediumStringView()); swap(large_byte_string, other_large_byte_string); EXPECT_EQ(large_byte_string, GetMediumStringView()); EXPECT_EQ(other_large_byte_string, different_medium_or_large_cord); // Miscellaneous cases not covered above. These do not swap a second time to // restore state, so they are destructive. // Small <=> Different Allocator Medium ByteString medium_new_delete_byte_string = ByteString(NewDeleteAllocator<>{}, kDifferentMediumStringView); swap(empty_byte_string, medium_new_delete_byte_string); EXPECT_EQ(empty_byte_string, kDifferentMediumStringView); EXPECT_EQ(medium_new_delete_byte_string, ""); // Small <=> Different Allocator Large ByteString large_new_delete_byte_string = ByteString(NewDeleteAllocator<>{}, GetMediumOrLargeCord()); swap(small_byte_string, large_new_delete_byte_string); EXPECT_EQ(small_byte_string, GetMediumOrLargeCord()); EXPECT_EQ(large_new_delete_byte_string, GetSmallStringView()); // Medium <=> Different Allocator Large large_new_delete_byte_string = ByteString(NewDeleteAllocator<>{}, different_medium_or_large_cord); swap(medium_byte_string, large_new_delete_byte_string); EXPECT_EQ(medium_byte_string, different_medium_or_large_cord); EXPECT_EQ(large_new_delete_byte_string, GetMediumStringView()); // Medium <=> Different Allocator Medium medium_byte_string = ByteString(GetAllocator(), GetMediumStringView()); medium_new_delete_byte_string = ByteString(NewDeleteAllocator<>{}, kDifferentMediumStringView); swap(medium_byte_string, medium_new_delete_byte_string); EXPECT_EQ(medium_byte_string, kDifferentMediumStringView); EXPECT_EQ(medium_new_delete_byte_string, GetMediumStringView()); } TEST_P(ByteStringTest, FlattenSmall) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string.Flatten(), GetSmallStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); } TEST_P(ByteStringTest, FlattenMedium) { ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); EXPECT_EQ(byte_string.Flatten(), GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); } TEST_P(ByteStringTest, FlattenLarge) { if (GetAllocator().arena() != nullptr) { GTEST_SKIP(); } ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); EXPECT_EQ(byte_string.Flatten(), GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); } TEST_P(ByteStringTest, TryFlatSmall) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_THAT(byte_string.TryFlat(), Optional(GetSmallStringView())); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); } TEST_P(ByteStringTest, TryFlatMedium) { ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); EXPECT_THAT(byte_string.TryFlat(), Optional(GetMediumStringView())); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); } TEST_P(ByteStringTest, TryFlatLarge) { if (GetAllocator().arena() != nullptr) { GTEST_SKIP(); } ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeFragmentedCord()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); EXPECT_THAT(byte_string.TryFlat(), Eq(absl::nullopt)); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); } TEST_P(ByteStringTest, Equals) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_TRUE(byte_string.Equals(GetMediumStringView())); } TEST_P(ByteStringTest, Compare) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(byte_string.Compare(GetMediumStringView()), 0); EXPECT_EQ(byte_string.Compare(GetMediumOrLargeCord()), 0); } TEST_P(ByteStringTest, StartsWith) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_TRUE(byte_string.StartsWith( GetMediumStringView().substr(0, kSmallByteStringCapacity))); EXPECT_TRUE(byte_string.StartsWith( GetMediumOrLargeCord().Subcord(0, kSmallByteStringCapacity))); } TEST_P(ByteStringTest, EndsWith) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_TRUE(byte_string.EndsWith( GetMediumStringView().substr(kSmallByteStringCapacity))); EXPECT_TRUE(byte_string.EndsWith(GetMediumOrLargeCord().Subcord( kSmallByteStringCapacity, GetMediumOrLargeCord().size() - kSmallByteStringCapacity))); } TEST_P(ByteStringTest, Find) { ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); // Find string_view EXPECT_THAT(byte_string.Find("A string"), Optional(0)); EXPECT_THAT( byte_string.Find("small string optimization!"), Optional(GetMediumStringView().find("small string optimization!"))); EXPECT_THAT(byte_string.Find("not found"), Eq(absl::nullopt)); EXPECT_THAT(byte_string.Find(""), Optional(0)); EXPECT_THAT(byte_string.Find("", 3), Optional(3)); EXPECT_THAT(byte_string.Find("A string", 1), Eq(absl::nullopt)); // Find cord EXPECT_THAT(byte_string.Find(absl::Cord("A string")), Optional(0)); EXPECT_THAT( byte_string.Find(absl::Cord("small string optimization!")), Optional(GetMediumStringView().find("small string optimization!"))); EXPECT_THAT( byte_string.Find(absl::MakeFragmentedCord( {"A string", " that is too large for the small string optimization!", " extra"})), Eq(absl::nullopt)); EXPECT_THAT(byte_string.Find(GetMediumOrLargeFragmentedCord()), Optional(0)); EXPECT_THAT(byte_string.Find(absl::Cord("not found")), Eq(absl::nullopt)); EXPECT_THAT(byte_string.Find(absl::Cord("")), Optional(0)); EXPECT_THAT(byte_string.Find(absl::Cord(""), 3), Optional(3)); } TEST_P(ByteStringTest, FindEdgeCases) { ByteString empty_byte_string(GetAllocator(), ""); EXPECT_THAT(empty_byte_string.Find("a"), Eq(absl::nullopt)); EXPECT_THAT(empty_byte_string.Find(""), Optional(0)); ByteString cord_byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_THAT(cord_byte_string.Find("not found"), Eq(absl::nullopt)); ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); // Needle longer than haystack. EXPECT_THAT(byte_string.Find(std::string(byte_string.size() + 1, 'a')), Eq(absl::nullopt)); // Needle at the end. absl::string_view suffix = "optimization!"; EXPECT_THAT(byte_string.Find(suffix), Optional(byte_string.size() - suffix.size())); // pos at the end. EXPECT_THAT(byte_string.Find("a", byte_string.size()), Eq(absl::nullopt)); EXPECT_THAT(byte_string.Find("", byte_string.size()), Optional(byte_string.size())); // Search in a cord-backed ByteString with pos > 0. EXPECT_THAT(cord_byte_string.Find("string", 1), Optional(GetMediumStringView().find("string", 1))); // Needle at the end of a cord-backed ByteString. absl::string_view suffix_sv = "optimization!"; EXPECT_THAT(cord_byte_string.Find(suffix_sv), Optional(cord_byte_string.size() - suffix_sv.size())); EXPECT_THAT(cord_byte_string.Find(absl::Cord(suffix_sv)), Optional(cord_byte_string.size() - suffix_sv.size())); // Fragmented needle with empty first chunk. absl::Cord fragmented_with_empty_chunk; fragmented_with_empty_chunk.Append(""); fragmented_with_empty_chunk.Append("A string"); EXPECT_THAT(byte_string.Find(fragmented_with_empty_chunk), Optional(0)); // Search with fragmented cord needle on string_view backed ByteString with // partial match. ByteString partial_match_haystack(GetAllocator(), "abababac"); absl::Cord partial_match_needle = absl::MakeFragmentedCord({"aba", "c"}); EXPECT_THAT(partial_match_haystack.Find(partial_match_needle), Optional(4)); // Search with fragmented cord needle where first chunk is found but not // enough space for the rest. ByteString short_haystack(GetAllocator(), "abcdefg"); absl::Cord needle_too_long = absl::MakeFragmentedCord({"ef", "gh"}); EXPECT_THAT(short_haystack.Find(needle_too_long), Eq(absl::nullopt)); // Search with a fragmented empty cord. absl::Cord fragmented_empty_cord = absl::MakeFragmentedCord({"", ""}); EXPECT_THAT(byte_string.Find(fragmented_empty_cord), Optional(0)); EXPECT_THAT(byte_string.Find(fragmented_empty_cord, 3), Optional(3)); // Search for suffix in a fragmented cord. ByteString fragmented_cord_byte_string(GetAllocator(), GetMediumOrLargeFragmentedCord()); EXPECT_THAT(fragmented_cord_byte_string.Find(suffix_sv), Optional(fragmented_cord_byte_string.size() - suffix_sv.size())); EXPECT_THAT(fragmented_cord_byte_string.Find(absl::Cord(suffix_sv)), Optional(fragmented_cord_byte_string.size() - suffix_sv.size())); } #ifndef NDEBUG TEST_P(ByteStringTest, FindOutOfBounds) { ByteString byte_string = ByteString(GetAllocator(), "test"); EXPECT_DEATH(byte_string.Find("t", 5), _); } #endif TEST_P(ByteStringTest, Substring) { // small byte_string substring ByteString small_byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(small_byte_string.Substring(1, 5), GetSmallStringView().substr(1, 4)); EXPECT_EQ(small_byte_string.Substring(0, small_byte_string.size()), GetSmallStringView()); EXPECT_EQ(small_byte_string.Substring(1, 1), ""); // medium byte_string substring ByteString medium_byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(medium_byte_string.Substring(2, 12), GetMediumStringView().substr(2, 10)); EXPECT_EQ(medium_byte_string.Substring(0, medium_byte_string.size()), GetMediumStringView()); // large byte_string substring ByteString large_byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(large_byte_string.Substring(3, 15), GetMediumOrLargeCord().Subcord(3, 12)); EXPECT_EQ(large_byte_string.Substring(0, large_byte_string.size()), GetMediumOrLargeCord()); // substring with one parameter ByteString tacocat_byte_string = ByteString(GetAllocator(), "tacocat"); EXPECT_EQ(tacocat_byte_string.Substring(4), "cat"); } TEST_P(ByteStringTest, SubstringEdgeCases) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(byte_string.Substring(byte_string.size(), byte_string.size()), ""); EXPECT_EQ(byte_string.Substring(0, 0), ""); } #ifndef NDEBUG TEST_P(ByteStringTest, SubstringOutOfBounds) { ByteString byte_string = ByteString(GetAllocator(), "test"); EXPECT_DEATH(static_cast(byte_string.Substring(5, 5)), _); EXPECT_DEATH(static_cast(byte_string.Substring(0, 5)), _); EXPECT_DEATH(static_cast(byte_string.Substring(3, 2)), _); } #endif TEST_P(ByteStringTest, RemovePrefixSmall) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); byte_string.RemovePrefix(1); EXPECT_EQ(byte_string, GetSmallStringView().substr(1)); } TEST_P(ByteStringTest, RemovePrefixMedium) { ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); byte_string.RemovePrefix(byte_string.size() - kSmallByteStringCapacity); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string, GetMediumStringView().substr(GetMediumStringView().size() - kSmallByteStringCapacity)); } TEST_P(ByteStringTest, RemovePrefixMediumOrLarge) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); byte_string.RemovePrefix(byte_string.size() - kSmallByteStringCapacity); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string, GetMediumStringView().substr(GetMediumStringView().size() - kSmallByteStringCapacity)); } TEST_P(ByteStringTest, RemoveSuffixSmall) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); byte_string.RemoveSuffix(1); EXPECT_EQ(byte_string, GetSmallStringView().substr(0, GetSmallStringView().size() - 1)); } TEST_P(ByteStringTest, RemoveSuffixMedium) { ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); byte_string.RemoveSuffix(byte_string.size() - kSmallByteStringCapacity); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string, GetMediumStringView().substr(0, kSmallByteStringCapacity)); } TEST_P(ByteStringTest, RemoveSuffixMediumOrLarge) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); byte_string.RemoveSuffix(byte_string.size() - kSmallByteStringCapacity); EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); EXPECT_EQ(byte_string, GetMediumStringView().substr(0, kSmallByteStringCapacity)); } TEST_P(ByteStringTest, ToStringSmall) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(byte_string.ToString(), byte_string); } TEST_P(ByteStringTest, ToStringMedium) { ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(byte_string.ToString(), byte_string); } TEST_P(ByteStringTest, ToStringLarge) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(byte_string.ToString(), byte_string); } TEST_P(ByteStringTest, ToStringViewSmall) { std::string scratch; ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(byte_string.ToStringView(&scratch), GetSmallStringView()); } TEST_P(ByteStringTest, ToStringViewMedium) { std::string scratch; ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(byte_string.ToStringView(&scratch), GetMediumStringView()); } TEST_P(ByteStringTest, ToStringViewLarge) { std::string scratch; ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(byte_string.ToStringView(&scratch), GetMediumOrLargeCord()); } TEST_P(ByteStringTest, AsStringViewSmall) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(byte_string.AsStringView(), GetSmallStringView()); } TEST_P(ByteStringTest, AsStringViewMedium) { ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(byte_string.AsStringView(), GetMediumStringView()); } TEST_P(ByteStringTest, AsStringViewLarge) { ByteString byte_string = ByteString(GetMediumOrLargeCord()); EXPECT_DEATH(byte_string.AsStringView(), _); } TEST_P(ByteStringTest, CopyToStringSmall) { std::string out; ByteString(GetAllocator(), GetSmallStringView()).CopyToString(&out); EXPECT_EQ(out, GetSmallStringView()); } TEST_P(ByteStringTest, CopyToStringMedium) { std::string out; ByteString(GetAllocator(), GetMediumStringView()).CopyToString(&out); EXPECT_EQ(out, GetMediumStringView()); } TEST_P(ByteStringTest, CopyToStringLarge) { std::string out; ByteString(GetAllocator(), GetMediumOrLargeCord()).CopyToString(&out); EXPECT_EQ(out, GetMediumOrLargeCord()); } TEST_P(ByteStringTest, AppendToStringSmall) { std::string out; ByteString(GetAllocator(), GetSmallStringView()).AppendToString(&out); EXPECT_EQ(out, GetSmallStringView()); } TEST_P(ByteStringTest, AppendToStringMedium) { std::string out; ByteString(GetAllocator(), GetMediumStringView()).AppendToString(&out); EXPECT_EQ(out, GetMediumStringView()); } TEST_P(ByteStringTest, AppendToStringLarge) { std::string out; ByteString(GetAllocator(), GetMediumOrLargeCord()).AppendToString(&out); EXPECT_EQ(out, GetMediumOrLargeCord()); } TEST_P(ByteStringTest, ToCordSmall) { ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(byte_string.ToCord(), byte_string); EXPECT_EQ(std::move(byte_string).ToCord(), GetSmallStringView()); } TEST_P(ByteStringTest, ToCordMedium) { ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(byte_string.ToCord(), byte_string); EXPECT_EQ(std::move(byte_string).ToCord(), GetMediumStringView()); } TEST_P(ByteStringTest, ToCordLarge) { ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(byte_string.ToCord(), byte_string); EXPECT_EQ(std::move(byte_string).ToCord(), GetMediumOrLargeCord()); } TEST_P(ByteStringTest, CopyToCordSmall) { absl::Cord out; ByteString(GetAllocator(), GetSmallStringView()).CopyToCord(&out); EXPECT_EQ(out, GetSmallStringView()); } TEST_P(ByteStringTest, CopyToCordMedium) { absl::Cord out; ByteString(GetAllocator(), GetMediumStringView()).CopyToCord(&out); EXPECT_EQ(out, GetMediumStringView()); } TEST_P(ByteStringTest, CopyToCordLarge) { absl::Cord out; ByteString(GetAllocator(), GetMediumOrLargeCord()).CopyToCord(&out); EXPECT_EQ(out, GetMediumOrLargeCord()); } TEST_P(ByteStringTest, AppendToCordSmall) { absl::Cord out; ByteString(GetAllocator(), GetSmallStringView()).AppendToCord(&out); EXPECT_EQ(out, GetSmallStringView()); } TEST_P(ByteStringTest, AppendToCordMedium) { absl::Cord out; ByteString(GetAllocator(), GetMediumStringView()).AppendToCord(&out); EXPECT_EQ(out, GetMediumStringView()); } TEST_P(ByteStringTest, AppendToCordLarge) { absl::Cord out; ByteString(GetAllocator(), GetMediumOrLargeCord()).AppendToCord(&out); EXPECT_EQ(out, GetMediumOrLargeCord()); } TEST_P(ByteStringTest, CloneSmall) { google::protobuf::Arena arena; ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(byte_string.Clone(&arena), byte_string); } TEST_P(ByteStringTest, CloneMedium) { google::protobuf::Arena arena; ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(byte_string.Clone(&arena), byte_string); } TEST_P(ByteStringTest, CloneLarge) { google::protobuf::Arena arena; ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(byte_string.Clone(&arena), byte_string); } TEST_P(ByteStringTest, LegacyByteStringSmall) { google::protobuf::Arena arena; ByteString byte_string = ByteString(GetAllocator(), GetSmallStringView()); EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/false, &arena), GetSmallStringView()); EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/true, &arena), GetSmallStringView()); } TEST_P(ByteStringTest, LegacyByteStringMedium) { google::protobuf::Arena arena; ByteString byte_string = ByteString(GetAllocator(), GetMediumStringView()); EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/false, &arena), GetMediumStringView()); EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/true, &arena), GetMediumStringView()); } TEST_P(ByteStringTest, LegacyByteStringLarge) { google::protobuf::Arena arena; ByteString byte_string = ByteString(GetAllocator(), GetMediumOrLargeCord()); EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/false, &arena), GetMediumOrLargeCord()); EXPECT_EQ(LegacyByteString(byte_string, /*stable=*/true, &arena), GetMediumOrLargeCord()); } TEST_P(ByteStringTest, HashValue) { EXPECT_EQ(absl::HashOf(ByteString(GetAllocator(), GetSmallStringView())), absl::HashOf(GetSmallStringView())); EXPECT_EQ(absl::HashOf(ByteString(GetAllocator(), GetMediumStringView())), absl::HashOf(GetMediumStringView())); EXPECT_EQ(absl::HashOf(ByteString(GetAllocator(), GetMediumOrLargeCord())), absl::HashOf(GetMediumOrLargeCord())); } INSTANTIATE_TEST_SUITE_P(ByteStringTest, ByteStringTest, ::testing::Values(AllocatorKind::kNewDelete, AllocatorKind::kArena)); } // namespace } // namespace cel::common_internal ================================================ FILE: common/internal/casting.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/casting.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_CASTING_H_ #define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_CASTING_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/meta/type_traits.h" #include "absl/types/optional.h" #include "internal/casts.h" namespace cel { namespace common_internal { template using propagate_const_t = std::conditional_t>, std::add_const_t, To>; template using propagate_volatile_t = std::conditional_t>, std::add_volatile_t, To>; template using propagate_reference_t = std::conditional_t, std::add_lvalue_reference_t, std::conditional_t, std::add_rvalue_reference_t, To>>; template using propagate_cvref_t = propagate_reference_t< propagate_volatile_t, From>, From>; } // namespace common_internal namespace common_internal { // Implementation of `cel::InstanceOf`. template struct ABSL_DEPRECATED("Use Is member functions instead.") InstanceOfImpl final { static_assert(!std::is_pointer_v, "To must not be a pointer"); static_assert(!std::is_array_v, "To must not be an array"); static_assert(!std::is_lvalue_reference_v, "To must not be a lvalue reference"); static_assert(!std::is_rvalue_reference_v, "To must not be a lvalue reference"); static_assert(!std::is_const_v, "To must not be const qualified"); static_assert(!std::is_volatile_v, "To must not be volatile qualified"); static_assert(std::is_class_v, "To must be a non-union class"); explicit InstanceOfImpl() = default; template ABSL_DEPRECATED("Use Is member functions instead.") ABSL_MUST_USE_RESULT bool operator()(const From& from) const { static_assert(!std::is_volatile_v, "From must not be volatile qualified"); static_assert(std::is_class_v, "From must be a non-union class"); if constexpr (std::is_same_v, To>) { // Same type. Separate from the next `else if` to work on in-complete // types. return true; } else if constexpr (std::is_polymorphic_v && std::is_polymorphic_v> && std::is_base_of_v>) { // Polymorphic upcast. return true; } else if constexpr (!std::is_polymorphic_v && !std::is_polymorphic_v> && (std::is_convertible_v || std::is_convertible_v || std::is_convertible_v || std::is_convertible_v)) { // Implicitly convertible. return true; } else { // Something else. return from.template Is(); } } template ABSL_DEPRECATED("Use Is member functions instead.") ABSL_MUST_USE_RESULT bool operator()(const From* from) const { static_assert(!std::is_volatile_v, "From must not be volatile qualified"); static_assert(std::is_class_v, "From must be a non-union class"); return from != nullptr && (*this)(*from); } }; // Implementation of `cel::Cast`. template struct ABSL_DEPRECATED( "Use explicit conversion functions instead through static_cast.") CastImpl final { static_assert(!std::is_pointer_v, "To must not be a pointer"); static_assert(!std::is_array_v, "To must not be an array"); static_assert(!std::is_lvalue_reference_v, "To must not be a lvalue reference"); static_assert(!std::is_rvalue_reference_v, "To must not be a lvalue reference"); static_assert(!std::is_const_v, "To must not be const qualified"); static_assert(!std::is_volatile_v, "To must not be volatile qualified"); static_assert(std::is_class_v, "To must be a non-union class"); explicit CastImpl() = default; template ABSL_DEPRECATED( "Use explicit conversion functions instead through static_cast.") ABSL_MUST_USE_RESULT decltype(auto) operator()(From&& from) const { static_assert(!std::is_volatile_v, "From must not be volatile qualified"); static_assert(std::is_class_v>, "From must be a non-union class"); if constexpr (std::is_polymorphic_v) { static_assert(std::is_lvalue_reference_v, "polymorphic casts are only possible on lvalue references"); } if constexpr (std::is_same_v, To>) { // Same type. Separate from the next `else if` to work on in-complete // types. return static_cast>(from); } else if constexpr (std::is_polymorphic_v && std::is_polymorphic_v> && std::is_base_of_v>) { // Polymorphic upcast. return static_cast>(from); } else if constexpr (std::is_polymorphic_v && std::is_polymorphic_v> && std::is_base_of_v, To>) { // Polymorphic downcast. return cel::internal::down_cast>( std::forward(from)); } else if constexpr (std::is_convertible_v && !std::is_polymorphic_v && !std::is_polymorphic_v>) { return static_cast(std::forward(from)); } else { // Something else. return std::forward(from).template Get(); } } template ABSL_DEPRECATED( "Use explicit conversion functions instead through static_cast.") ABSL_MUST_USE_RESULT decltype(auto) operator()(From* from) const { static_assert(!std::is_volatile_v, "From must not be volatile qualified"); static_assert(std::is_class_v, "From must be a non-union class"); using R = decltype((*this)(*from)); static_assert(std::is_lvalue_reference_v); if (from == nullptr) { return static_cast>>( nullptr); } return static_cast>>( std::addressof((*this)(*from))); } }; // Implementation of `cel::As`. template struct ABSL_DEPRECATED("Use As member functions instead.") AsImpl final { static_assert(!std::is_pointer_v, "To must not be a pointer"); static_assert(!std::is_array_v, "To must not be an array"); static_assert(!std::is_lvalue_reference_v, "To must not be a lvalue reference"); static_assert(!std::is_rvalue_reference_v, "To must not be a lvalue reference"); static_assert(!std::is_const_v, "To must not be const qualified"); static_assert(!std::is_volatile_v, "To must not be volatile qualified"); static_assert(std::is_class_v, "To must be a non-union class"); explicit AsImpl() = default; template ABSL_DEPRECATED("Use As member functions instead.") ABSL_MUST_USE_RESULT decltype(auto) operator()(From&& from) const { // Returns either `absl::optional` or `cel::optional_ref` // depending on the return type of `CastTraits::Convert`. The use of these // two types is an implementation detail. static_assert(!std::is_volatile_v, "From must not be volatile qualified"); static_assert(std::is_class_v>, "From must be a non-union class"); return std::forward(from).template As(); } // Returns a pointer. template ABSL_DEPRECATED("Use As member functions instead.") ABSL_MUST_USE_RESULT decltype(auto) operator()(From* from) const { // Returns either `absl::optional` or `To*` depending on the return type of // `CastTraits::Convert`. The use of these two types is an implementation // detail. static_assert(!std::is_volatile_v, "From must not be volatile qualified"); static_assert(std::is_class_v, "From must be a non-union class"); using R = decltype(from->template As()); if (from == nullptr) { return R{absl::nullopt}; } return from->template As(); } }; } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_CASTING_H_ ================================================ FILE: common/internal/metadata.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_METADATA_H_ #define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_METADATA_H_ #include #include "google/protobuf/arena.h" namespace cel::common_internal { // `google::protobuf::Arena` has a minimum alignment of 8. `ReferenceCount` has a minimum // alignment that is guaranteed to be greater than or equal to `google::protobuf::Arena`. inline constexpr uintptr_t kMetadataOwnerNone = 0; inline constexpr uintptr_t kMetadataOwnerReferenceCountBit = uintptr_t{1} << 0; inline constexpr uintptr_t kMetadataOwnerArenaBit = uintptr_t{1} << 1; inline constexpr uintptr_t kMetadataOwnerBits = alignof(google::protobuf::Arena) - 1; inline constexpr uintptr_t kMetadataOwnerPointerMask = ~kMetadataOwnerBits; // Ensure kMetadataOwnerBits encompasses kMetadataOwnerReferenceCountBit and // kMetadataOwnerArenaBit. static_assert((kMetadataOwnerBits | kMetadataOwnerReferenceCountBit) == kMetadataOwnerBits); static_assert((kMetadataOwnerBits | kMetadataOwnerArenaBit) == kMetadataOwnerBits); } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_METADATA_H_ ================================================ FILE: common/internal/reference_count.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/internal/reference_count.h" #include #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "common/data.h" #include "internal/new.h" #include "google/protobuf/message_lite.h" namespace cel::common_internal { template class DeletingReferenceCount; namespace { class ReferenceCountedStdString final : public ReferenceCounted { public: static std::pair New( std::string&& string) { const auto* const refcount = new ReferenceCountedStdString(std::move(string)); const auto* const refcount_string = std::launder( reinterpret_cast(&refcount->string_[0])); return std::pair{static_cast(refcount), absl::string_view(*refcount_string)}; } explicit ReferenceCountedStdString(std::string&& string) { (::new (static_cast(&string_[0])) std::string(std::move(string))) ->shrink_to_fit(); } private: void Finalize() noexcept override { std::destroy_at(std::launder(reinterpret_cast(&string_[0]))); } alignas(std::string) char string_[sizeof(std::string)]; }; class ReferenceCountedString final : public ReferenceCounted { public: static std::pair New( absl::string_view string) { const auto* const refcount = ::new (internal::New(Overhead() + string.size())) ReferenceCountedString(string); return std::pair{static_cast(refcount), absl::string_view(refcount->data_, refcount->size_)}; } private: // ReferenceCountedString is non-standard-layout due to having virtual functions // from a base class. This causes compilers to warn about the use of offsetof(), // but it still works here, so silence the warning and proceed. #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Winvalid-offsetof" #endif static size_t Overhead() { return offsetof(ReferenceCountedString, data_); } #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif explicit ReferenceCountedString(absl::string_view string) : size_(string.size()) { std::memcpy(data_, string.data(), size_); } void Delete() noexcept override { void* const that = this; const auto size = size_; std::destroy_at(this); internal::SizedDelete(that, Overhead() + size); } const size_t size_; char data_[]; }; } // namespace std::pair MakeReferenceCountedString(absl::string_view value) { ABSL_DCHECK(!value.empty()); return ReferenceCountedString::New(value); } std::pair MakeReferenceCountedString(std::string&& value) { ABSL_DCHECK(!value.empty()); return ReferenceCountedStdString::New(std::move(value)); } } // namespace cel::common_internal ================================================ FILE: common/internal/reference_count.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This header contains primitives for reference counting, roughly equivalent to // the primitives used to implement `std::shared_ptr`. These primitives should // not be used directly in most cases, instead `cel::Shared` should be // used instead. #ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_REFERENCE_COUNT_H_ #define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_REFERENCE_COUNT_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "common/data.h" #include "google/protobuf/arena.h" #include "google/protobuf/message_lite.h" namespace cel::common_internal { struct AdoptRef final { explicit AdoptRef() = default; }; inline constexpr AdoptRef kAdoptRef{}; class ReferenceCount; struct ReferenceCountFromThis; void SetReferenceCountForThat(ReferenceCountFromThis& that, ReferenceCount* absl_nullable refcount); ReferenceCount* absl_nullable GetReferenceCountForThat( const ReferenceCountFromThis& that); // `ReferenceCountFromThis` is similar to `std::enable_shared_from_this`. It // allows the derived object to inspect its own reference count. It should not // be used directly, but should be used through // `cel::EnableManagedMemoryFromThis`. struct ReferenceCountFromThis { private: friend void SetReferenceCountForThat(ReferenceCountFromThis& that, ReferenceCount* absl_nullable refcount); friend ReferenceCount* absl_nullable GetReferenceCountForThat( const ReferenceCountFromThis& that); static constexpr uintptr_t kNullPtr = uintptr_t{0}; static constexpr uintptr_t kSentinelPtr = ~kNullPtr; void* absl_nullable refcount = reinterpret_cast(kSentinelPtr); }; inline void SetReferenceCountForThat(ReferenceCountFromThis& that, ReferenceCount* absl_nullable refcount) { ABSL_DCHECK_EQ(that.refcount, reinterpret_cast(ReferenceCountFromThis::kSentinelPtr)); that.refcount = static_cast(refcount); } inline ReferenceCount* absl_nullable GetReferenceCountForThat( const ReferenceCountFromThis& that) { ABSL_DCHECK_NE(that.refcount, reinterpret_cast(ReferenceCountFromThis::kSentinelPtr)); return static_cast(that.refcount); } void StrongRef(const ReferenceCount& refcount) noexcept; void StrongRef(const ReferenceCount* absl_nullable refcount) noexcept; void StrongUnref(const ReferenceCount& refcount) noexcept; void StrongUnref(const ReferenceCount* absl_nullable refcount) noexcept; ABSL_MUST_USE_RESULT bool StrengthenRef(const ReferenceCount& refcount) noexcept; ABSL_MUST_USE_RESULT bool StrengthenRef(const ReferenceCount* absl_nullable refcount) noexcept; void WeakRef(const ReferenceCount& refcount) noexcept; void WeakRef(const ReferenceCount* absl_nullable refcount) noexcept; void WeakUnref(const ReferenceCount& refcount) noexcept; void WeakUnref(const ReferenceCount* absl_nullable refcount) noexcept; ABSL_MUST_USE_RESULT bool IsUniqueRef(const ReferenceCount& refcount) noexcept; ABSL_MUST_USE_RESULT bool IsUniqueRef(const ReferenceCount* absl_nullable refcount) noexcept; ABSL_MUST_USE_RESULT bool IsExpiredRef(const ReferenceCount& refcount) noexcept; ABSL_MUST_USE_RESULT bool IsExpiredRef(const ReferenceCount* absl_nullable refcount) noexcept; // `ReferenceCount` is similar to the control block used by `std::shared_ptr`. // It is not meant to be interacted with directly in most cases, instead // `cel::Shared` should be used. class alignas(8) ReferenceCount { public: ReferenceCount() = default; ReferenceCount(const ReferenceCount&) = delete; ReferenceCount(ReferenceCount&&) = delete; ReferenceCount& operator=(const ReferenceCount&) = delete; ReferenceCount& operator=(ReferenceCount&&) = delete; virtual ~ReferenceCount() = default; private: friend void StrongRef(const ReferenceCount& refcount) noexcept; friend void StrongUnref(const ReferenceCount& refcount) noexcept; friend bool StrengthenRef(const ReferenceCount& refcount) noexcept; friend void WeakRef(const ReferenceCount& refcount) noexcept; friend void WeakUnref(const ReferenceCount& refcount) noexcept; friend bool IsUniqueRef(const ReferenceCount& refcount) noexcept; friend bool IsExpiredRef(const ReferenceCount& refcount) noexcept; virtual void Finalize() noexcept = 0; virtual void Delete() noexcept = 0; mutable std::atomic strong_refcount_ = 1; mutable std::atomic weak_refcount_ = 1; }; // ReferenceCount and its derivations must be at least as aligned as // google::protobuf::Arena. This is a requirement for the pointer tagging defined in // common/internal/metadata.h. static_assert(alignof(ReferenceCount) >= alignof(google::protobuf::Arena)); // `ReferenceCounted` is a base class for classes which should be reference // counted. It provides default implementations for `Finalize()` and `Delete()`. class ReferenceCounted : public ReferenceCount { private: void Finalize() noexcept override {} void Delete() noexcept override { delete this; } }; // `EmplacedReferenceCount` adapts `T` to make it reference countable, by // storing `T` inside the reference count. This only works when `T` has not yet // been allocated. template class EmplacedReferenceCount final : public ReferenceCounted { public: static_assert(std::is_destructible_v, "T must be destructible"); static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_array_v, "T must not be an array"); template explicit EmplacedReferenceCount(T*& value, Args&&... args) noexcept( std::is_nothrow_constructible_v) { value = ::new (static_cast(&value_[0])) T(std::forward(args)...); } private: void Finalize() noexcept override { std::destroy_at(std::launder(reinterpret_cast(&value_[0]))); } // We store the instance of `T` in a char buffer and use placement new and // direct calls to the destructor. The reason for this is `Finalize()` is // called when the strong reference count hits 0. This allows us to destroy // our instance of `T` once we are no longer strongly reachable and deallocate // the memory once we are no longer weakly reachable. alignas(T) char value_[sizeof(T)]; }; // `DeletingReferenceCount` adapts `T` to make it reference countable, by taking // ownership of `T` and deleting it. This only works when `T` has already been // allocated and is to expensive to move or copy. template class DeletingReferenceCount final : public ReferenceCounted { public: explicit DeletingReferenceCount(const T* absl_nonnull to_delete) noexcept : to_delete_(to_delete) {} private: void Finalize() noexcept override { delete to_delete_; } const T* absl_nonnull const to_delete_; }; extern template class DeletingReferenceCount; template const ReferenceCount* absl_nonnull MakeDeletingReferenceCount( const T* absl_nonnull to_delete) { if constexpr (google::protobuf::Arena::is_arena_constructable::value) { ABSL_DCHECK_EQ(to_delete->GetArena(), nullptr); } if constexpr (std::is_base_of_v) { return new DeletingReferenceCount(to_delete); } else { auto* refcount = new DeletingReferenceCount(to_delete); if constexpr (std::is_base_of_v) { common_internal::SetDataReferenceCount(to_delete, refcount); } return refcount; } } template std::pair MakeEmplacedReferenceCount(Args&&... args) { using U = std::remove_const_t; U* pointer; auto* const refcount = new EmplacedReferenceCount(pointer, std::forward(args)...); if constexpr (google::protobuf::Arena::is_arena_constructable::value) { ABSL_DCHECK_EQ(pointer->GetArena(), nullptr); } if constexpr (std::is_base_of_v) { common_internal::SetDataReferenceCount(pointer, refcount); } return std::pair{static_cast(pointer), static_cast(refcount)}; } template class InlinedReferenceCount final : public ReferenceCounted { public: template explicit InlinedReferenceCount(std::in_place_t, Args&&... args) : ReferenceCounted() { ::new (static_cast(value())) T(std::forward(args)...); } ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull value() { return reinterpret_cast(&value_[0]); } ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull value() const { return reinterpret_cast(&value_[0]); } private: void Finalize() noexcept override { value()->~T(); } // We store the instance of `T` in a char buffer and use placement new and // direct calls to the destructor. The reason for this is `Finalize()` is // called when the strong reference count hits 0. This allows us to destroy // our instance of `T` once we are no longer strongly reachable and deallocate // the memory once we are no longer weakly reachable. alignas(T) char value_[sizeof(T)]; }; template std::pair MakeReferenceCount( Args&&... args) { using U = std::remove_const_t; auto* const refcount = new InlinedReferenceCount(std::in_place, std::forward(args)...); auto* const pointer = refcount->value(); if constexpr (std::is_base_of_v) { SetReferenceCountForThat(*pointer, refcount); } return std::make_pair(static_cast(pointer), static_cast(refcount)); } inline void StrongRef(const ReferenceCount& refcount) noexcept { const auto count = refcount.strong_refcount_.fetch_add(1, std::memory_order_relaxed); ABSL_DCHECK_GT(count, 0); } inline void StrongRef(const ReferenceCount* absl_nullable refcount) noexcept { if (refcount != nullptr) { StrongRef(*refcount); } } inline void StrongUnref(const ReferenceCount& refcount) noexcept { const auto count = refcount.strong_refcount_.fetch_sub(1, std::memory_order_acq_rel); ABSL_DCHECK_GT(count, 0); ABSL_ASSUME(count > 0); if (ABSL_PREDICT_FALSE(count == 1)) { const_cast(refcount).Finalize(); WeakUnref(refcount); } } inline void StrongUnref(const ReferenceCount* absl_nullable refcount) noexcept { if (refcount != nullptr) { StrongUnref(*refcount); } } ABSL_MUST_USE_RESULT inline bool StrengthenRef(const ReferenceCount& refcount) noexcept { auto count = refcount.strong_refcount_.load(std::memory_order_relaxed); while (true) { ABSL_DCHECK_GE(count, 0); ABSL_ASSUME(count >= 0); if (count == 0) { return false; } if (refcount.strong_refcount_.compare_exchange_weak( count, count + 1, std::memory_order_release, std::memory_order_relaxed)) { return true; } } } ABSL_MUST_USE_RESULT inline bool StrengthenRef( const ReferenceCount* absl_nullable refcount) noexcept { return refcount != nullptr ? StrengthenRef(*refcount) : false; } inline void WeakRef(const ReferenceCount& refcount) noexcept { const auto count = refcount.weak_refcount_.fetch_add(1, std::memory_order_relaxed); ABSL_DCHECK_GT(count, 0); } inline void WeakRef(const ReferenceCount* absl_nullable refcount) noexcept { if (refcount != nullptr) { WeakRef(*refcount); } } inline void WeakUnref(const ReferenceCount& refcount) noexcept { const auto count = refcount.weak_refcount_.fetch_sub(1, std::memory_order_acq_rel); ABSL_DCHECK_GT(count, 0); ABSL_ASSUME(count > 0); if (ABSL_PREDICT_FALSE(count == 1)) { const_cast(refcount).Delete(); } } inline void WeakUnref(const ReferenceCount* absl_nullable refcount) noexcept { if (refcount != nullptr) { WeakUnref(*refcount); } } ABSL_MUST_USE_RESULT inline bool IsUniqueRef(const ReferenceCount& refcount) noexcept { const auto count = refcount.strong_refcount_.load(std::memory_order_acquire); ABSL_DCHECK_GT(count, 0); ABSL_ASSUME(count > 0); return count == 1; } ABSL_MUST_USE_RESULT inline bool IsUniqueRef(const ReferenceCount* absl_nullable refcount) noexcept { return refcount != nullptr ? IsUniqueRef(*refcount) : false; } ABSL_MUST_USE_RESULT inline bool IsExpiredRef(const ReferenceCount& refcount) noexcept { const auto count = refcount.strong_refcount_.load(std::memory_order_acquire); ABSL_DCHECK_GE(count, 0); ABSL_ASSUME(count >= 0); return count == 0; } ABSL_MUST_USE_RESULT inline bool IsExpiredRef( const ReferenceCount* absl_nullable refcount) noexcept { return refcount != nullptr ? IsExpiredRef(*refcount) : false; } std::pair MakeReferenceCountedString(absl::string_view value); std::pair MakeReferenceCountedString(std::string&& value); } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_REFERENCE_COUNT_H_ ================================================ FILE: common/internal/reference_count_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/internal/reference_count.h" #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "common/data.h" #include "internal/testing.h" #include "google/protobuf/arena.h" #include "google/protobuf/message_lite.h" namespace cel::common_internal { namespace { using ::testing::NotNull; using ::testing::WhenDynamicCastTo; class Object : public virtual ReferenceCountFromThis { public: explicit Object(bool& destructed) : destructed_(destructed) {} ~Object() { destructed_ = true; } private: bool& destructed_; }; class Subobject : public Object, public virtual ReferenceCountFromThis { public: using Object::Object; }; TEST(ReferenceCount, Strong) { bool destructed = false; Object* object; ReferenceCount* refcount; std::tie(object, refcount) = MakeReferenceCount(destructed); EXPECT_EQ(GetReferenceCountForThat(*object), refcount); EXPECT_EQ(GetReferenceCountForThat(*static_cast(object)), refcount); StrongRef(refcount); StrongUnref(refcount); EXPECT_TRUE(IsUniqueRef(refcount)); EXPECT_FALSE(IsExpiredRef(refcount)); EXPECT_FALSE(destructed); StrongUnref(refcount); EXPECT_TRUE(destructed); } TEST(ReferenceCount, Weak) { bool destructed = false; Object* object; ReferenceCount* refcount; std::tie(object, refcount) = MakeReferenceCount(destructed); EXPECT_EQ(GetReferenceCountForThat(*object), refcount); EXPECT_EQ(GetReferenceCountForThat(*static_cast(object)), refcount); WeakRef(refcount); ASSERT_TRUE(StrengthenRef(refcount)); StrongUnref(refcount); EXPECT_TRUE(IsUniqueRef(refcount)); EXPECT_FALSE(IsExpiredRef(refcount)); EXPECT_FALSE(destructed); StrongUnref(refcount); EXPECT_TRUE(destructed); EXPECT_TRUE(IsExpiredRef(refcount)); ASSERT_FALSE(StrengthenRef(refcount)); WeakUnref(refcount); } class DataObject final : public Data { public: DataObject() noexcept : Data() {} explicit DataObject(google::protobuf::Arena* absl_nullable arena) noexcept : Data(arena) {} char member_[17]; }; struct OtherObject final { char data[17]; }; TEST(DeletingReferenceCount, Data) { auto* data = new DataObject(); const auto* refcount = MakeDeletingReferenceCount(data); EXPECT_THAT( refcount, WhenDynamicCastTo*>(NotNull())); EXPECT_EQ(common_internal::GetDataReferenceCount(data), refcount); StrongUnref(refcount); } TEST(DeletingReferenceCount, MessageLite) { auto* message_lite = new google::protobuf::Value(); const auto* refcount = MakeDeletingReferenceCount(message_lite); EXPECT_THAT( refcount, WhenDynamicCastTo*>( NotNull())); StrongUnref(refcount); } TEST(DeletingReferenceCount, Other) { auto* other = new OtherObject(); const auto* refcount = MakeDeletingReferenceCount(other); EXPECT_THAT( refcount, WhenDynamicCastTo*>(NotNull())); StrongUnref(refcount); } TEST(EmplacedReferenceCount, Data) { Data* data; const ReferenceCount* refcount; std::tie(data, refcount) = MakeEmplacedReferenceCount(); EXPECT_THAT( refcount, WhenDynamicCastTo*>(NotNull())); EXPECT_EQ(common_internal::GetDataReferenceCount(data), refcount); StrongUnref(refcount); } TEST(EmplacedReferenceCount, MessageLite) { google::protobuf::Value* message_lite; const ReferenceCount* refcount; std::tie(message_lite, refcount) = MakeEmplacedReferenceCount(); EXPECT_THAT( refcount, WhenDynamicCastTo*>( NotNull())); StrongUnref(refcount); } TEST(EmplacedReferenceCount, Other) { OtherObject* other; const ReferenceCount* refcount; std::tie(other, refcount) = MakeEmplacedReferenceCount(); EXPECT_THAT( refcount, WhenDynamicCastTo*>(NotNull())); StrongUnref(refcount); } } // namespace } // namespace cel::common_internal ================================================ FILE: common/internal/signature.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/internal/signature.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" namespace cel::common_internal { namespace { void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { for (char c : str) { switch (c) { case '\\': case '(': case ')': case '<': case '>': case '"': case ',': case '~': result->push_back('\\'); break; case '.': if (escape_dot) { result->push_back('\\'); } break; } result->push_back(c); } } absl::Status AppendTypeParameters(std::string* result, const Type& type); // Recursively appends a string representation of the given `type` to `result`. // Type parameters are enclosed in angle brackets and separated by commas. // Grammar: // TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; // NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; // TypeList = TypeElem { "," TypeElem } ; // TypeElem = TypeDesc | TypeParam // TypeParam = "~" Alpha ; // Identifier = ( Alpha | "_" ) { AlphaNumeric | "_" } ; // (* Terminals *) // Alpha = "a"..."z" | "A"..."Z" ; // Digit = "0"..."9" ; // AlphaNumeric = Alpha | Digit ; // // For compatibility, the implementation allows unexpected characters in // type names and parameters and escapes them with a backslash. absl::Status AppendTypeDesc(std::string* result, const Type& type) { switch (type.kind()) { case TypeKind::kNull: absl::StrAppend(result, "null"); break; case TypeKind::kBool: absl::StrAppend(result, "bool"); break; case TypeKind::kInt: absl::StrAppend(result, "int"); break; case TypeKind::kUint: absl::StrAppend(result, "uint"); break; case TypeKind::kDouble: absl::StrAppend(result, "double"); break; case TypeKind::kString: absl::StrAppend(result, "string"); break; case TypeKind::kBytes: absl::StrAppend(result, "bytes"); break; case TypeKind::kDuration: absl::StrAppend(result, "duration"); break; case TypeKind::kTimestamp: absl::StrAppend(result, "timestamp"); break; case TypeKind::kAny: absl::StrAppend(result, "any"); break; case TypeKind::kDyn: absl::StrAppend(result, "dyn"); break; case TypeKind::kBoolWrapper: absl::StrAppend(result, "bool_wrapper"); break; case TypeKind::kIntWrapper: absl::StrAppend(result, "int_wrapper"); break; case TypeKind::kUintWrapper: absl::StrAppend(result, "uint_wrapper"); break; case TypeKind::kDoubleWrapper: absl::StrAppend(result, "double_wrapper"); break; case TypeKind::kStringWrapper: absl::StrAppend(result, "string_wrapper"); break; case TypeKind::kBytesWrapper: absl::StrAppend(result, "bytes_wrapper"); break; case TypeKind::kList: absl::StrAppend(result, "list"); CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); break; case TypeKind::kMap: absl::StrAppend(result, "map"); CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); break; case TypeKind::kFunction: absl::StrAppend(result, "function"); CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); break; case TypeKind::kType: absl::StrAppend(result, "type"); CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); break; case TypeKind::kTypeParam: absl::StrAppend(result, "~"); AppendEscaped(result, type.GetTypeParam().name(), /*escape_dot=*/true); break; case TypeKind::kOpaque: AppendEscaped(result, type.name(), /*escape_dot=*/false); CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); break; case TypeKind::kStruct: AppendEscaped(result, type.name(), /*escape_dot=*/false); CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); break; default: return absl::InvalidArgumentError( absl::StrFormat("Type kind: %s is not supported in CEL declarations", type.DebugString())); } return absl::OkStatus(); } absl::Status AppendTypeParameters(std::string* result, const Type& type) { const auto& parameters = type.GetParameters(); if (!parameters.empty()) { result->push_back('<'); for (size_t i = 0; i < parameters.size(); ++i) { CEL_RETURN_IF_ERROR(AppendTypeDesc(result, parameters[i])); if (i < parameters.size() - 1) { result->push_back(','); } } result->push_back('>'); } return absl::OkStatus(); } } // namespace absl::StatusOr MakeTypeSignature(const Type& type) { std::string result; CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type)); return result; } absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member) { std::string result; if (is_member) { if (!args.empty()) { CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[0])); } else { return absl::InvalidArgumentError("Member function with no receiver"); } result.push_back('.'); } AppendEscaped(&result, function_name, /*escape_dot=*/true); result.push_back('('); for (size_t i = is_member ? 1 : 0; i < args.size(); ++i) { CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[i])); if (i < args.size() - 1) { result.push_back(','); } } result.push_back(')'); return result; } } // namespace cel::common_internal ================================================ FILE: common/internal/signature.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ #include #include #include #include "absl/status/statusor.h" #include "common/type.h" namespace cel::common_internal { // Generates an signature for a `cel::Type`, which is a string representation of // the type. // // Examples: // // - `int` // - `list` // - `list>` absl::StatusOr MakeTypeSignature(const Type& type); // Generates an identifier for a function overload based on the function name // and the types of the arguments. If `is_member` is true, the first argument // type is used as the receiver and is prepended to the function name, followed // by a dollar sign. // // Examples: // // - `foo()` // - `foo(int)` // - `bar.foo(int)` // - `foo(int,string)` // - `foo(list,list)` // - `bar.foo(list,list>)` // // If the function name contains a period, it is escaped with a backslash, e.g. // `foo.bar` becomes `foo\.bar`. This allows to disambiguate between a member // function and qualified target type name. // absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ ================================================ FILE: common/internal/signature_test.cc ================================================ #include "common/internal/signature.h" // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/type.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" namespace cel::common_internal { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::internal::GetTestingDescriptorPool; using ::testing::HasSubstr; using ::testing::ValuesIn; google::protobuf::Arena* GetTestArena() { static absl::NoDestructor arena; return &*arena; } struct TypeSignatureTestCase { Type type; std::string expected_signature; std::string expected_error; }; using TypeSignatureTest = testing::TestWithParam; TEST_P(TypeSignatureTest, TypeSignature) { const auto& param = GetParam(); absl::StatusOr signature = common_internal::MakeTypeSignature(param.type); if (!param.expected_error.empty()) { EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); } else { EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); } } std::vector GetTypeSignatureTestCases() { return { { .type = StringType{}, .expected_signature = "string", }, { .type = IntType{}, .expected_signature = "int", }, { .type = ListType(GetTestArena(), StringType{}), .expected_signature = "list", }, { .type = ListType(GetTestArena(), TypeParamType("A")), .expected_signature = "list<~A>", }, { .type = MapType(GetTestArena(), IntType{}, DynType{}), .expected_signature = "map", }, { .type = MapType(GetTestArena(), TypeParamType("B"), TypeParamType("C")), .expected_signature = "map<~B,~C>", }, { .type = OpaqueType( GetTestArena(), "bar", {FunctionType(GetTestArena(), TypeParamType("D"), {})}), .expected_signature = "bar>", }, { .type = AnyType{}, .expected_signature = "any", }, { .type = DurationType{}, .expected_signature = "duration", }, { .type = TimestampType{}, .expected_signature = "timestamp", }, { .type = IntWrapperType{}, .expected_signature = "int_wrapper", }, { .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")), .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", }, { .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", }, { .type = UnknownType{}, .expected_error = "Type kind: *unknown* is not supported in CEL declarations", }, { .type = ErrorType{}, .expected_error = "Type kind: *error* is not supported in CEL declarations", }, }; } INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, ValuesIn(GetTypeSignatureTestCases())); struct OverloadSignatureTestCase { std::string function_name = "hello"; std::vector args; bool is_member = false; std::string expected_signature; std::string expected_error; }; using OverloadSignatureTest = testing::TestWithParam; TEST_P(OverloadSignatureTest, OverloadSignature) { const auto& param = GetParam(); absl::StatusOr signature = common_internal::MakeOverloadSignature(param.function_name, param.args, param.is_member); if (!param.expected_error.empty()) { EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); } else { EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); } } std::vector GetOverloadSignatureTestCases() { return { { .args = {StringType{}}, .expected_signature = "hello(string)", }, { .args = {IntType{}, UintType{}}, .expected_signature = "hello(int,uint)", }, { .args = {ListType(GetTestArena(), StringType{})}, .expected_signature = "hello(list)", }, { .args = {ListType(GetTestArena(), TypeParamType("A"))}, .expected_signature = "hello(list<~A>)", }, { .args = {MapType(GetTestArena(), IntType{}, DynType{})}, .expected_signature = "hello(map)", }, { .args = {MapType(GetTestArena(), TypeParamType("B"), TypeParamType("C"))}, .expected_signature = "hello(map<~B,~C>)", }, { .args = {OpaqueType( GetTestArena(), "bar", {FunctionType(GetTestArena(), TypeParamType("D"), {})})}, .expected_signature = "hello(bar>)", }, { .args = {AnyType{}}, .expected_signature = "hello(any)", }, { .args = {DurationType{}}, .expected_signature = "hello(duration)", }, { .args = {TimestampType{}}, .expected_signature = "hello(timestamp)", }, { .args = {IntWrapperType{}}, .expected_signature = "hello(int_wrapper)", }, { .args = {MessageType( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes"))}, .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, {.args = {}, .is_member = true, .expected_error = "Member function with no receiver"}, { .args = {StringType{}}, .is_member = true, .expected_signature = "string.hello()", }, { .args = {StringType{}, ListType(GetTestArena(), BoolType{})}, .is_member = true, .expected_signature = "string.hello(list)", }, { .args = {StringType{}, BoolType{}, DynType{}}, .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, { .function_name = R"(h.(e),l\o)", .args = {StringType{}, ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)"))}, .is_member = true, .expected_signature = R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", }, }; } INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, ValuesIn(GetOverloadSignatureTestCases())); } // namespace } // namespace cel::common_internal ================================================ FILE: common/internal/value_conversion.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/internal/value_conversion.h" #include #include #include "cel/expr/value.pb.h" #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/any.h" #include "common/value.h" #include "common/value_kind.h" #include "extensions/protobuf/value.h" #include "internal/proto_time_encoding.h" #include "internal/status_macros.h" #include "internal/time.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" #include "google/protobuf/message.h" namespace cel::test { namespace { using ExprValueKind = cel::expr::Value::KindCase; using ExprMapValue = cel::expr::MapValue; using ExprListValue = cel::expr::ListValue; std::string ToString(ExprValueKind kind_case) { switch (kind_case) { case ExprValueKind::kBoolValue: return "bool_value"; case ExprValueKind::kInt64Value: return "int64_value"; case ExprValueKind::kUint64Value: return "uint64_value"; case ExprValueKind::kDoubleValue: return "double_value"; case ExprValueKind::kStringValue: return "string_value"; case ExprValueKind::kBytesValue: return "bytes_value"; case ExprValueKind::kTypeValue: return "type_value"; case ExprValueKind::kEnumValue: return "enum_value"; case ExprValueKind::kMapValue: return "map_value"; case ExprValueKind::kListValue: return "list_value"; case ExprValueKind::kNullValue: return "null_value"; case ExprValueKind::kObjectValue: return "object_value"; default: return "unknown kind case"; } } absl::StatusOr FromObject( const google::protobuf::Any& any, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (any.type_url() == "type.googleapis.com/google.protobuf.Duration") { google::protobuf::Duration duration; if (!any.UnpackTo(&duration)) { return absl::InvalidArgumentError("invalid duration"); } absl::Duration d = internal::DecodeDuration(duration); CEL_RETURN_IF_ERROR(cel::internal::ValidateDuration(d)); return cel::DurationValue(d); } else if (any.type_url() == "type.googleapis.com/google.protobuf.Timestamp") { google::protobuf::Timestamp timestamp; if (!any.UnpackTo(×tamp)) { return absl::InvalidArgumentError("invalid timestamp"); } absl::Time time = internal::DecodeTime(timestamp); CEL_RETURN_IF_ERROR(cel::internal::ValidateTimestamp(time)); return cel::TimestampValue(time); } return extensions::ProtoMessageToValue(any, descriptor_pool, message_factory, arena); } absl::StatusOr MapValueFromExpr( const ExprMapValue& map_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { auto builder = cel::NewMapValueBuilder(arena); for (const auto& entry : map_value.entries()) { CEL_ASSIGN_OR_RETURN(auto key, FromExprValue(entry.key(), descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN(auto value, FromExprValue(entry.value(), descriptor_pool, message_factory, arena)); CEL_RETURN_IF_ERROR(builder->Put(std::move(key), std::move(value))); } return std::move(*builder).Build(); } absl::StatusOr ListValueFromExpr( const ExprListValue& list_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { auto builder = cel::NewListValueBuilder(arena); for (const auto& elem : list_value.values()) { CEL_ASSIGN_OR_RETURN( auto value, FromExprValue(elem, descriptor_pool, message_factory, arena)); CEL_RETURN_IF_ERROR(builder->Add(std::move(value))); } return std::move(*builder).Build(); } absl::StatusOr MapValueToExpr( const MapValue& map_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { ExprMapValue result; CEL_ASSIGN_OR_RETURN(auto iter, map_value.NewIterator()); while (iter->HasNext()) { CEL_ASSIGN_OR_RETURN(auto key_value, iter->Next(descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN( auto value_value, map_value.Get(key_value, descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN( auto key, ToExprValue(key_value, descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN(auto value, ToExprValue(value_value, descriptor_pool, message_factory, arena)); auto* entry = result.add_entries(); *entry->mutable_key() = std::move(key); *entry->mutable_value() = std::move(value); } return result; } absl::StatusOr ListValueToExpr( const ListValue& list_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { ExprListValue result; CEL_ASSIGN_OR_RETURN(auto iter, list_value.NewIterator()); while (iter->HasNext()) { CEL_ASSIGN_OR_RETURN(auto elem, iter->Next(descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN( *result.add_values(), ToExprValue(elem, descriptor_pool, message_factory, arena)); } return result; } absl::StatusOr ToProtobufAny( const StructValue& struct_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { google::protobuf::io::CordOutputStream serialized; CEL_RETURN_IF_ERROR( struct_value.SerializeTo(descriptor_pool, message_factory, &serialized)); google::protobuf::Any result; result.set_type_url(MakeTypeUrl(struct_value.GetTypeName())); result.set_value(std::string(std::move(serialized).Consume())); return result; } } // namespace absl::StatusOr FromExprValue( const cel::expr::Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { google::protobuf::LinkMessageReflection(); switch (value.kind_case()) { case ExprValueKind::kBoolValue: return cel::BoolValue(value.bool_value()); case ExprValueKind::kInt64Value: return cel::IntValue(value.int64_value()); case ExprValueKind::kUint64Value: return cel::UintValue(value.uint64_value()); case ExprValueKind::kDoubleValue: return cel::DoubleValue(value.double_value()); case ExprValueKind::kStringValue: return cel::StringValue(value.string_value()); case ExprValueKind::kBytesValue: return cel::BytesValue(value.bytes_value()); case ExprValueKind::kNullValue: return cel::NullValue(); case ExprValueKind::kObjectValue: return FromObject(value.object_value(), descriptor_pool, message_factory, arena); case ExprValueKind::kMapValue: return MapValueFromExpr(value.map_value(), descriptor_pool, message_factory, arena); case ExprValueKind::kListValue: return ListValueFromExpr(value.list_value(), descriptor_pool, message_factory, arena); default: return absl::UnimplementedError(absl::StrCat( "FromExprValue not supported ", ToString(value.kind_case()))); } } absl::StatusOr ToExprValue( const Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { cel::expr::Value result; switch (value->kind()) { case ValueKind::kBool: result.set_bool_value(value.GetBool().NativeValue()); break; case ValueKind::kInt: result.set_int64_value(value.GetInt().NativeValue()); break; case ValueKind::kUint: result.set_uint64_value(value.GetUint().NativeValue()); break; case ValueKind::kDouble: result.set_double_value(value.GetDouble().NativeValue()); break; case ValueKind::kString: result.set_string_value(value.GetString().ToString()); break; case ValueKind::kBytes: result.set_bytes_value(value.GetBytes().ToString()); break; case ValueKind::kType: result.set_type_value(value.GetType().name()); break; case ValueKind::kNull: result.set_null_value(google::protobuf::NullValue::NULL_VALUE); break; case ValueKind::kDuration: { google::protobuf::Duration duration; CEL_RETURN_IF_ERROR(internal::EncodeDuration( value.GetDuration().NativeValue(), &duration)); result.mutable_object_value()->PackFrom(duration); break; } case ValueKind::kTimestamp: { google::protobuf::Timestamp timestamp; CEL_RETURN_IF_ERROR( internal::EncodeTime(value.GetTimestamp().NativeValue(), ×tamp)); result.mutable_object_value()->PackFrom(timestamp); break; } case ValueKind::kMap: { CEL_ASSIGN_OR_RETURN( *result.mutable_map_value(), MapValueToExpr(value.GetMap(), descriptor_pool, message_factory, arena)); break; } case ValueKind::kList: { CEL_ASSIGN_OR_RETURN( *result.mutable_list_value(), ListValueToExpr(value.GetList(), descriptor_pool, message_factory, arena)); break; } case ValueKind::kStruct: { CEL_ASSIGN_OR_RETURN(*result.mutable_object_value(), ToProtobufAny(value.GetStruct(), descriptor_pool, message_factory, arena)); break; } default: return absl::UnimplementedError( absl::StrCat("ToExprValue not supported ", ValueKindToString(value->kind()))); } return result; } } // namespace cel::test ================================================ FILE: common/internal/value_conversion.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Converters to/from serialized Value to/from runtime values. #ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_VALUE_CONVERSION_H_ #define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_VALUE_CONVERSION_H_ #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/api/expr/v1alpha1/value.pb.h" #include "cel/expr/value.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" // TODO(uncreated-issue/84): Clean up and expose cel::expr::Value converters // in the common folder. namespace cel::test { ABSL_MUST_USE_RESULT inline bool UnsafeConvertWireCompatProto( const google::protobuf::MessageLite& src, google::protobuf::MessageLite* absl_nonnull dest) { absl::Cord serialized; return src.SerializePartialToCord(&serialized) && dest->ParsePartialFromCord(serialized); } ABSL_MUST_USE_RESULT inline bool ConvertWireCompatProto( const cel::expr::CheckedExpr& src, google::api::expr::v1alpha1::CheckedExpr* absl_nonnull dest) { return UnsafeConvertWireCompatProto(src, dest); } ABSL_MUST_USE_RESULT inline bool ConvertWireCompatProto( const google::api::expr::v1alpha1::CheckedExpr& src, cel::expr::CheckedExpr* absl_nonnull dest) { return UnsafeConvertWireCompatProto(src, dest); } ABSL_MUST_USE_RESULT inline bool ConvertWireCompatProto( const cel::expr::ParsedExpr& src, google::api::expr::v1alpha1::ParsedExpr* absl_nonnull dest) { return UnsafeConvertWireCompatProto(src, dest); } ABSL_MUST_USE_RESULT inline bool ConvertWireCompatProto( const google::api::expr::v1alpha1::ParsedExpr& src, cel::expr::ParsedExpr* absl_nonnull dest) { return UnsafeConvertWireCompatProto(src, dest); } ABSL_MUST_USE_RESULT inline bool ConvertWireCompatProto( const cel::expr::Expr& src, google::api::expr::v1alpha1::Expr* absl_nonnull dest) { return UnsafeConvertWireCompatProto(src, dest); } ABSL_MUST_USE_RESULT inline bool ConvertWireCompatProto(const google::api::expr::v1alpha1::Expr& src, cel::expr::Expr* absl_nonnull dest) { return UnsafeConvertWireCompatProto(src, dest); } ABSL_MUST_USE_RESULT inline bool ConvertWireCompatProto( const cel::expr::Value& src, google::api::expr::v1alpha1::Value* absl_nonnull dest) { return UnsafeConvertWireCompatProto(src, dest); } ABSL_MUST_USE_RESULT inline bool ConvertWireCompatProto( const google::api::expr::v1alpha1::Value& src, cel::expr::Value* absl_nonnull dest) { return UnsafeConvertWireCompatProto(src, dest); } absl::StatusOr FromExprValue( const cel::expr::Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); absl::StatusOr ToExprValue( const Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_VALUE_CONVERSION_H_ ================================================ FILE: common/json.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_JSON_H_ #define THIRD_PARTY_CEL_CPP_COMMON_JSON_H_ #include namespace cel { // Maximum `int64_t` value that can be represented as `double` without losing // data. inline constexpr int64_t kJsonMaxInt = (int64_t{1} << 53) - 1; // Minimum `int64_t` value that can be represented as `double` without losing // data. inline constexpr int64_t kJsonMinInt = -kJsonMaxInt; // Maximum `uint64_t` value that can be represented as `double` without losing // data. inline constexpr uint64_t kJsonMaxUint = (uint64_t{1} << 53) - 1; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_JSON_H_ ================================================ FILE: common/kind.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/kind.h" #include "absl/strings/string_view.h" namespace cel { absl::string_view KindToString(Kind kind) { switch (kind) { case Kind::kNullType: return "null_type"; case Kind::kDyn: return "dyn"; case Kind::kAny: return "any"; case Kind::kType: return "type"; case Kind::kTypeParam: return "type_param"; case Kind::kFunction: return "function"; case Kind::kBool: return "bool"; case Kind::kInt: return "int"; case Kind::kUint: return "uint"; case Kind::kDouble: return "double"; case Kind::kString: return "string"; case Kind::kBytes: return "bytes"; case Kind::kDuration: return "duration"; case Kind::kTimestamp: return "timestamp"; case Kind::kList: return "list"; case Kind::kMap: return "map"; case Kind::kStruct: return "struct"; case Kind::kUnknown: return "*unknown*"; case Kind::kOpaque: return "*opaque*"; case Kind::kBoolWrapper: return "google.protobuf.BoolValue"; case Kind::kIntWrapper: return "google.protobuf.Int64Value"; case Kind::kUintWrapper: return "google.protobuf.UInt64Value"; case Kind::kDoubleWrapper: return "google.protobuf.DoubleValue"; case Kind::kStringWrapper: return "google.protobuf.StringValue"; case Kind::kBytesWrapper: return "google.protobuf.BytesValue"; case Kind::kEnum: return "enum"; default: return "*error*"; } } } // namespace cel ================================================ FILE: common/kind.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_KIND_H_ #define THIRD_PARTY_CEL_CPP_COMMON_KIND_H_ #include #include "absl/base/attributes.h" #include "absl/strings/string_view.h" namespace cel { enum class Kind : uint8_t { // Must match legacy CelValue::Type. kNull = 0, kBool, kInt, kUint, kDouble, kString, kBytes, kStruct, kDuration, kTimestamp, kList, kMap, kUnknown, kType, kError, kAny, // New kinds not present in legacy CelValue. kDyn, kOpaque, kBoolWrapper, kIntWrapper, kUintWrapper, kDoubleWrapper, kStringWrapper, kBytesWrapper, kTypeParam, kFunction, kEnum, // Legacy aliases, deprecated do not use. kNullType = kNull, kInt64 = kInt, kUint64 = kUint, kMessage = kStruct, kUnknownSet = kUnknown, kCelType = kType, // INTERNAL: Do not exceed 63. Implementation details rely on the fact that // we can store `Kind` using 6 bits. kNotForUseWithExhaustiveSwitchStatements = 63, }; ABSL_ATTRIBUTE_PURE_FUNCTION absl::string_view KindToString(Kind kind); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_KIND_H_ ================================================ FILE: common/kind_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/kind.h" #include #include #include "common/type_kind.h" #include "common/value_kind.h" #include "internal/testing.h" namespace cel { namespace { static_assert(std::is_same_v, std::underlying_type_t>, "TypeKind and ValueKind must have the same underlying type"); TEST(Kind, ToString) { EXPECT_EQ(KindToString(Kind::kError), "*error*"); EXPECT_EQ(KindToString(Kind::kNullType), "null_type"); EXPECT_EQ(KindToString(Kind::kDyn), "dyn"); EXPECT_EQ(KindToString(Kind::kAny), "any"); EXPECT_EQ(KindToString(Kind::kType), "type"); EXPECT_EQ(KindToString(Kind::kBool), "bool"); EXPECT_EQ(KindToString(Kind::kInt), "int"); EXPECT_EQ(KindToString(Kind::kUint), "uint"); EXPECT_EQ(KindToString(Kind::kDouble), "double"); EXPECT_EQ(KindToString(Kind::kString), "string"); EXPECT_EQ(KindToString(Kind::kBytes), "bytes"); EXPECT_EQ(KindToString(Kind::kDuration), "duration"); EXPECT_EQ(KindToString(Kind::kTimestamp), "timestamp"); EXPECT_EQ(KindToString(Kind::kList), "list"); EXPECT_EQ(KindToString(Kind::kMap), "map"); EXPECT_EQ(KindToString(Kind::kStruct), "struct"); EXPECT_EQ(KindToString(Kind::kUnknown), "*unknown*"); EXPECT_EQ(KindToString(Kind::kOpaque), "*opaque*"); EXPECT_EQ(KindToString(Kind::kBoolWrapper), "google.protobuf.BoolValue"); EXPECT_EQ(KindToString(Kind::kIntWrapper), "google.protobuf.Int64Value"); EXPECT_EQ(KindToString(Kind::kUintWrapper), "google.protobuf.UInt64Value"); EXPECT_EQ(KindToString(Kind::kDoubleWrapper), "google.protobuf.DoubleValue"); EXPECT_EQ(KindToString(Kind::kStringWrapper), "google.protobuf.StringValue"); EXPECT_EQ(KindToString(Kind::kBytesWrapper), "google.protobuf.BytesValue"); EXPECT_EQ(KindToString(static_cast(std::numeric_limits::max())), "*error*"); } TEST(Kind, TypeKindRoundtrip) { EXPECT_EQ(TypeKindToKind(KindToTypeKind(Kind::kBool)), Kind::kBool); } TEST(Kind, ValueKindRoundtrip) { EXPECT_EQ(ValueKindToKind(KindToValueKind(Kind::kBool)), Kind::kBool); } TEST(Kind, IsTypeKind) { EXPECT_TRUE(KindIsTypeKind(Kind::kBool)); EXPECT_TRUE(KindIsTypeKind(Kind::kAny)); EXPECT_TRUE(KindIsTypeKind(Kind::kDyn)); } TEST(Kind, IsValueKind) { EXPECT_TRUE(KindIsValueKind(Kind::kBool)); EXPECT_FALSE(KindIsValueKind(Kind::kAny)); EXPECT_FALSE(KindIsValueKind(Kind::kDyn)); } TEST(Kind, Equality) { EXPECT_EQ(Kind::kBool, TypeKind::kBool); EXPECT_EQ(TypeKind::kBool, Kind::kBool); EXPECT_EQ(Kind::kBool, ValueKind::kBool); EXPECT_EQ(ValueKind::kBool, Kind::kBool); EXPECT_NE(Kind::kBool, TypeKind::kInt); EXPECT_NE(TypeKind::kInt, Kind::kBool); EXPECT_NE(Kind::kBool, ValueKind::kInt); EXPECT_NE(ValueKind::kInt, Kind::kBool); } TEST(TypeKind, ToString) { EXPECT_EQ(TypeKindToString(TypeKind::kBool), KindToString(Kind::kBool)); } TEST(ValueKind, ToString) { EXPECT_EQ(ValueKindToString(ValueKind::kBool), KindToString(Kind::kBool)); } } // namespace } // namespace cel ================================================ FILE: common/legacy_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/legacy_value.h" #include #include #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/attribute.h" #include "common/casting.h" #include "common/kind.h" #include "common/memory.h" #include "common/type.h" #include "common/unknown.h" #include "common/value.h" #include "common/value_kind.h" #include "common/values/list_value_builder.h" #include "common/values/map_value_builder.h" #include "common/values/values.h" #include "eval/internal/cel_value_equal.h" #include "eval/public/cel_value.h" #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/containers/field_backed_map_impl.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrap_util.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "eval/public/structs/proto_message_type_adapter.h" #include "internal/json.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" // TODO(uncreated-issue/76): improve coverage for JSON/Any handling namespace cel { namespace { using google::api::expr::runtime::CelList; using google::api::expr::runtime::CelMap; using google::api::expr::runtime::CelValue; using google::api::expr::runtime::FieldBackedListImpl; using google::api::expr::runtime::FieldBackedMapImpl; using google::api::expr::runtime::GetGenericProtoTypeInfoInstance; using google::api::expr::runtime::LegacyTypeInfoApis; using google::api::expr::runtime::MessageWrapper; using ::google::api::expr::runtime::internal::MaybeWrapValueToMessage; absl::Status InvalidMapKeyTypeError(ValueKind kind) { return absl::InvalidArgumentError( absl::StrCat("Invalid map key type: '", ValueKindToString(kind), "'")); } MessageWrapper AsMessageWrapper( const google::protobuf::Message* absl_nullability_unknown message_ptr, const LegacyTypeInfoApis* absl_nullability_unknown type_info) { return MessageWrapper(message_ptr, type_info); } class CelListIterator final : public ValueIterator { public: explicit CelListIterator(const CelList* cel_list) : cel_list_(cel_list), size_(cel_list_->size()) {} bool HasNext() override { return index_ < size_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (!HasNext()) { return absl::FailedPreconditionError( "ValueIterator::Next() called when ValueIterator::HasNext() returns " "false"); } auto cel_value = cel_list_->Get(arena, index_); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *result)); ++index_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (index_ >= size_) { return false; } auto cel_value = cel_list_->Get(arena, index_); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *key_or_value)); ++index_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (index_ >= size_) { return false; } if (value != nullptr) { auto cel_value = cel_list_->Get(arena, index_); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *value)); } *key = IntValue(index_); ++index_; return true; } private: const CelList* const cel_list_; const int size_; int index_ = 0; }; class CelMapIterator final : public ValueIterator { public: explicit CelMapIterator(const CelMap* cel_map) : cel_map_(cel_map), size_(cel_map->size()) {} bool HasNext() override { return index_ < size_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (!HasNext()) { return absl::FailedPreconditionError( "ValueIterator::Next() called when ValueIterator::HasNext() returns " "false"); } CEL_RETURN_IF_ERROR(ProjectKeys(arena)); auto cel_value = (*cel_list_)->Get(arena, index_); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *result)); ++index_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (index_ >= size_) { return false; } CEL_RETURN_IF_ERROR(ProjectKeys(arena)); auto cel_value = (*cel_list_)->Get(arena, index_); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *key_or_value)); ++index_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (index_ >= size_) { return false; } CEL_RETURN_IF_ERROR(ProjectKeys(arena)); auto cel_key = (*cel_list_)->Get(arena, index_); if (value != nullptr) { auto cel_value = cel_map_->Get(arena, cel_key); if (!cel_value) { return absl::DataLossError( "map iterator returned key that was not present in the map"); } CEL_RETURN_IF_ERROR(ModernValue(arena, *cel_value, *value)); } CEL_RETURN_IF_ERROR(ModernValue(arena, cel_key, *key)); ++index_; return true; } private: absl::Status ProjectKeys(google::protobuf::Arena* arena) { if (cel_list_.ok() && *cel_list_ == nullptr) { cel_list_ = cel_map_->ListKeys(arena); } return cel_list_.status(); } const CelMap* const cel_map_; const int size_ = 0; absl::StatusOr cel_list_ = nullptr; int index_ = 0; }; } // namespace namespace common_internal { namespace { CelValue LegacyTrivialStructValue(google::protobuf::Arena* absl_nonnull arena, const Value& value) { if (auto legacy_struct_value = common_internal::AsLegacyStructValue(value); legacy_struct_value) { return CelValue::CreateMessageWrapper( AsMessageWrapper(legacy_struct_value->message_ptr(), legacy_struct_value->legacy_type_info())); } if (auto parsed_message_value = value.AsParsedMessage(); parsed_message_value) { auto maybe_cloned = parsed_message_value->Clone(arena); return CelValue::CreateMessageWrapper(MessageWrapper( cel::to_address(maybe_cloned), &GetGenericProtoTypeInfoInstance())); } return CelValue::CreateError(google::protobuf::Arena::Create( arena, absl::InvalidArgumentError(absl::StrCat( "unsupported conversion from cel::StructValue to CelValue: ", value.GetRuntimeType().DebugString())))); } CelValue LegacyTrivialListValue(google::protobuf::Arena* absl_nonnull arena, const Value& value) { if (auto legacy_list_value = common_internal::AsLegacyListValue(value); legacy_list_value) { return CelValue::CreateList(legacy_list_value->cel_list()); } if (auto parsed_repeated_field_value = value.AsParsedRepeatedField(); parsed_repeated_field_value) { auto maybe_cloned = parsed_repeated_field_value->Clone(arena); return CelValue::CreateList(google::protobuf::Arena::Create( arena, &maybe_cloned.message(), maybe_cloned.field(), arena)); } if (auto parsed_json_list_value = value.AsParsedJsonList(); parsed_json_list_value) { auto maybe_cloned = parsed_json_list_value->Clone(arena); return CelValue::CreateList(google::protobuf::Arena::Create( arena, cel::to_address(maybe_cloned), well_known_types::GetListValueReflectionOrDie( maybe_cloned->GetDescriptor()) .GetValuesDescriptor(), arena)); } if (auto custom_list_value = value.AsCustomList(); custom_list_value) { auto status_or_compat_list = common_internal::MakeCompatListValue( *custom_list_value, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), arena); if (!status_or_compat_list.ok()) { return CelValue::CreateError(google::protobuf::Arena::Create( arena, std::move(status_or_compat_list).status())); } return CelValue::CreateList(*status_or_compat_list); } return CelValue::CreateError(google::protobuf::Arena::Create( arena, absl::InvalidArgumentError(absl::StrCat( "unsupported conversion from cel::ListValue to CelValue: ", value.GetRuntimeType().DebugString())))); } CelValue LegacyTrivialMapValue(google::protobuf::Arena* absl_nonnull arena, const Value& value) { if (auto legacy_map_value = common_internal::AsLegacyMapValue(value); legacy_map_value) { return CelValue::CreateMap(legacy_map_value->cel_map()); } if (auto parsed_map_field_value = value.AsParsedMapField(); parsed_map_field_value) { auto maybe_cloned = parsed_map_field_value->Clone(arena); return CelValue::CreateMap(google::protobuf::Arena::Create( arena, &maybe_cloned.message(), maybe_cloned.field(), arena)); } if (auto parsed_json_map_value = value.AsParsedJsonMap(); parsed_json_map_value) { auto maybe_cloned = parsed_json_map_value->Clone(arena); return CelValue::CreateMap(google::protobuf::Arena::Create( arena, cel::to_address(maybe_cloned), well_known_types::GetStructReflectionOrDie( maybe_cloned->GetDescriptor()) .GetFieldsDescriptor(), arena)); } if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { auto status_or_compat_map = common_internal::MakeCompatMapValue( *custom_map_value, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), arena); if (!status_or_compat_map.ok()) { return CelValue::CreateError(google::protobuf::Arena::Create( arena, std::move(status_or_compat_map).status())); } return CelValue::CreateMap(*status_or_compat_map); } return CelValue::CreateError(google::protobuf::Arena::Create( arena, absl::InvalidArgumentError(absl::StrCat( "unsupported conversion from cel::MapValue to CelValue: ", value.GetRuntimeType().DebugString())))); } } // namespace google::api::expr::runtime::CelValue UnsafeLegacyValue( const Value& value, bool stable, google::protobuf::Arena* absl_nonnull arena) { switch (value.kind()) { case ValueKind::kNull: return CelValue::CreateNull(); case ValueKind::kBool: return CelValue::CreateBool(value.GetBool()); case ValueKind::kInt: return CelValue::CreateInt64(value.GetInt()); case ValueKind::kUint: return CelValue::CreateUint64(value.GetUint()); case ValueKind::kDouble: return CelValue::CreateDouble(value.GetDouble()); case ValueKind::kString: return CelValue::CreateStringView( LegacyStringValue(value.GetString(), stable, arena)); case ValueKind::kBytes: return CelValue::CreateBytesView( LegacyBytesValue(value.GetBytes(), stable, arena)); case ValueKind::kStruct: return LegacyTrivialStructValue(arena, value); case ValueKind::kDuration: return CelValue::CreateDuration(value.GetDuration().ToDuration()); case ValueKind::kTimestamp: return CelValue::CreateTimestamp(value.GetTimestamp().ToTime()); case ValueKind::kList: return LegacyTrivialListValue(arena, value); case ValueKind::kMap: return LegacyTrivialMapValue(arena, value); case ValueKind::kType: return CelValue::CreateCelTypeView(value.GetType().name()); default: // Everything else is unsupported. return CelValue::CreateError(google::protobuf::Arena::Create( arena, absl::InvalidArgumentError(absl::StrCat( "unsupported conversion from cel::Value to CelValue: ", value->GetRuntimeType().DebugString())))); } } } // namespace common_internal namespace common_internal { std::string LegacyListValue::DebugString() const { return CelValue::CreateList(impl_).DebugString(); } // See `ValueInterface::SerializeTo`. absl::Status LegacyListValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); const google::protobuf::Descriptor* descriptor = descriptor_pool->FindMessageTypeByName("google.protobuf.ListValue"); if (descriptor == nullptr) { return absl::InternalError( "unable to locate descriptor for message type: " "google.protobuf.ListValue"); } google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage( descriptor, message_factory, CelValue::CreateList(impl_), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy map to JSON"); } if (!wrapped->SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", wrapped->GetTypeName())); } return absl::OkStatus(); } absl::Status LegacyListValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, CelValue::CreateList(impl_), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy list to JSON"); } if (wrapped->GetDescriptor() == json->GetDescriptor()) { // We can directly use google::protobuf::Message::Copy(). json->CopyFrom(*wrapped); } else { // Equivalent descriptors but not identical. Must serialize and // deserialize. absl::Cord serialized; if (!wrapped->SerializePartialToString(&serialized)) { return absl::UnknownError(absl::StrCat("failed to serialize message: ", wrapped->GetTypeName())); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parsed message: ", json->GetTypeName())); } } return absl::OkStatus(); } } absl::Status LegacyListValue::ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, CelValue::CreateList(impl_), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy list to JSON"); } if (wrapped->GetDescriptor() == json->GetDescriptor()) { // We can directly use google::protobuf::Message::Copy(). json->CopyFrom(*wrapped); } else { // Equivalent descriptors but not identical. Must serialize and // deserialize. absl::Cord serialized; if (!wrapped->SerializePartialToString(&serialized)) { return absl::UnknownError(absl::StrCat("failed to serialize message: ", wrapped->GetTypeName())); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parsed message: ", json->GetTypeName())); } } return absl::OkStatus(); } } bool LegacyListValue::IsEmpty() const { return impl_->empty(); } size_t LegacyListValue::Size() const { return static_cast(impl_->size()); } // See LegacyListValueInterface::Get for documentation. absl::Status LegacyListValue::Get( size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (ABSL_PREDICT_FALSE(index < 0 || index >= impl_->size())) { *result = ErrorValue(absl::InvalidArgumentError("index out of bounds")); return absl::OkStatus(); } CEL_RETURN_IF_ERROR( ModernValue(arena, impl_->Get(arena, static_cast(index)), *result)); return absl::OkStatus(); } absl::Status LegacyListValue::ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { const auto size = impl_->size(); Value element; for (int index = 0; index < size; ++index) { CEL_RETURN_IF_ERROR(ModernValue(arena, impl_->Get(arena, index), element)); CEL_ASSIGN_OR_RETURN(auto ok, callback(index, Value(element))); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr LegacyListValue::NewIterator() const { return std::make_unique(impl_); } absl::Status LegacyListValue::Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { CEL_ASSIGN_OR_RETURN(auto legacy_other, LegacyValue(arena, other)); const auto* cel_list = impl_; for (int i = 0; i < cel_list->size(); ++i) { auto element = cel_list->Get(arena, i); absl::optional equal = interop_internal::CelValueEqualImpl(element, legacy_other); // Heterogeneous equality behavior is to just return false if equality // undefined. if (equal.has_value() && *equal) { *result = TrueValue(); return absl::OkStatus(); } } *result = FalseValue(); return absl::OkStatus(); } std::string LegacyMapValue::DebugString() const { return CelValue::CreateMap(impl_).DebugString(); } absl::Status LegacyMapValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); const google::protobuf::Descriptor* descriptor = descriptor_pool->FindMessageTypeByName("google.protobuf.Struct"); if (descriptor == nullptr) { return absl::InternalError( "unable to locate descriptor for message type: google.protobuf.Struct"); } google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage( descriptor, message_factory, CelValue::CreateMap(impl_), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy map to JSON"); } if (!wrapped->SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", wrapped->GetTypeName())); } return absl::OkStatus(); } absl::Status LegacyMapValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, CelValue::CreateMap(impl_), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy map to JSON"); } if (wrapped->GetDescriptor() == json->GetDescriptor()) { // We can directly use google::protobuf::Message::Copy(). json->CopyFrom(*wrapped); } else { // Equivalent descriptors but not identical. Must serialize and deserialize. absl::Cord serialized; if (!wrapped->SerializePartialToString(&serialized)) { return absl::UnknownError(absl::StrCat("failed to serialize message: ", wrapped->GetTypeName())); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parsed message: ", json->GetTypeName())); } } return absl::OkStatus(); } absl::Status LegacyMapValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, CelValue::CreateMap(impl_), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy map to JSON"); } if (wrapped->GetDescriptor() == json->GetDescriptor()) { // We can directly use google::protobuf::Message::Copy(). json->CopyFrom(*wrapped); } else { // Equivalent descriptors but not identical. Must serialize and deserialize. absl::Cord serialized; if (!wrapped->SerializePartialToString(&serialized)) { return absl::UnknownError(absl::StrCat("failed to serialize message: ", wrapped->GetTypeName())); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parsed message: ", json->GetTypeName())); } } return absl::OkStatus(); } bool LegacyMapValue::IsEmpty() const { return impl_->empty(); } size_t LegacyMapValue::Size() const { return static_cast(impl_->size()); } absl::Status LegacyMapValue::Get( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { switch (key.kind()) { case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: *result = Value{key}; return absl::OkStatus(); case ValueKind::kBool: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kInt: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUint: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kString: break; default: *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); auto cel_value = impl_->Get(arena, cel_key); if (!cel_value.has_value()) { *result = NoSuchKeyError(key.DebugString()); return absl::OkStatus(); } CEL_RETURN_IF_ERROR(ModernValue(arena, *cel_value, *result)); return absl::OkStatus(); } absl::StatusOr LegacyMapValue::Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { switch (key.kind()) { case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: *result = Value{key}; return false; case ValueKind::kBool: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kInt: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUint: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kString: break; default: *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); auto cel_value = impl_->Get(arena, cel_key); if (!cel_value.has_value()) { *result = NullValue{}; return false; } CEL_RETURN_IF_ERROR(ModernValue(arena, *cel_value, *result)); return true; } absl::Status LegacyMapValue::Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { switch (key.kind()) { case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: *result = Value{key}; return absl::OkStatus(); case ValueKind::kBool: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kInt: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUint: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kString: break; default: *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); absl::StatusOr has = impl_->Has(cel_key); if (!has.ok()) { *result = ErrorValue(std::move(has).status()); return absl::OkStatus(); } *result = BoolValue(*has); return absl::OkStatus(); } absl::Status LegacyMapValue::ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { CEL_ASSIGN_OR_RETURN(auto keys, impl_->ListKeys(arena)); *result = ListValue{common_internal::LegacyListValue(keys)}; return absl::OkStatus(); } absl::Status LegacyMapValue::ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { CEL_ASSIGN_OR_RETURN(auto keys, impl_->ListKeys(arena)); const auto size = keys->size(); Value key; Value value; for (int index = 0; index < size; ++index) { auto cel_key = keys->Get(arena, index); auto cel_value = *impl_->Get(arena, cel_key); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_key, key)); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, value)); CEL_ASSIGN_OR_RETURN(auto ok, callback(key, value)); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr LegacyMapValue::NewIterator() const { return std::make_unique(impl_); } absl::string_view LegacyStructValue::GetTypeName() const { auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); return message_wrapper.legacy_type_info()->GetTypename(message_wrapper); } std::string LegacyStructValue::DebugString() const { auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); return message_wrapper.legacy_type_info()->DebugString(message_wrapper); } absl::Status LegacyStructValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); if (ABSL_PREDICT_TRUE( message_wrapper.message_ptr()->SerializePartialToZeroCopyStream( output))) { return absl::OkStatus(); } return absl::UnknownError("failed to serialize protocol buffer message"); } absl::Status LegacyStructValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); return internal::MessageToJson( *google::protobuf::DownCastMessage(message_wrapper.message_ptr()), descriptor_pool, message_factory, json); } absl::Status LegacyStructValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); return internal::MessageToJson( *google::protobuf::DownCastMessage(message_wrapper.message_ptr()), descriptor_pool, message_factory, json); } absl::Status LegacyStructValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (auto legacy_struct_value = common_internal::AsLegacyStructValue(other); legacy_struct_value.has_value()) { auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { return absl::UnimplementedError( absl::StrCat("legacy access APIs missing for ", GetTypeName())); } auto other_message_wrapper = AsMessageWrapper(legacy_struct_value->message_ptr(), legacy_struct_value->legacy_type_info()); *result = BoolValue{ access_apis->IsEqualTo(message_wrapper, other_message_wrapper)}; return absl::OkStatus(); } if (auto struct_value = other.AsStruct(); struct_value.has_value()) { return common_internal::StructValueEqual( common_internal::LegacyStructValue(message_ptr_, legacy_type_info_), *struct_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } bool LegacyStructValue::IsZeroValue() const { auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { return false; } return access_apis->ListFields(message_wrapper).empty(); } absl::Status LegacyStructValue::GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { *result = NoSuchFieldError(name); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN( auto cel_value, access_apis->GetField(name, message_wrapper, unboxing_options, MemoryManagerRef::Pooling(arena))); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, *result)); return absl::OkStatus(); } absl::Status LegacyStructValue::GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::UnimplementedError( "access to fields by numbers is not available for legacy structs"); } absl::StatusOr LegacyStructValue::HasFieldByName( absl::string_view name) const { auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { return NoSuchFieldError(name).NativeValue(); } return access_apis->HasField(name, message_wrapper); } absl::StatusOr LegacyStructValue::HasFieldByNumber(int64_t number) const { return absl::UnimplementedError( "access to fields by numbers is not available for legacy structs"); } absl::Status LegacyStructValue::ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { return absl::UnimplementedError( absl::StrCat("legacy access APIs missing for ", GetTypeName())); } auto field_names = access_apis->ListFields(message_wrapper); Value value; for (const auto& field_name : field_names) { CEL_ASSIGN_OR_RETURN( auto cel_value, access_apis->GetField(field_name, message_wrapper, ProtoWrapperTypeOptions::kUnsetNull, MemoryManagerRef::Pooling(arena))); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, value)); CEL_ASSIGN_OR_RETURN(auto ok, callback(field_name, value)); if (!ok) { break; } } return absl::OkStatus(); } absl::Status LegacyStructValue::Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const { if (ABSL_PREDICT_FALSE(qualifiers.empty())) { return absl::InvalidArgumentError("invalid select qualifier path."); } auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { absl::string_view field_name = absl::visit( absl::Overload( [](const FieldSpecifier& field) -> absl::string_view { return field.name; }, [](const AttributeQualifier& field) -> absl::string_view { return field.GetStringKey().value_or(""); }), qualifiers.front()); *result = NoSuchFieldError(field_name); *count = -1; return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN( auto legacy_result, access_apis->Qualify(qualifiers, message_wrapper, presence_test, MemoryManager::Pooling(arena))); CEL_RETURN_IF_ERROR(ModernValue(arena, legacy_result.value, *result)); *count = legacy_result.qualifier_count; return absl::OkStatus(); } } // namespace common_internal absl::Status ModernValue(google::protobuf::Arena* arena, google::api::expr::runtime::CelValue legacy_value, Value& result) { switch (legacy_value.type()) { case CelValue::Type::kNullType: result = NullValue{}; return absl::OkStatus(); case CelValue::Type::kBool: result = BoolValue{legacy_value.BoolOrDie()}; return absl::OkStatus(); case CelValue::Type::kInt64: result = IntValue{legacy_value.Int64OrDie()}; return absl::OkStatus(); case CelValue::Type::kUint64: result = UintValue{legacy_value.Uint64OrDie()}; return absl::OkStatus(); case CelValue::Type::kDouble: result = DoubleValue{legacy_value.DoubleOrDie()}; return absl::OkStatus(); case CelValue::Type::kString: result = StringValue(Borrower::Arena(arena), legacy_value.StringOrDie().value()); return absl::OkStatus(); case CelValue::Type::kBytes: result = BytesValue(Borrower::Arena(arena), legacy_value.BytesOrDie().value()); return absl::OkStatus(); case CelValue::Type::kMessage: { auto message_wrapper = legacy_value.MessageWrapperOrDie(); result = common_internal::LegacyStructValue( google::protobuf::DownCastMessage( message_wrapper.message_ptr()), message_wrapper.legacy_type_info()); return absl::OkStatus(); } case CelValue::Type::kDuration: result = UnsafeDurationValue(legacy_value.DurationOrDie()); return absl::OkStatus(); case CelValue::Type::kTimestamp: result = UnsafeTimestampValue(legacy_value.TimestampOrDie()); return absl::OkStatus(); case CelValue::Type::kList: result = ListValue(common_internal::LegacyListValue(legacy_value.ListOrDie())); return absl::OkStatus(); case CelValue::Type::kMap: result = MapValue(common_internal::LegacyMapValue(legacy_value.MapOrDie())); return absl::OkStatus(); case CelValue::Type::kUnknownSet: result = UnknownValue{*legacy_value.UnknownSetOrDie()}; return absl::OkStatus(); case CelValue::Type::kCelType: { auto type_name = legacy_value.CelTypeOrDie().value(); if (type_name.empty()) { return absl::InvalidArgumentError("empty type name in CelValue"); } result = TypeValue(common_internal::LegacyRuntimeType(type_name)); return absl::OkStatus(); } case CelValue::Type::kError: result = ErrorValue{*legacy_value.ErrorOrDie()}; return absl::OkStatus(); case CelValue::Type::kAny: return absl::InternalError(absl::StrCat( "illegal attempt to convert special CelValue type ", CelValue::TypeName(legacy_value.type()), " to cel::Value")); default: break; } return absl::InvalidArgumentError(absl::StrCat( "cel::Value does not support ", KindToString(legacy_value.type()))); } absl::StatusOr LegacyValue( google::protobuf::Arena* arena, const Value& modern_value) { switch (modern_value.kind()) { case ValueKind::kNull: return CelValue::CreateNull(); case ValueKind::kBool: return CelValue::CreateBool(Cast(modern_value).NativeValue()); case ValueKind::kInt: return CelValue::CreateInt64(Cast(modern_value).NativeValue()); case ValueKind::kUint: return CelValue::CreateUint64( Cast(modern_value).NativeValue()); case ValueKind::kDouble: return CelValue::CreateDouble( Cast(modern_value).NativeValue()); case ValueKind::kString: return CelValue::CreateStringView(common_internal::LegacyStringValue( modern_value.GetString(), /*stable=*/false, arena)); case ValueKind::kBytes: return CelValue::CreateBytesView(common_internal::LegacyBytesValue( modern_value.GetBytes(), /*stable=*/false, arena)); case ValueKind::kStruct: return common_internal::LegacyTrivialStructValue(arena, modern_value); case ValueKind::kDuration: return CelValue::CreateUncheckedDuration( modern_value.GetDuration().NativeValue()); case ValueKind::kTimestamp: return CelValue::CreateTimestamp( modern_value.GetTimestamp().NativeValue()); case ValueKind::kList: return common_internal::LegacyTrivialListValue(arena, modern_value); case ValueKind::kMap: return common_internal::LegacyTrivialMapValue(arena, modern_value); case ValueKind::kUnknown: return CelValue::CreateUnknownSet(google::protobuf::Arena::Create( arena, Cast(modern_value).NativeValue())); case ValueKind::kType: return CelValue::CreateCelType( CelValue::CelTypeHolder(google::protobuf::Arena::Create( arena, Cast(modern_value).NativeValue().name()))); case ValueKind::kError: return CelValue::CreateError(google::protobuf::Arena::Create( arena, Cast(modern_value).NativeValue())); default: return absl::InvalidArgumentError( absl::StrCat("google::api::expr::runtime::CelValue does not support ", ValueKindToString(modern_value.kind()))); } } namespace interop_internal { absl::StatusOr FromLegacyValue(google::protobuf::Arena* arena, const CelValue& legacy_value, bool) { switch (legacy_value.type()) { case CelValue::Type::kNullType: return NullValue{}; case CelValue::Type::kBool: return BoolValue(legacy_value.BoolOrDie()); case CelValue::Type::kInt64: return IntValue(legacy_value.Int64OrDie()); case CelValue::Type::kUint64: return UintValue(legacy_value.Uint64OrDie()); case CelValue::Type::kDouble: return DoubleValue(legacy_value.DoubleOrDie()); case CelValue::Type::kString: return StringValue(Borrower::Arena(arena), legacy_value.StringOrDie().value()); case CelValue::Type::kBytes: return BytesValue(Borrower::Arena(arena), legacy_value.BytesOrDie().value()); case CelValue::Type::kMessage: { auto message_wrapper = legacy_value.MessageWrapperOrDie(); return common_internal::LegacyStructValue( google::protobuf::DownCastMessage( message_wrapper.message_ptr()), message_wrapper.legacy_type_info()); } case CelValue::Type::kDuration: return UnsafeDurationValue(legacy_value.DurationOrDie()); case CelValue::Type::kTimestamp: return UnsafeTimestampValue(legacy_value.TimestampOrDie()); case CelValue::Type::kList: return ListValue( common_internal::LegacyListValue(legacy_value.ListOrDie())); case CelValue::Type::kMap: return MapValue(common_internal::LegacyMapValue(legacy_value.MapOrDie())); case CelValue::Type::kUnknownSet: return UnknownValue{*legacy_value.UnknownSetOrDie()}; case CelValue::Type::kCelType: return CreateTypeValueFromView(arena, legacy_value.CelTypeOrDie().value()); case CelValue::Type::kError: return ErrorValue(*legacy_value.ErrorOrDie()); case CelValue::Type::kAny: return absl::InternalError(absl::StrCat( "illegal attempt to convert special CelValue type ", CelValue::TypeName(legacy_value.type()), " to cel::Value")); default: break; } return absl::UnimplementedError(absl::StrCat( "conversion from CelValue to cel::Value for type ", CelValue::TypeName(legacy_value.type()), " is not yet implemented")); } absl::StatusOr ToLegacyValue( google::protobuf::Arena* arena, const Value& value, bool) { switch (value.kind()) { case ValueKind::kNull: return CelValue::CreateNull(); case ValueKind::kBool: return CelValue::CreateBool(Cast(value).NativeValue()); case ValueKind::kInt: return CelValue::CreateInt64(Cast(value).NativeValue()); case ValueKind::kUint: return CelValue::CreateUint64(Cast(value).NativeValue()); case ValueKind::kDouble: return CelValue::CreateDouble(Cast(value).NativeValue()); case ValueKind::kString: return CelValue::CreateStringView(common_internal::LegacyStringValue( value.GetString(), /*stable=*/false, arena)); case ValueKind::kBytes: return CelValue::CreateBytesView(common_internal::LegacyBytesValue( value.GetBytes(), /*stable=*/false, arena)); case ValueKind::kStruct: return common_internal::LegacyTrivialStructValue(arena, value); case ValueKind::kDuration: return CelValue::CreateUncheckedDuration( Cast(value).NativeValue()); case ValueKind::kTimestamp: return CelValue::CreateTimestamp( Cast(value).NativeValue()); case ValueKind::kList: return common_internal::LegacyTrivialListValue(arena, value); case ValueKind::kMap: return common_internal::LegacyTrivialMapValue(arena, value); case ValueKind::kUnknown: return CelValue::CreateUnknownSet(google::protobuf::Arena::Create( arena, Cast(value).NativeValue())); case ValueKind::kType: return CelValue::CreateCelType( CelValue::CelTypeHolder(google::protobuf::Arena::Create( arena, Cast(value).NativeValue().name()))); case ValueKind::kError: return CelValue::CreateError(google::protobuf::Arena::Create( arena, Cast(value).NativeValue())); default: return absl::InvalidArgumentError( absl::StrCat("google::api::expr::runtime::CelValue does not support ", ValueKindToString(value.kind()))); } } Value LegacyValueToModernValueOrDie( google::protobuf::Arena* arena, const google::api::expr::runtime::CelValue& value, bool unchecked) { auto status_or_value = FromLegacyValue(arena, value, unchecked); ABSL_CHECK_OK(status_or_value.status()); // Crash OK return std::move(*status_or_value); } std::vector LegacyValueToModernValueOrDie( google::protobuf::Arena* arena, absl::Span values, bool unchecked) { std::vector modern_values; modern_values.reserve(values.size()); for (const auto& value : values) { modern_values.push_back( LegacyValueToModernValueOrDie(arena, value, unchecked)); } return modern_values; } google::api::expr::runtime::CelValue ModernValueToLegacyValueOrDie( google::protobuf::Arena* arena, const Value& value, bool unchecked) { auto status_or_value = ToLegacyValue(arena, value, unchecked); ABSL_CHECK_OK(status_or_value.status()); // Crash OK return std::move(*status_or_value); } TypeValue CreateTypeValueFromView(google::protobuf::Arena* arena, absl::string_view input) { return TypeValue(common_internal::LegacyRuntimeType(input)); } } // namespace interop_internal } // namespace cel ================================================ FILE: common/legacy_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_LEGACY_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_LEGACY_VALUE_H_ #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "common/value.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" namespace cel { absl::Status ModernValue(google::protobuf::Arena* arena, google::api::expr::runtime::CelValue legacy_value, Value& result); inline absl::StatusOr ModernValue( google::protobuf::Arena* arena, google::api::expr::runtime::CelValue legacy_value) { Value result; CEL_RETURN_IF_ERROR(ModernValue(arena, legacy_value, result)); return result; } absl::StatusOr LegacyValue( google::protobuf::Arena* arena, const Value& modern_value); namespace common_internal { // Convert a `cel::Value` to `google::api::expr::runtime::CelValue`, using // `arena` to make memory allocations if necessary. `stable` indicates whether // `cel::Value` is in a location where it will not be moved, so that inline // string/bytes storage can be referenced. google::api::expr::runtime::CelValue UnsafeLegacyValue( const Value& value, bool stable, google::protobuf::Arena* absl_nonnull arena); } // namespace common_internal } // namespace cel namespace cel::interop_internal { absl::StatusOr FromLegacyValue( google::protobuf::Arena* arena, const google::api::expr::runtime::CelValue& legacy_value, bool unchecked = false); absl::StatusOr ToLegacyValue( google::protobuf::Arena* arena, const Value& value, bool unchecked = false); inline NullValue CreateNullValue() { return NullValue{}; } inline BoolValue CreateBoolValue(bool value) { return BoolValue{value}; } inline IntValue CreateIntValue(int64_t value) { return IntValue{value}; } inline UintValue CreateUintValue(uint64_t value) { return UintValue{value}; } inline DoubleValue CreateDoubleValue(double value) { return DoubleValue{value}; } inline ListValue CreateLegacyListValue( const google::api::expr::runtime::CelList* value) { return common_internal::LegacyListValue(value); } inline MapValue CreateLegacyMapValue( const google::api::expr::runtime::CelMap* value) { return common_internal::LegacyMapValue(value); } inline Value CreateDurationValue(absl::Duration value, bool unchecked = false) { return DurationValue{value}; } inline TimestampValue CreateTimestampValue(absl::Time value) { return TimestampValue{value}; } Value LegacyValueToModernValueOrDie( google::protobuf::Arena* arena, const google::api::expr::runtime::CelValue& value, bool unchecked = false); std::vector LegacyValueToModernValueOrDie( google::protobuf::Arena* arena, absl::Span values, bool unchecked = false); google::api::expr::runtime::CelValue ModernValueToLegacyValueOrDie( google::protobuf::Arena* arena, const Value& value, bool unchecked = false); TypeValue CreateTypeValueFromView(google::protobuf::Arena* arena, absl::string_view input); } // namespace cel::interop_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_LEGACY_VALUE_H_ ================================================ FILE: common/memory.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/memory.h" #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/numeric/bits.h" #include "google/protobuf/arena.h" namespace cel { std::ostream& operator<<(std::ostream& out, MemoryManagement memory_management) { switch (memory_management) { case MemoryManagement::kPooling: return out << "POOLING"; case MemoryManagement::kReferenceCounting: return out << "REFERENCE_COUNTING"; } } void* ReferenceCountingMemoryManager::Allocate(size_t size, size_t alignment) { ABSL_DCHECK(absl::has_single_bit(alignment)) << "alignment must be a power of 2: " << alignment; if (size == 0) { return nullptr; } if (alignment <= __STDCPP_DEFAULT_NEW_ALIGNMENT__) { return ::operator new(size); } return ::operator new(size, static_cast(alignment)); } bool ReferenceCountingMemoryManager::Deallocate(void* ptr, size_t size, size_t alignment) noexcept { ABSL_DCHECK(absl::has_single_bit(alignment)) << "alignment must be a power of 2: " << alignment; if (ptr == nullptr) { ABSL_DCHECK_EQ(size, 0); return false; } ABSL_DCHECK_GT(size, 0); if (alignment <= __STDCPP_DEFAULT_NEW_ALIGNMENT__) { #if defined(__cpp_sized_deallocation) && __cpp_sized_deallocation >= 201309L ::operator delete(ptr, size); #else ::operator delete(ptr); #endif } else { #if defined(__cpp_sized_deallocation) && __cpp_sized_deallocation >= 201309L ::operator delete(ptr, size, static_cast(alignment)); #else ::operator delete(ptr, static_cast(alignment)); #endif } return true; } MemoryManager MemoryManager::Unmanaged() { // A static singleton arena, using `absl::NoDestructor` to avoid warnings // related static variables without trivial destructors. static absl::NoDestructor arena; return MemoryManager::Pooling(&*arena); } } // namespace cel ================================================ FILE: common/memory.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_MEMORY_H_ #define THIRD_PARTY_CEL_CPP_COMMON_MEMORY_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/numeric/bits.h" #include "common/allocator.h" #include "common/arena.h" #include "common/data.h" #include "common/internal/metadata.h" #include "common/internal/reference_count.h" #include "common/reference_count.h" #include "internal/exceptions.h" #include "internal/to_address.h" // IWYU pragma: keep #include "google/protobuf/arena.h" namespace cel { // Obtain the address of the underlying element from a raw pointer or "fancy" // pointer. using internal::to_address; // MemoryManagement is an enumeration of supported memory management forms // underlying `cel::MemoryManager`. enum class MemoryManagement { // Region-based (a.k.a. arena). Memory is allocated in fixed size blocks and // deallocated all at once upon destruction of the `cel::MemoryManager`. kPooling = 1, // Reference counting. Memory is allocated with an associated reference // counter. When the reference counter hits 0, it is deallocated. kReferenceCounting, }; std::ostream& operator<<(std::ostream& out, MemoryManagement memory_management); class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owner; class Borrower; template class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Unique; template class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owned; template class Borrowed; template struct Ownable; template struct Borrowable; class MemoryManager; class ReferenceCountingMemoryManager; class PoolingMemoryManager; namespace common_internal { template inline constexpr bool kNotMessageLiteAndNotData = std::conjunction_v>, std::negation>>; template inline constexpr bool kIsPointerConvertible = std::is_convertible_v; template inline constexpr bool kNotSameAndIsPointerConvertible = std::conjunction_v>, std::bool_constant>>; // Clears the contents of `owner`, and returns the reference count if in use. const ReferenceCount* absl_nullable OwnerRelease(Owner owner) noexcept; const ReferenceCount* absl_nullable BorrowerRelease(Borrower borrower) noexcept; template Owned WrapEternal(const T* value); // Pointer tag used by `cel::Unique` to indicate that the destructor needs to be // registered with the arena, but it has not been done yet. Must be done when // releasing. inline constexpr uintptr_t kUniqueArenaUnownedBit = uintptr_t{1} << 0; inline constexpr uintptr_t kUniqueArenaBits = kUniqueArenaUnownedBit; inline constexpr uintptr_t kUniqueArenaPointerMask = ~kUniqueArenaBits; } // namespace common_internal template Owned AllocateShared(Allocator<> allocator, Args&&... args); template Owned WrapShared(T* object, Allocator<> allocator); // `Owner` represents a reference to some co-owned data, of which this owner is // one of the co-owners. When using reference counting, `Owner` performs // increment/decrement where appropriate similar to `std::shared_ptr`. // `Borrower` is similar to `Owner`, except that it is always trivially // copyable/destructible. In that sense, `Borrower` is similar to // `std::reference_wrapper`. class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owner final { private: static constexpr uintptr_t kNone = common_internal::kMetadataOwnerNone; static constexpr uintptr_t kReferenceCountBit = common_internal::kMetadataOwnerReferenceCountBit; static constexpr uintptr_t kArenaBit = common_internal::kMetadataOwnerArenaBit; static constexpr uintptr_t kBits = common_internal::kMetadataOwnerBits; static constexpr uintptr_t kPointerMask = common_internal::kMetadataOwnerPointerMask; public: static Owner None() noexcept { return Owner(); } static Owner Allocator(Allocator<> allocator) noexcept { auto* arena = allocator.arena(); return arena != nullptr ? Arena(arena) : None(); } static Owner Arena(google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { ABSL_DCHECK(arena != nullptr); return Owner(reinterpret_cast(arena) | kArenaBit); } static Owner Arena(std::nullptr_t) = delete; static Owner ReferenceCount(const ReferenceCount* absl_nonnull reference_count ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { ABSL_DCHECK(reference_count != nullptr); common_internal::StrongRef(*reference_count); return Owner(reinterpret_cast(reference_count) | kReferenceCountBit); } static Owner ReferenceCount(std::nullptr_t) = delete; Owner() = default; Owner(const Owner& other) noexcept : Owner(CopyFrom(other.ptr_)) {} Owner(Owner&& other) noexcept : Owner(MoveFrom(other.ptr_)) {} template // NOLINTNEXTLINE(google-explicit-constructor) Owner(const Owned& owned) noexcept; template // NOLINTNEXTLINE(google-explicit-constructor) Owner(Owned&& owned) noexcept; explicit Owner(Borrower borrower) noexcept; template explicit Owner(Borrowed borrowed) noexcept; ~Owner() { Destroy(ptr_); } Owner& operator=(const Owner& other) noexcept { if (ptr_ != other.ptr_) { Destroy(ptr_); ptr_ = CopyFrom(other.ptr_); } return *this; } Owner& operator=(Owner&& other) noexcept { if (ABSL_PREDICT_TRUE(this != &other)) { Destroy(ptr_); ptr_ = MoveFrom(other.ptr_); } return *this; } template // NOLINTNEXTLINE(google-explicit-constructor) Owner& operator=(const Owned& owned) noexcept; template // NOLINTNEXTLINE(google-explicit-constructor) Owner& operator=(Owned&& owned) noexcept; explicit operator bool() const noexcept { return !IsNone(ptr_); } google::protobuf::Arena* absl_nullable arena() const noexcept { return (ptr_ & Owner::kBits) == Owner::kArenaBit ? reinterpret_cast(ptr_ & Owner::kPointerMask) : nullptr; } void reset() noexcept { Destroy(ptr_); ptr_ = 0; } // Tests whether two owners have ownership over the same data, that is they // are co-owners. friend bool operator==(const Owner& lhs, const Owner& rhs) noexcept { // A reference count and arena can never occupy the same memory address, so // we can compare for equality without masking off the bits. return lhs.ptr_ == rhs.ptr_; } private: template friend class Unique; friend class Borrower; template friend Owned AllocateShared(cel::Allocator<> allocator, Args&&... args); template friend Owned WrapShared(T* object, cel::Allocator<> allocator); template friend struct Ownable; friend const common_internal::ReferenceCount* absl_nullable common_internal::OwnerRelease(Owner owner) noexcept; friend const common_internal::ReferenceCount* absl_nullable common_internal::BorrowerRelease(Borrower borrower) noexcept; friend struct ArenaTraits; constexpr explicit Owner(uintptr_t ptr) noexcept : ptr_(ptr) {} static constexpr bool IsNone(uintptr_t ptr) noexcept { return ptr == kNone; } static constexpr bool IsArena(uintptr_t ptr) noexcept { return (ptr & kArenaBit) != kNone; } static constexpr bool IsReferenceCount(uintptr_t ptr) noexcept { return (ptr & kReferenceCountBit) != kNone; } ABSL_ATTRIBUTE_RETURNS_NONNULL static google::protobuf::Arena* absl_nonnull AsArena(uintptr_t ptr) noexcept { ABSL_ASSERT(IsArena(ptr)); return reinterpret_cast(ptr & kPointerMask); } ABSL_ATTRIBUTE_RETURNS_NONNULL static const common_internal::ReferenceCount* absl_nonnull AsReferenceCount( uintptr_t ptr) noexcept { ABSL_ASSERT(IsReferenceCount(ptr)); return reinterpret_cast( ptr & kPointerMask); } static uintptr_t CopyFrom(uintptr_t other) noexcept { return Own(other); } static uintptr_t MoveFrom(uintptr_t& other) noexcept { return std::exchange(other, kNone); } static void Destroy(uintptr_t ptr) noexcept { Unown(ptr); } static uintptr_t Own(uintptr_t ptr) noexcept { if (IsReferenceCount(ptr)) { const auto* refcount = Owner::AsReferenceCount(ptr); ABSL_ASSUME(refcount != nullptr); common_internal::StrongRef(refcount); } return ptr; } static void Unown(uintptr_t ptr) noexcept { if (IsReferenceCount(ptr)) { const auto* reference_count = AsReferenceCount(ptr); ABSL_ASSUME(reference_count != nullptr); common_internal::StrongUnref(reference_count); } } uintptr_t ptr_ = kNone; }; inline bool operator!=(const Owner& lhs, const Owner& rhs) noexcept { return !operator==(lhs, rhs); } namespace common_internal { inline const ReferenceCount* absl_nullable OwnerRelease(Owner owner) noexcept { uintptr_t ptr = std::exchange(owner.ptr_, kMetadataOwnerNone); if (Owner::IsReferenceCount(ptr)) { return Owner::AsReferenceCount(ptr); } return nullptr; } } // namespace common_internal template <> struct ArenaTraits { static bool trivially_destructible(const Owner& owner) { return !Owner::IsReferenceCount(owner.ptr_); } }; // `Borrower` represents a reference to some borrowed data, where the data has // at least one owner. When using reference counting, `Borrower` does not // participate in incrementing/decrementing the reference count. Thus `Borrower` // will not keep the underlying data alive. class Borrower final { public: static Borrower None() noexcept { return Borrower(); } static Borrower Allocator(Allocator<> allocator) noexcept { auto* arena = allocator.arena(); return arena != nullptr ? Arena(arena) : None(); } static Borrower Arena(google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { ABSL_DCHECK(arena != nullptr); return Borrower(reinterpret_cast(arena) | Owner::kArenaBit); } static Borrower Arena(std::nullptr_t) = delete; static Borrower ReferenceCount( const ReferenceCount* absl_nonnull reference_count ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { ABSL_DCHECK(reference_count != nullptr); return Borrower(reinterpret_cast(reference_count) | Owner::kReferenceCountBit); } static Borrower ReferenceCount(std::nullptr_t) = delete; Borrower() = default; Borrower(const Borrower&) = default; Borrower(Borrower&&) = default; Borrower& operator=(const Borrower&) = default; Borrower& operator=(Borrower&&) = default; template // NOLINTNEXTLINE(google-explicit-constructor) Borrower(const Owned& owned ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept; template // NOLINTNEXTLINE(google-explicit-constructor) Borrower(Borrowed borrowed) noexcept; // NOLINTNEXTLINE(google-explicit-constructor) Borrower(const Owner& owner ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept : ptr_(owner.ptr_) {} // NOLINTNEXTLINE(google-explicit-constructor) Borrower& operator=( const Owner& owner ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { ptr_ = owner.ptr_; return *this; } Borrower& operator=(Owner&&) = delete; template Borrower& operator=( const Owned& owned ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept; template Borrower& operator=(Owned&&) = delete; template // NOLINTNEXTLINE(google-explicit-constructor) Borrower& operator=(Borrowed borrowed) noexcept; explicit operator bool() const noexcept { return !Owner::IsNone(ptr_); } google::protobuf::Arena* absl_nullable arena() const noexcept { return (ptr_ & Owner::kBits) == Owner::kArenaBit ? reinterpret_cast(ptr_ & Owner::kPointerMask) : nullptr; } void reset() noexcept { ptr_ = 0; } // Tests whether two borrowers are borrowing the same data. friend bool operator==(Borrower lhs, Borrower rhs) noexcept { // A reference count and arena can never occupy the same memory address, so // we can compare for equality without masking off the bits. return lhs.ptr_ == rhs.ptr_; } private: friend class Owner; template friend struct Borrowable; friend const common_internal::ReferenceCount* absl_nullable common_internal::BorrowerRelease(Borrower borrower) noexcept; constexpr explicit Borrower(uintptr_t ptr) noexcept : ptr_(ptr) {} uintptr_t ptr_ = Owner::kNone; }; inline bool operator!=(Borrower lhs, Borrower rhs) noexcept { return !operator==(lhs, rhs); } inline bool operator==(Borrower lhs, const Owner& rhs) noexcept { return operator==(lhs, Borrower(rhs)); } inline bool operator==(const Owner& lhs, Borrower rhs) noexcept { return operator==(Borrower(lhs), rhs); } inline bool operator!=(Borrower lhs, const Owner& rhs) noexcept { return !operator==(lhs, rhs); } inline bool operator!=(const Owner& lhs, Borrower rhs) noexcept { return !operator==(lhs, rhs); } inline Owner::Owner(Borrower borrower) noexcept : ptr_(Owner::Own(borrower.ptr_)) {} namespace common_internal { inline const ReferenceCount* absl_nullable BorrowerRelease( Borrower borrower) noexcept { uintptr_t ptr = borrower.ptr_; if (Owner::IsReferenceCount(ptr)) { return Owner::AsReferenceCount(ptr); } return nullptr; } } // namespace common_internal template Unique AllocateUnique(Allocator<> allocator, Args&&... args); // Wrap an already created `T` in `Unique`. Requires that `T` is not const, // otherwise `GetArena()` may return slightly unexpected results depending on if // it is the default value. template std::enable_if_t, Unique> WrapUnique(T* object); template Unique WrapUnique(T* object, Allocator<> allocator); // `Unique` points to an object which was allocated using `Allocator<>` or // `Allocator`. It has ownership over the object, and will perform any // destruction and deallocation required. `Unique` must not outlive the // underlying arena, if any. Unlike `Owned` and `Borrowed`, `Unique` supports // arena incompatible objects. It is very similar to `std::unique_ptr` when // using a custom deleter. // // IMPLEMENTATION NOTES: // When utilizing arenas, we optionally perform a risky optimization via // `AllocateUnique`. We do not use `Arena::Create`, instead we directly allocate // the bytes and construct it in place ourselves. This avoids registering the // destructor when required. Instead we register the destructor ourselves, if // required, during `Unique::release`. This allows us to avoid deferring // destruction of the object until the arena is destroyed, avoiding the cost // involved in doing so. template class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Unique final { public: using element_type = T; static_assert(!std::is_array_v, "T must not be an array"); static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); Unique() = default; Unique(const Unique&) = delete; Unique& operator=(const Unique&) = delete; explicit Unique(T* ptr) noexcept : Unique(ptr, common_internal::GetArena(ptr)) {} // NOLINTNEXTLINE(google-explicit-constructor) Unique(std::nullptr_t) noexcept : Unique() {} Unique(Unique&& other) noexcept : Unique(other.ptr_, other.arena_) { other.ptr_ = nullptr; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Unique(Unique&& other) noexcept : Unique(other.ptr_, other.arena_) { other.ptr_ = nullptr; } ~Unique() { Delete(); } Unique& operator=(Unique&& other) noexcept { if (ABSL_PREDICT_TRUE(this != &other)) { Delete(); ptr_ = other.ptr_; arena_ = other.arena_; other.ptr_ = nullptr; } return *this; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Unique& operator=(U* other) noexcept { reset(other); return *this; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Unique& operator=(Unique&& other) noexcept { Delete(); ptr_ = other.ptr_; arena_ = other.arena_; other.ptr_ = nullptr; return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Unique& operator=(std::nullptr_t) noexcept { reset(); return *this; } T& operator*() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(static_cast(*this)); return *get(); } T* absl_nonnull operator->() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(static_cast(*this)); return get(); } // Relinquishes ownership of `T*`, returning it. If `T` was allocated and // constructed using an arena, no further action is required. If `T` was // allocated and constructed without an arena, the caller must eventually call // `delete`. ABSL_MUST_USE_RESULT T* release() noexcept { PreRelease(); return std::exchange(ptr_, nullptr); } void reset() noexcept { reset(nullptr); } void reset(T* ptr) noexcept { Delete(); ptr_ = ptr; arena_ = reinterpret_cast(common_internal::GetArena(ptr)); } void reset(std::nullptr_t) noexcept { Delete(); ptr_ = nullptr; arena_ = 0; } explicit operator bool() const noexcept { return get() != nullptr; } google::protobuf::Arena* absl_nullable arena() const noexcept { return reinterpret_cast( arena_ & common_internal::kUniqueArenaPointerMask); } friend void swap(Unique& lhs, Unique& rhs) noexcept { using std::swap; swap(lhs.ptr_, rhs.ptr_); swap(lhs.arena_, rhs.arena_); } private: template friend class Unique; template friend class Owned; template friend Unique AllocateUnique(Allocator<> allocator, Args&&... args); template friend Unique WrapUnique(U* object, Allocator<> allocator); friend class ReferenceCountingMemoryManager; friend class PoolingMemoryManager; friend struct std::pointer_traits>; friend struct ArenaTraits>; Unique(T* ptr, uintptr_t arena) noexcept : ptr_(ptr), arena_(arena) {} Unique(T* ptr, google::protobuf::Arena* arena, bool unowned = false) noexcept : Unique(ptr, reinterpret_cast(arena) | (unowned ? common_internal::kUniqueArenaUnownedBit : 0)) { ABSL_ASSERT(!unowned || (unowned && arena != nullptr)); } Unique(google::protobuf::Arena* arena, T* ptr, bool unowned = false) noexcept : Unique(ptr, arena, unowned) {} T* get() const noexcept { return ptr_; } void Delete() const noexcept { if (static_cast(*this)) { if (arena_ != 0) { if ((arena_ & common_internal::kUniqueArenaBits) == common_internal::kUniqueArenaUnownedBit) { // We never registered the destructor, call it if necessary. if constexpr (!std::is_trivially_destructible_v && !google::protobuf::Arena::is_destructor_skippable::value) { std::destroy_at(ptr_); } } } else { delete ptr_; } } } void PreRelease() noexcept { if constexpr (!std::is_trivially_destructible_v && !google::protobuf::Arena::is_destructor_skippable::value) { if (static_cast(*this) && (arena_ & common_internal::kUniqueArenaBits) == common_internal::kUniqueArenaUnownedBit) { // We never registered the destructor, call it if necessary. arena()->OwnDestructor(const_cast*>(ptr_)); arena_ &= common_internal::kUniqueArenaPointerMask; } } } void Release(T** ptr, Owner* owner) noexcept { if (ptr_ == nullptr) { *ptr = nullptr; return; } PreRelease(); *ptr = std::exchange(ptr_, nullptr); if (arena_ == 0) { owner->ptr_ = reinterpret_cast( common_internal::MakeDeletingReferenceCount(*ptr)) | common_internal::kMetadataOwnerReferenceCountBit; } else { owner->ptr_ = reinterpret_cast(arena()) | common_internal::kMetadataOwnerArenaBit; } } T* ptr_ = nullptr; // Potentially tagged pointer to `google::protobuf::Arena`. The tag is used to determine // whether we still need to register the destructor with the `google::protobuf::Arena`. uintptr_t arena_ = 0; }; template Unique(T*) -> Unique; template Unique AllocateUnique(Allocator<> allocator, Args&&... args) { using U = std::remove_cv_t; static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_array_v, "T must not be an array"); U* object; google::protobuf::Arena* absl_nullable arena = allocator.arena(); bool unowned; if constexpr (google::protobuf::Arena::is_arena_constructable::value) { object = google::protobuf::Arena::Create(arena, std::forward(args)...); // For arena-compatible proto types, let the Arena::Create handle // registering the destructor call. // Otherwise, Unique retains a pointer to the owning arena so it may // conditionally register T::~T depending on usage. unowned = false; } else { void* p = allocator.allocate_bytes(sizeof(U), alignof(U)); CEL_INTERNAL_TRY { if constexpr (ArenaTraits<>::constructible()) { object = ::new (p) U(arena, std::forward(args)...); } else { object = ::new (p) U(std::forward(args)...); } } CEL_INTERNAL_CATCH_ANY { allocator.deallocate_bytes(p, sizeof(U), alignof(U)); CEL_INTERNAL_RETHROW; } unowned = arena != nullptr && !ArenaTraits<>::trivially_destructible(*object); } return Unique(object, arena, unowned); } template std::enable_if_t, Unique> WrapUnique(T* object) { return Unique(object); } template Unique WrapUnique(T* object, Allocator<> allocator) { return Unique(object, allocator.arena()); } template inline bool operator==(const Unique& lhs, std::nullptr_t) { return !static_cast(lhs); } template inline bool operator==(std::nullptr_t, const Unique& rhs) { return !static_cast(rhs); } template inline bool operator!=(const Unique& lhs, std::nullptr_t) { return static_cast(lhs); } template inline bool operator!=(std::nullptr_t, const Unique& rhs) { return static_cast(rhs); } } // namespace cel namespace std { template struct pointer_traits> { using pointer = cel::Unique; using element_type = typename cel::Unique::element_type; using difference_type = ptrdiff_t; template using rebind = cel::Unique; static element_type* to_address(const pointer& p) noexcept { return p.ptr_; } }; } // namespace std namespace cel { template struct ArenaTraits> { static bool trivially_destructible(const Unique& unique) { return unique.arena_ != 0 && (unique.arena_ & common_internal::kUniqueArenaBits) == 0; } }; // `Owned` points to an object which was allocated using `Allocator<>` or // `Allocator`. It has co-ownership over the object. `T` must meet the named // requirement `ArenaConstructable`. template class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owned final { public: using element_type = T; static_assert(!std::is_array_v, "T must not be an array"); static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(!std::is_void_v, "T must not be void"); Owned() = default; Owned(const Owned&) = default; Owned& operator=(const Owned&) = default; Owned(Owned&& other) noexcept : Owned(std::exchange(other.value_, nullptr), std::move(other.owner_)) {} template >> // NOLINTNEXTLINE(google-explicit-constructor) Owned(const Owned& other) noexcept : Owned(other.value_, other.owner_) {} template >> // NOLINTNEXTLINE(google-explicit-constructor) Owned(Owned&& other) noexcept : Owned(std::exchange(other.value_, nullptr), std::move(other.owner_)) {} template >> explicit Owned(Borrowed other) noexcept; template >> // NOLINTNEXTLINE(google-explicit-constructor) Owned(Unique&& other) : Owned() { other.Release(&value_, &owner_); } Owned(Owner owner, T* value ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept : Owned(value, std::move(owner)) {} // NOLINTNEXTLINE(google-explicit-constructor) Owned(std::nullptr_t) noexcept : Owned() {} Owned& operator=(Owned&& other) noexcept { if (ABSL_PREDICT_TRUE(this != &other)) { value_ = std::exchange(other.value_, nullptr); owner_ = std::move(other.owner_); } return *this; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Owned& operator=(const Owned& other) noexcept { value_ = other.value_; owner_ = other.owner_; return *this; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Owned& operator=(Owned&& other) noexcept { value_ = std::exchange(other.value_, nullptr); owner_ = std::move(other.owner_); return *this; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Owned& operator=(Borrowed other) noexcept; template >> // NOLINTNEXTLINE(google-explicit-constructor) Owned& operator=(Unique&& other) { owner_.reset(); other.Release(&value_, &owner_); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Owned& operator=(std::nullptr_t) noexcept { reset(); return *this; } T& operator*() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(static_cast(*this)); return *get(); } T* absl_nonnull operator->() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(static_cast(*this)); return get(); } void reset() noexcept { value_ = nullptr; owner_.reset(); } google::protobuf::Arena* absl_nullable arena() const noexcept { return owner_.arena(); } explicit operator bool() const noexcept { return get() != nullptr; } friend void swap(Owned& lhs, Owned& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); swap(lhs.owner_, rhs.owner_); } private: friend class Owner; friend class Borrower; template friend class Owned; template friend class Borrowed; template friend struct Ownable; template friend Owned AllocateShared(Allocator<> allocator, Args&&... args); template friend Owned WrapShared(U* object, Allocator<> allocator); template friend Owned common_internal::WrapEternal(const U* value); friend struct std::pointer_traits>; friend struct ArenaTraits>; Owned(T* value, Owner owner) noexcept : value_(value), owner_(std::move(owner)) {} T* get() const noexcept { return value_; } T* value_ = nullptr; Owner owner_; }; template Owned(T*) -> Owned; template Owned(Unique) -> Owned; template Owned(Owner, T*) -> Owned; template Owned(Borrowed) -> Owned; } // namespace cel namespace std { template struct pointer_traits> { using pointer = cel::Owned; using element_type = typename cel::Owned::element_type; using difference_type = ptrdiff_t; template using rebind = cel::Owned; static element_type* to_address(const pointer& p) noexcept { return p.value_; } }; } // namespace std namespace cel { template struct ArenaTraits> { static bool trivially_destructible(const Owned& owned) { return ArenaTraits<>::trivially_destructible(owned.owner_); } }; template Owner::Owner(const Owned& owned) noexcept : Owner(owned.owner_) {} template Owner::Owner(Owned&& owned) noexcept : Owner(std::move(owned.owner_)) { owned.value_ = nullptr; } template Owner& Owner::operator=(const Owned& owned) noexcept { *this = owned.owner_; return *this; } template Owner& Owner::operator=(Owned&& owned) noexcept { *this = std::move(owned.owner_); owned.value_ = nullptr; return *this; } template bool operator==(const Owned& lhs, std::nullptr_t) noexcept { return !static_cast(lhs); } template bool operator==(std::nullptr_t, const Owned& rhs) noexcept { return rhs == nullptr; } template bool operator!=(const Owned& lhs, std::nullptr_t) noexcept { return !operator==(lhs, nullptr); } template bool operator!=(std::nullptr_t, const Owned& rhs) noexcept { return !operator==(nullptr, rhs); } template Owned AllocateShared(Allocator<> allocator, Args&&... args) { using U = std::remove_cv_t; static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_array_v, "T must not be an array"); U* object; Owner owner; if (google::protobuf::Arena* absl_nullable arena = allocator.arena(); arena != nullptr) { object = ArenaAllocator(arena).template new_object( std::forward(args)...); owner.ptr_ = reinterpret_cast(arena) | common_internal::kMetadataOwnerArenaBit; } else { const common_internal::ReferenceCount* refcount; std::tie(object, refcount) = common_internal::MakeEmplacedReferenceCount( std::forward(args)...); owner.ptr_ = reinterpret_cast(refcount) | common_internal::kMetadataOwnerReferenceCountBit; } return Owned(object, std::move(owner)); } template Owned WrapShared(T* object, Allocator<> allocator) { Owner owner; if (object == nullptr) { } else if (allocator.arena() != nullptr) { owner.ptr_ = reinterpret_cast( static_cast(allocator.arena())) | common_internal::kMetadataOwnerArenaBit; } else { owner.ptr_ = reinterpret_cast( common_internal::MakeDeletingReferenceCount(object)) | common_internal::kMetadataOwnerReferenceCountBit; } return Owned(object, std::move(owner)); } template std::enable_if_t, Owned> WrapShared(T* object) { return WrapShared(object, object->GetArena()); } namespace common_internal { template Owned WrapEternal(const T* value) { return Owned(value, Owner::None()); } } // namespace common_internal // `Borrowed` points to an object which was allocated using `Allocator<>` or // `Allocator`. It has no ownership over the object, and is only valid so // long as one or more owners of the object exist. `T` must meet the named // requirement `ArenaConstructable`. template class Borrowed final { public: using element_type = T; static_assert(!std::is_array_v, "T must not be an array"); static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(!std::is_void_v, "T must not be void"); Borrowed() = default; Borrowed(const Borrowed&) = default; Borrowed(Borrowed&&) = default; Borrowed& operator=(const Borrowed&) = default; Borrowed& operator=(Borrowed&&) = default; template >> // NOLINTNEXTLINE(google-explicit-constructor) Borrowed(const Borrowed& other) noexcept : Borrowed(other.value_, other.borrower_) {} template >> // NOLINTNEXTLINE(google-explicit-constructor) Borrowed(Borrowed&& other) noexcept : Borrowed(other.value_, other.borrower_) {} template >> // NOLINTNEXTLINE(google-explicit-constructor) Borrowed(const Owned& other ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept : Borrowed(other.value_, other.owner_) {} Borrowed(Borrower borrower, T* ptr) noexcept : Borrowed(ptr, borrower) {} // NOLINTNEXTLINE(google-explicit-constructor) Borrowed(std::nullptr_t) noexcept : Borrowed() {} template >> // NOLINTNEXTLINE(google-explicit-constructor) Borrowed& operator=(const Borrowed& other) noexcept { value_ = other.value_; borrower_ = other.borrower_; return *this; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Borrowed& operator=(Borrowed&& other) noexcept { value_ = other.value_; borrower_ = other.borrower_; return *this; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Borrowed& operator=( const Owned& other ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { value_ = other.value_; borrower_ = other.borrower_; return *this; } template >> // NOLINTNEXTLINE(google-explicit-constructor) Borrowed& operator=(Owned&&) = delete; // NOLINTNEXTLINE(google-explicit-constructor) Borrowed& operator=(std::nullptr_t) noexcept { reset(); return *this; } T& operator*() const noexcept { ABSL_DCHECK(static_cast(*this)); return *get(); } T* absl_nonnull operator->() const noexcept { ABSL_DCHECK(static_cast(*this)); return get(); } void reset() noexcept { value_ = nullptr; borrower_.reset(); } google::protobuf::Arena* absl_nullable arena() const noexcept { return borrower_.arena(); } explicit operator bool() const noexcept { return get() != nullptr; } private: friend class Owner; friend class Borrower; template friend class Owned; template friend class Borrowed; template friend struct Borrowable; friend struct std::pointer_traits>; constexpr Borrowed(T* value, Borrower borrower) noexcept : value_(value), borrower_(borrower) {} T* get() const noexcept { return value_; } T* value_ = nullptr; Borrower borrower_; }; template Borrowed(T*) -> Borrowed; template Borrowed(Borrower, T*) -> Borrowed; template Borrowed(Owned) -> Borrowed; } // namespace cel namespace std { template struct pointer_traits> { using pointer = cel::Borrowed; using element_type = typename cel::Borrowed::element_type; using difference_type = ptrdiff_t; template using rebind = cel::Borrowed; static element_type* to_address(pointer p) noexcept { return p.value_; } }; } // namespace std namespace cel { template Owner::Owner(Borrowed borrowed) noexcept : Owner(borrowed.borrower_) {} template Borrower::Borrower(const Owned& owned ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept : Borrower(owned.owner_) {} template Borrower::Borrower(Borrowed borrowed) noexcept : Borrower(borrowed.borrower_) {} template Borrower& Borrower::operator=( const Owned& owned ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { *this = owned.owner_; return *this; } template Borrower& Borrower::operator=(Borrowed borrowed) noexcept { *this = borrowed.borrower_; return *this; } template bool operator==(Borrowed lhs, std::nullptr_t) noexcept { return !static_cast(lhs); } template bool operator==(std::nullptr_t, Borrowed rhs) noexcept { return rhs == nullptr; } template bool operator!=(Borrowed lhs, std::nullptr_t) noexcept { return !operator==(lhs, nullptr); } template bool operator!=(std::nullptr_t, Borrowed rhs) noexcept { return !operator==(nullptr, rhs); } template template Owned::Owned(Borrowed other) noexcept : Owned(other.value_, Owner(other.borrower_)) {} template template Owned& Owned::operator=(Borrowed other) noexcept { value_ = other.value_; owner_ = Owner(other.borrower_); return *this; } // `Ownable` is a mixin for enabling the ability to get `Owned` that refer to // this. template struct Ownable { protected: Owned Own() const noexcept { static_assert(std::is_base_of_v, "T must be derived from Data"); const T* const that = static_cast(this); return Owned( Owner(Owner::Own(static_cast(that)->owner_)), that); } Owned Own() noexcept { static_assert(std::is_base_of_v, "T must be derived from Data"); T* const that = static_cast(this); return Owned(Owner(Owner::Own(static_cast(that)->owner_)), that); } ABSL_DEPRECATED("Use Own") Owned shared_from_this() const noexcept { return Own(); } ABSL_DEPRECATED("Use Own") Owned shared_from_this() noexcept { return Own(); } }; // `Borrowable` is a mixin for enabling the ability to get `Borrowed` that // refer to this. template struct Borrowable { protected: Borrowed Borrow() const noexcept { static_assert(std::is_base_of_v, "T must be derived from Data"); const T* const that = static_cast(this); return Borrowed(Borrower(static_cast(that)->owner_), that); } Borrowed Borrow() noexcept { static_assert(std::is_base_of_v, "T must be derived from Data"); T* const that = static_cast(this); return Borrowed(Borrower(static_cast(that)->owner_), that); } }; // `ReferenceCountingMemoryManager` is a `MemoryManager` which employs automatic // memory management through reference counting. class ReferenceCountingMemoryManager final { public: ReferenceCountingMemoryManager(const ReferenceCountingMemoryManager&) = delete; ReferenceCountingMemoryManager(ReferenceCountingMemoryManager&&) = delete; ReferenceCountingMemoryManager& operator=( const ReferenceCountingMemoryManager&) = delete; ReferenceCountingMemoryManager& operator=(ReferenceCountingMemoryManager&&) = delete; private: static void* Allocate(size_t size, size_t alignment); static bool Deallocate(void* ptr, size_t size, size_t alignment) noexcept; explicit ReferenceCountingMemoryManager() = default; friend class MemoryManager; }; // `PoolingMemoryManager` is a `MemoryManager` which employs automatic // memory management through memory pooling. class PoolingMemoryManager final { public: PoolingMemoryManager(const PoolingMemoryManager&) = delete; PoolingMemoryManager(PoolingMemoryManager&&) = delete; PoolingMemoryManager& operator=(const PoolingMemoryManager&) = delete; PoolingMemoryManager& operator=(PoolingMemoryManager&&) = delete; private: // Allocates memory directly from the allocator used by this memory manager. // If `memory_management()` returns `MemoryManagement::kReferenceCounting`, // this allocation *must* be explicitly deallocated at some point via // `Deallocate`. Otherwise deallocation is optional. ABSL_MUST_USE_RESULT static void* Allocate(google::protobuf::Arena* absl_nonnull arena, size_t size, size_t alignment) { ABSL_DCHECK(absl::has_single_bit(alignment)) << "alignment must be a power of 2"; if (size == 0) { return nullptr; } return arena->AllocateAligned(size, alignment); } // Attempts to deallocate memory previously allocated via `Allocate`, `size` // and `alignment` must match the values from the previous call to `Allocate`. // Returns `true` if the deallocation was successful and additional calls to // `Allocate` may re-use the memory, `false` otherwise. Returns `false` if // given `nullptr`. static bool Deallocate(google::protobuf::Arena* absl_nonnull, void*, size_t, size_t alignment) noexcept { ABSL_DCHECK(absl::has_single_bit(alignment)) << "alignment must be a power of 2"; return false; } // Registers a custom destructor to be run upon destruction of the memory // management implementation. Return value is always `true`, indicating that // the destructor may be called at some point in the future. static bool OwnCustomDestructor(google::protobuf::Arena* absl_nonnull arena, void* object, void (*absl_nonnull destruct)(void*)) { ABSL_DCHECK(destruct != nullptr); arena->OwnCustomDestructor(object, destruct); return true; } template static void DefaultDestructor(void* ptr) { static_assert(!std::is_trivially_destructible_v); static_cast(ptr)->~T(); } explicit PoolingMemoryManager() = default; friend class MemoryManager; }; // `MemoryManager` is an abstraction for supporting automatic memory management. // All objects created by the `MemoryManager` have a lifetime governed by the // underlying memory management strategy. Currently `MemoryManager` is a // composed type that holds either a reference to // `ReferenceCountingMemoryManager` or owns a `PoolingMemoryManager`. // // ============================ Reference Counting ============================ // `Unique`: The object is valid until destruction of the `Unique`. // // `Shared`: The object is valid so long as one or more `Shared` managing the // object exist. // // ================================= Pooling ================================== // `Unique`: The object is valid until destruction of the underlying memory // resources or of the `Unique`. // // `Shared`: The object is valid until destruction of the underlying memory // resources. class MemoryManager final { public: // Returns a `MemoryManager` which utilizes an arena but never frees its // memory. It is effectively a memory leak and should only be used for limited // use cases, such as initializing singletons which live for the life of the // program. static MemoryManager Unmanaged(); // Returns a `MemoryManager` which utilizes reference counting. ABSL_MUST_USE_RESULT static MemoryManager ReferenceCounting() { return MemoryManager(nullptr); } // Returns a `MemoryManager` which utilizes an arena. ABSL_MUST_USE_RESULT static MemoryManager Pooling( google::protobuf::Arena* absl_nonnull arena) { return MemoryManager(arena); } explicit MemoryManager(Allocator<> allocator) : arena_(allocator.arena()) {} MemoryManager() = delete; MemoryManager(const MemoryManager&) = default; MemoryManager& operator=(const MemoryManager&) = default; MemoryManagement memory_management() const noexcept { return arena_ == nullptr ? MemoryManagement::kReferenceCounting : MemoryManagement::kPooling; } // Allocates memory directly from the allocator used by this memory manager. // If `memory_management()` returns `MemoryManagement::kReferenceCounting`, // this allocation *must* be explicitly deallocated at some point via // `Deallocate`. Otherwise deallocation is optional. ABSL_MUST_USE_RESULT void* Allocate(size_t size, size_t alignment) { if (arena_ == nullptr) { return ReferenceCountingMemoryManager::Allocate(size, alignment); } else { return PoolingMemoryManager::Allocate(arena_, size, alignment); } } // Attempts to deallocate memory previously allocated via `Allocate`, `size` // and `alignment` must match the values from the previous call to `Allocate`. // Returns `true` if the deallocation was successful and additional calls to // `Allocate` may re-use the memory, `false` otherwise. Returns `false` if // given `nullptr`. bool Deallocate(void* ptr, size_t size, size_t alignment) noexcept { if (arena_ == nullptr) { return ReferenceCountingMemoryManager::Deallocate(ptr, size, alignment); } else { return PoolingMemoryManager::Deallocate(arena_, ptr, size, alignment); } } // Registers a custom destructor to be run upon destruction of the memory // management implementation. A return of `true` indicates the destructor may // be called at some point in the future, `false` if will definitely not be // called. All pooling memory managers return `true` while the reference // counting memory manager returns `false`. bool OwnCustomDestructor(void* object, void (*absl_nonnull destruct)(void*)) { ABSL_DCHECK(destruct != nullptr); if (arena_ == nullptr) { return false; } else { return PoolingMemoryManager::OwnCustomDestructor(arena_, object, destruct); } } google::protobuf::Arena* absl_nullable arena() const noexcept { return arena_; } template // NOLINTNEXTLINE(google-explicit-constructor) operator Allocator() const { return arena(); } friend void swap(MemoryManager& lhs, MemoryManager& rhs) noexcept { using std::swap; swap(lhs.arena_, rhs.arena_); } private: friend class PoolingMemoryManager; explicit MemoryManager(std::nullptr_t) : arena_(nullptr) {} explicit MemoryManager(google::protobuf::Arena* absl_nonnull arena) : arena_(arena) {} // If `nullptr`, we are using reference counting. Otherwise we are using // Pooling. We use `UnreachablePooling()` as a sentinel to detect use after // move otherwise the moved-from `MemoryManager` would be in a valid state and // utilize reference counting. google::protobuf::Arena* absl_nullable arena_; }; using MemoryManagerRef = MemoryManager; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_MEMORY_H_ ================================================ FILE: common/memory_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This header contains primitives for reference counting, roughly equivalent to // the primitives used to implement `std::shared_ptr`. These primitives should // not be used directly in most cases, instead `cel::ManagedMemory` should be // used instead. #include "common/memory.h" #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "common/allocator.h" #include "common/data.h" #include "common/internal/reference_count.h" #include "internal/testing.h" #include "google/protobuf/arena.h" #ifdef ABSL_HAVE_EXCEPTIONS #include #endif namespace cel { namespace { using ::testing::IsFalse; using ::testing::IsNull; using ::testing::IsTrue; using ::testing::TestParamInfo; using ::testing::TestWithParam; TEST(Owner, None) { EXPECT_THAT(Owner::None(), IsFalse()); EXPECT_THAT(Owner::None().arena(), IsNull()); } TEST(Owner, Allocator) { google::protobuf::Arena arena; EXPECT_THAT(Owner::Allocator(NewDeleteAllocator<>{}), IsFalse()); EXPECT_THAT(Owner::Allocator(ArenaAllocator<>{&arena}), IsTrue()); } TEST(Owner, Arena) { google::protobuf::Arena arena; EXPECT_THAT(Owner::Arena(&arena), IsTrue()); EXPECT_EQ(Owner::Arena(&arena).arena(), &arena); } TEST(Owner, ReferenceCount) { auto* refcount = new common_internal::ReferenceCounted(); EXPECT_THAT(Owner::ReferenceCount(refcount), IsTrue()); EXPECT_THAT(Owner::ReferenceCount(refcount).arena(), IsNull()); common_internal::StrongUnref(refcount); } TEST(Owner, Equality) { google::protobuf::Arena arena1; google::protobuf::Arena arena2; EXPECT_EQ(Owner::None(), Owner::None()); EXPECT_EQ(Owner::Allocator(NewDeleteAllocator<>{}), Owner::None()); EXPECT_EQ(Owner::Arena(&arena1), Owner::Arena(&arena1)); EXPECT_NE(Owner::Arena(&arena1), Owner::None()); EXPECT_NE(Owner::None(), Owner::Arena(&arena1)); EXPECT_NE(Owner::Arena(&arena1), Owner::Arena(&arena2)); EXPECT_EQ(Owner::Allocator(ArenaAllocator<>{&arena1}), Owner::Arena(&arena1)); } TEST(Borrower, None) { EXPECT_THAT(Borrower::None(), IsFalse()); EXPECT_THAT(Borrower::None().arena(), IsNull()); } TEST(Borrower, Allocator) { google::protobuf::Arena arena; EXPECT_THAT(Borrower::Allocator(NewDeleteAllocator<>{}), IsFalse()); EXPECT_THAT(Borrower::Allocator(ArenaAllocator<>{&arena}), IsTrue()); } TEST(Borrower, Arena) { google::protobuf::Arena arena; EXPECT_THAT(Borrower::Arena(&arena), IsTrue()); EXPECT_EQ(Borrower::Arena(&arena).arena(), &arena); } TEST(Borrower, ReferenceCount) { auto* refcount = new common_internal::ReferenceCounted(); EXPECT_THAT(Borrower::ReferenceCount(refcount), IsTrue()); EXPECT_THAT(Borrower::ReferenceCount(refcount).arena(), IsNull()); common_internal::StrongUnref(refcount); } TEST(Borrower, Equality) { google::protobuf::Arena arena1; google::protobuf::Arena arena2; EXPECT_EQ(Borrower::None(), Borrower::None()); EXPECT_EQ(Borrower::Allocator(NewDeleteAllocator<>{}), Borrower::None()); EXPECT_EQ(Borrower::Arena(&arena1), Borrower::Arena(&arena1)); EXPECT_NE(Borrower::Arena(&arena1), Borrower::None()); EXPECT_NE(Borrower::None(), Borrower::Arena(&arena1)); EXPECT_NE(Borrower::Arena(&arena1), Borrower::Arena(&arena2)); EXPECT_EQ(Borrower::Allocator(ArenaAllocator<>{&arena1}), Borrower::Arena(&arena1)); } TEST(OwnerBorrower, CopyConstruct) { auto* refcount = new common_internal::ReferenceCounted(); Owner owner1 = Owner::ReferenceCount(refcount); common_internal::StrongUnref(refcount); Owner owner2(owner1); Borrower borrower(owner1); EXPECT_EQ(owner1, owner2); EXPECT_EQ(owner1, borrower); EXPECT_EQ(borrower, owner1); } TEST(OwnerBorrower, MoveConstruct) { auto* refcount = new common_internal::ReferenceCounted(); Owner owner1 = Owner::ReferenceCount(refcount); common_internal::StrongUnref(refcount); Owner owner2(std::move(owner1)); Borrower borrower(owner2); EXPECT_EQ(owner2, borrower); EXPECT_EQ(borrower, owner2); } TEST(OwnerBorrower, CopyAssign) { auto* refcount = new common_internal::ReferenceCounted(); Owner owner1 = Owner::ReferenceCount(refcount); common_internal::StrongUnref(refcount); Owner owner2; owner2 = owner1; Borrower borrower(owner1); EXPECT_EQ(owner1, owner2); EXPECT_EQ(owner1, borrower); EXPECT_EQ(borrower, owner1); } TEST(OwnerBorrower, MoveAssign) { auto* refcount = new common_internal::ReferenceCounted(); Owner owner1 = Owner::ReferenceCount(refcount); common_internal::StrongUnref(refcount); Owner owner2; owner2 = std::move(owner1); Borrower borrower(owner2); EXPECT_EQ(owner2, borrower); EXPECT_EQ(borrower, owner2); } TEST(Unique, ToAddress) { Unique unique; EXPECT_EQ(cel::to_address(unique), nullptr); unique = AllocateUnique(NewDeleteAllocator<>{}); EXPECT_EQ(cel::to_address(unique), unique.operator->()); } class OwnedTest : public TestWithParam { public: Allocator<> GetAllocator() { switch (GetParam()) { case AllocatorKind::kArena: return ArenaAllocator<>{&arena_}; case AllocatorKind::kNewDelete: return NewDeleteAllocator<>{}; } } private: google::protobuf::Arena arena_; }; TEST_P(OwnedTest, Default) { Owned owned; EXPECT_FALSE(owned); EXPECT_EQ(cel::to_address(owned), nullptr); EXPECT_FALSE(owned != nullptr); EXPECT_FALSE(nullptr != owned); } class TestData final : public Data { public: using InternalArenaConstructable_ = void; using DestructorSkippable_ = void; TestData() noexcept : Data() {} explicit TestData(google::protobuf::Arena* absl_nullable arena) noexcept : Data(arena) {} }; TEST_P(OwnedTest, AllocateSharedData) { auto owned = AllocateShared(GetAllocator()); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); } TEST_P(OwnedTest, AllocateSharedMessageLite) { auto owned = AllocateShared(GetAllocator()); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); } TEST_P(OwnedTest, WrapSharedData) { auto owned = WrapShared(google::protobuf::Arena::Create(GetAllocator().arena())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); } TEST_P(OwnedTest, WrapSharedMessageLite) { auto owned = WrapShared( google::protobuf::Arena::Create(GetAllocator().arena())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); } TEST_P(OwnedTest, SharedFromUniqueData) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); } TEST_P(OwnedTest, SharedFromUniqueMessageLite) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); EXPECT_EQ(Owner(owned).arena(), GetAllocator().arena()); EXPECT_EQ(Borrower(owned).arena(), GetAllocator().arena()); } TEST_P(OwnedTest, CopyConstruct) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned copied_owned(owned); EXPECT_EQ(copied_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, MoveConstruct) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned moved_owned(std::move(owned)); EXPECT_EQ(moved_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, CopyConstructOther) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned copied_owned(owned); EXPECT_EQ(copied_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, MoveConstructOther) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned moved_owned(std::move(owned)); EXPECT_EQ(moved_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, ConstructBorrowed) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned borrowed_owned(Borrowed{owned}); EXPECT_EQ(borrowed_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, ConstructOwner) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned owner_owned(Owner(owned), cel::to_address(owned)); EXPECT_EQ(owner_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, ConstructNullPtr) { Owned owned(nullptr); EXPECT_EQ(owned, nullptr); } TEST_P(OwnedTest, CopyAssign) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned copied_owned; copied_owned = owned; EXPECT_EQ(copied_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, MoveAssign) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned moved_owned; moved_owned = std::move(owned); EXPECT_EQ(moved_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, CopyAssignOther) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned copied_owned; copied_owned = owned; EXPECT_EQ(copied_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, MoveAssignOther) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned moved_owned; moved_owned = std::move(owned); EXPECT_EQ(moved_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, AssignBorrowed) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Owned borrowed_owned; borrowed_owned = Borrowed{owned}; EXPECT_EQ(borrowed_owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, AssignUnique) { Owned owned; owned = AllocateUnique(GetAllocator()); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); } TEST_P(OwnedTest, AssignNullPtr) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); EXPECT_TRUE(owned); owned = nullptr; EXPECT_FALSE(owned); } INSTANTIATE_TEST_SUITE_P(OwnedTest, OwnedTest, ::testing::Values(AllocatorKind::kArena, AllocatorKind::kNewDelete)); class BorrowedTest : public TestWithParam { public: Allocator<> GetAllocator() { switch (GetParam()) { case AllocatorKind::kArena: return ArenaAllocator<>{&arena_}; case AllocatorKind::kNewDelete: return NewDeleteAllocator<>{}; } } private: google::protobuf::Arena arena_; }; TEST_P(BorrowedTest, Default) { Borrowed borrowed; EXPECT_FALSE(borrowed); EXPECT_EQ(cel::to_address(borrowed), nullptr); EXPECT_FALSE(borrowed != nullptr); EXPECT_FALSE(nullptr != borrowed); } TEST_P(BorrowedTest, CopyConstruct) { auto owned = Owned(AllocateUnique(GetAllocator())); auto borrowed = Borrowed(owned); EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); Borrowed copied_borrowed(borrowed); EXPECT_EQ(copied_borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, MoveConstruct) { auto owned = Owned(AllocateUnique(GetAllocator())); auto borrowed = Borrowed(owned); EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); Borrowed moved_borrowed(std::move(borrowed)); EXPECT_EQ(moved_borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, CopyConstructOther) { auto owned = Owned(AllocateUnique(GetAllocator())); auto borrowed = Borrowed(owned); EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); Borrowed copied_borrowed(borrowed); EXPECT_EQ(copied_borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, MoveConstructOther) { auto owned = Owned(AllocateUnique(GetAllocator())); auto borrowed = Borrowed(owned); EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); Borrowed moved_borrowed(std::move(borrowed)); EXPECT_EQ(moved_borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, ConstructNullPtr) { Borrowed borrowed(nullptr); EXPECT_FALSE(borrowed); } TEST_P(BorrowedTest, CopyAssign) { auto owned = Owned(AllocateUnique(GetAllocator())); auto borrowed = Borrowed(owned); EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); Borrowed copied_borrowed; copied_borrowed = borrowed; EXPECT_EQ(copied_borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, MoveAssign) { auto owned = Owned(AllocateUnique(GetAllocator())); auto borrowed = Borrowed(owned); EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); Borrowed moved_borrowed; moved_borrowed = std::move(borrowed); EXPECT_EQ(moved_borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, CopyAssignOther) { auto owned = Owned(AllocateUnique(GetAllocator())); auto borrowed = Borrowed(owned); EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); Borrowed copied_borrowed; copied_borrowed = borrowed; EXPECT_EQ(copied_borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, MoveAssignOther) { auto owned = Owned(AllocateUnique(GetAllocator())); auto borrowed = Borrowed(owned); EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); Borrowed moved_borrowed; moved_borrowed = std::move(borrowed); EXPECT_EQ(moved_borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, AssignOwned) { auto owned = Owned(AllocateUnique(GetAllocator())); EXPECT_EQ(owned->GetArena(), GetAllocator().arena()); Borrowed borrowed = owned; EXPECT_EQ(borrowed->GetArena(), GetAllocator().arena()); } TEST_P(BorrowedTest, AssignNullPtr) { Borrowed borrowed; borrowed = nullptr; EXPECT_FALSE(borrowed); } INSTANTIATE_TEST_SUITE_P(BorrowedTest, BorrowedTest, ::testing::Values(AllocatorKind::kArena, AllocatorKind::kNewDelete)); } // namespace } // namespace cel ================================================ FILE: common/memory_testing.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_MEMORY_TESTING_H_ #define THIRD_PARTY_CEL_CPP_COMMON_MEMORY_TESTING_H_ #include #include #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "common/memory.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel::common_internal { template class ThreadCompatibleMemoryTest : public ::testing::TestWithParam> { public: void SetUp() override {} void TearDown() override { Finish(); } MemoryManagement memory_management() { return std::get<0>(this->GetParam()); } MemoryManagerRef memory_manager() { switch (memory_management()) { case MemoryManagement::kReferenceCounting: return MemoryManager::ReferenceCounting(); break; case MemoryManagement::kPooling: if (!arena_) { arena_.emplace(); } return MemoryManager::Pooling(&*arena_); break; } } void Finish() { arena_.reset(); } static std::string ToString( ::testing::TestParamInfo> param) { return absl::StrJoin(param.param, "_", absl::StreamFormatter()); } protected: virtual MemoryManager NewThreadCompatiblePoolingMemoryManager() { return MemoryManager::Pooling(&*arena_); } private: absl::optional arena_; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_MEMORY_TESTING_H_ ================================================ FILE: common/minimal_descriptor_database.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/minimal_descriptor_database.h" #include "absl/base/nullability.h" #include "internal/minimal_descriptor_database.h" #include "google/protobuf/descriptor_database.h" namespace cel { google::protobuf::DescriptorDatabase* absl_nonnull GetMinimalDescriptorDatabase() { return internal::GetMinimalDescriptorDatabase(); } } // namespace cel ================================================ FILE: common/minimal_descriptor_database.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_DATABASE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_DATABASE_H_ #include "absl/base/nullability.h" #include "google/protobuf/descriptor_database.h" namespace cel { // GetMinimalDescriptorDatabase returns a pointer to a // `google::protobuf::DescriptorDatabase` which includes has the minimally necessary // descriptors required by the Common Expression Language. The returned // `google::protobuf::DescriptorDatabase` is valid for the lifetime of the process and // should not be deleted. google::protobuf::DescriptorDatabase* absl_nonnull GetMinimalDescriptorDatabase(); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_DATABASE_H_ ================================================ FILE: common/minimal_descriptor_database_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/minimal_descriptor_database.h" #include "google/protobuf/descriptor.pb.h" #include "internal/testing.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { using ::testing::IsTrue; TEST(GetMinimalDescriptorDatabase, NullValue) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.NullValue", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, BoolValue) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.BoolValue", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, Int32Value) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.Int32Value", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, Int64Value) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.Int64Value", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, UInt32Value) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.UInt32Value", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, UInt64Value) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.UInt64Value", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, FloatValue) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.FloatValue", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, DoubleValue) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.DoubleValue", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, BytesValue) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.BytesValue", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, StringValue) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.StringValue", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, Any) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.Any", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, Duration) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.Duration", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, Timestamp) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.Timestamp", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, Value) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.Value", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, ListValue) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.ListValue", &fd), IsTrue()); } TEST(GetMinimalDescriptorDatabase, Struct) { google::protobuf::FileDescriptorProto fd; EXPECT_THAT(GetMinimalDescriptorDatabase()->FindFileContainingSymbol( "google.protobuf.Struct", &fd), IsTrue()); } } // namespace } // namespace cel ================================================ FILE: common/minimal_descriptor_pool.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/minimal_descriptor_pool.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "internal/minimal_descriptor_pool.h" #include "google/protobuf/descriptor.h" namespace cel { const google::protobuf::DescriptorPool* absl_nonnull GetMinimalDescriptorPool() { return internal::GetMinimalDescriptorPool(); } // If required, adds the minimally required descriptors to the pool. absl::Status AddMinimumRequiredDescriptorsToPool( google::protobuf::DescriptorPool* absl_nonnull pool) { return internal::AddMinimumRequiredDescriptorsToPool(pool); } } // namespace cel ================================================ FILE: common/minimal_descriptor_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "google/protobuf/descriptor.h" namespace cel { // GetMinimalDescriptorPool returns a pointer to a `google::protobuf::DescriptorPool` // which includes has the minimally necessary descriptors required by the Common // Expression Language. The returned `google::protobuf::DescriptorPool` is valid for the // lifetime of the process and should not be deleted. const google::protobuf::DescriptorPool* absl_nonnull GetMinimalDescriptorPool(); // If required, adds the minimally required descriptors to the pool. absl::Status AddMinimumRequiredDescriptorsToPool( google::protobuf::DescriptorPool* absl_nonnull pool); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_MINIMAL_DESCRIPTOR_POOL_H_ ================================================ FILE: common/minimal_descriptor_pool_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/minimal_descriptor_pool.h" #include "absl/status/status_matchers.h" #include "internal/testing.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::testing::NotNull; TEST(GetMinimalDescriptorPool, NullValue) { ASSERT_THAT(GetMinimalDescriptorPool()->FindEnumTypeByName( "google.protobuf.NullValue"), NotNull()); } TEST(GetMinimalDescriptorPool, BoolValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.BoolValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE); } TEST(GetMinimalDescriptorPool, Int32Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Int32Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE); } TEST(GetMinimalDescriptorPool, Int64Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Int64Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE); } TEST(GetMinimalDescriptorPool, UInt32Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.UInt32Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE); } TEST(GetMinimalDescriptorPool, UInt64Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.UInt64Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE); } TEST(GetMinimalDescriptorPool, FloatValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.FloatValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE); } TEST(GetMinimalDescriptorPool, DoubleValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.DoubleValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE); } TEST(GetMinimalDescriptorPool, BytesValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.BytesValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE); } TEST(GetMinimalDescriptorPool, StringValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.StringValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE); } TEST(GetMinimalDescriptorPool, Any) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName("google.protobuf.Any"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_ANY); } TEST(GetMinimalDescriptorPool, Duration) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Duration"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION); } TEST(GetMinimalDescriptorPool, Timestamp) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Timestamp"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP); } TEST(GetMinimalDescriptorPool, Value) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); } TEST(GetMinimalDescriptorPool, ListValue) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.ListValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); } TEST(GetMinimalDescriptorPool, Struct) { const auto* desc = GetMinimalDescriptorPool()->FindMessageTypeByName( "google.protobuf.Struct"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); } TEST(AddMinimumRequiredDescriptorsToPool, Adds) { google::protobuf::DescriptorPool pool; ASSERT_THAT(AddMinimumRequiredDescriptorsToPool(&pool), IsOk()); EXPECT_THAT(pool.FindEnumTypeByName("google.protobuf.NullValue"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.BoolValue"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Int32Value"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Int64Value"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.UInt32Value"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.UInt64Value"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.FloatValue"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.DoubleValue"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.BytesValue"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.StringValue"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Any"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Duration"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Timestamp"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Value"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.ListValue"), NotNull()); EXPECT_THAT(pool.FindMessageTypeByName("google.protobuf.Struct"), NotNull()); } } // namespace } // namespace cel ================================================ FILE: common/native_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_NATIVE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_NATIVE_TYPE_H_ #include "common/typeinfo.h" namespace cel { using NativeTypeId = TypeInfo; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_NATIVE_TYPE_H_ ================================================ FILE: common/navigable_ast.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/navigable_ast.h" #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "common/ast/navigable_ast_internal.h" #include "common/ast_traverse.h" #include "common/ast_visitor.h" #include "common/ast_visitor_base.h" #include "common/expr.h" namespace cel { namespace { using NavigableAstNodeData = common_internal::NavigableAstNodeData; using NavigableAstMetadata = common_internal::NavigableAstMetadata; NodeKind GetNodeKind(const Expr& expr) { switch (expr.kind_case()) { case ExprKindCase::kConstant: return NodeKind::kConstant; case ExprKindCase::kIdentExpr: return NodeKind::kIdent; case ExprKindCase::kSelectExpr: return NodeKind::kSelect; case ExprKindCase::kCallExpr: return NodeKind::kCall; case ExprKindCase::kListExpr: return NodeKind::kList; case ExprKindCase::kStructExpr: return NodeKind::kStruct; case ExprKindCase::kMapExpr: return NodeKind::kMap; case ExprKindCase::kComprehensionExpr: return NodeKind::kComprehension; case ExprKindCase::kUnspecifiedExpr: default: return NodeKind::kUnspecified; } } // Get the traversal relationship from parent to the given node. // Note: these depend on the ast_visitor utility's traversal ordering. ChildKind GetChildKind(const NavigableAstNodeData& parent_node, size_t child_index, absl::optional comprehension_arg) { switch (parent_node.node_kind) { case NodeKind::kStruct: return ChildKind::kStructValue; case NodeKind::kMap: if (child_index % 2 == 0) { return ChildKind::kMapKey; } return ChildKind::kMapValue; case NodeKind::kList: return ChildKind::kListElem; case NodeKind::kSelect: return ChildKind::kSelectOperand; case NodeKind::kCall: if (child_index == 0 && parent_node.expr->call_expr().has_target()) { return ChildKind::kCallReceiver; } return ChildKind::kCallArg; case NodeKind::kComprehension: if (!comprehension_arg.has_value()) { return ChildKind::kUnspecified; } switch (*comprehension_arg) { case ComprehensionArg::ITER_RANGE: return ChildKind::kComprehensionRange; case ComprehensionArg::ACCU_INIT: return ChildKind::kComprehensionInit; case ComprehensionArg::LOOP_CONDITION: return ChildKind::kComprehensionCondition; case ComprehensionArg::LOOP_STEP: return ChildKind::kComprehensionLoopStep; case ComprehensionArg::RESULT: return ChildKind::kComprensionResult; default: return ChildKind::kUnspecified; } default: return ChildKind::kUnspecified; } } class NavigableExprBuilderVisitor : public cel::AstVisitorBase { public: NavigableExprBuilderVisitor( absl::AnyInvocable()> node_factory, absl::AnyInvocable node_data_accessor) : node_factory_(std::move(node_factory)), node_data_accessor_(std::move(node_data_accessor)), metadata_(std::make_unique()) {} NavigableAstNodeData& NodeDataAt(size_t index) { return node_data_accessor_(*metadata_->nodes[index]); } void PreVisitExpr(const Expr& expr) override { NavigableAstNode* parent = parent_stack_.empty() ? nullptr : metadata_->nodes[parent_stack_.back()].get(); size_t index = metadata_->nodes.size(); metadata_->nodes.push_back(node_factory_()); NavigableAstNode* node = metadata_->nodes[index].get(); auto& node_data = NodeDataAt(index); node_data.parent = parent; node_data.expr = &expr; node_data.parent_relation = ChildKind::kUnspecified; node_data.node_kind = GetNodeKind(expr); node_data.tree_size = 1; node_data.height = 1; node_data.index = index; node_data.child_index = -1; node_data.metadata = metadata_.get(); metadata_->id_to_node.insert({expr.id(), node}); metadata_->expr_to_node.insert({&expr, node}); if (!parent_stack_.empty()) { auto& parent_node_data = NodeDataAt(parent_stack_.back()); size_t child_index = parent_node_data.children.size(); parent_node_data.children.push_back(node); node_data.parent_relation = GetChildKind(parent_node_data, child_index, comprehension_arg_); node_data.child_index = child_index; } parent_stack_.push_back(index); } void PreVisitComprehensionSubexpression( const Expr& expr, const ComprehensionExpr& comprehension, ComprehensionArg comprehension_arg) override { comprehension_arg_ = comprehension_arg; } void PostVisitExpr(const Expr& expr) override { size_t idx = parent_stack_.back(); parent_stack_.pop_back(); metadata_->postorder.push_back(metadata_->nodes[idx].get()); NavigableAstNodeData& node = NodeDataAt(idx); if (!parent_stack_.empty()) { auto& parent_node_data = NodeDataAt(parent_stack_.back()); parent_node_data.tree_size += node.tree_size; parent_node_data.height = std::max(parent_node_data.height, node.height + 1); } } std::unique_ptr Consume() && { return std::move(metadata_); } private: absl::AnyInvocable()> node_factory_; absl::AnyInvocable node_data_accessor_; std::unique_ptr metadata_; std::vector parent_stack_; absl::optional comprehension_arg_; }; } // namespace NavigableAst NavigableAst::Build(const Expr& expr) { cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; NavigableExprBuilderVisitor visitor( []() { return absl::WrapUnique(new NavigableAstNode()); }, [](NavigableAstNode& node) -> NavigableAstNodeData& { return node.data_; }); AstTraverse(expr, visitor, opts); return NavigableAst(std::move(visitor).Consume()); } } // namespace cel ================================================ FILE: common/navigable_ast.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_NAVIGABLE_AST_H_ #define THIRD_PARTY_CEL_CPP_COMMON_NAVIGABLE_AST_H_ #include "common/ast/navigable_ast_internal.h" #include "common/ast/navigable_ast_kinds.h" // IWYU pragma: export #include "common/expr.h" namespace cel { class NavigableAst; class NavigableAstNode; namespace common_internal { struct NativeAstTraits { using ExprType = Expr; using AstType = NavigableAst; using NodeType = NavigableAstNode; }; } // namespace common_internal // Wrapper around a CEL AST node that exposes traversal information. class NavigableAstNode : public common_internal::NavigableAstNodeBase< common_internal::NativeAstTraits> { private: using Base = common_internal::NavigableAstNodeBase; public: // A const Span like type that provides pre-order traversal for a sub tree. // provides .begin() and .end() returning bidirectional iterators to // const AstNode&. using PreorderRange = Base::PreorderRange; // A const Span like type that provides post-order traversal for a sub tree. // provides .begin() and .end() returning bidirectional iterators to // const AstNode&. using PostorderRange = Base::PostorderRange; // The parent of this node or nullptr if it is a root. using Base::parent; // The ptr to the backing Expr in the source AST. // // This may dangle if the source AST is mutated or destroyed. using Base::expr; // The index of this node in the parent's children. -1 if this is a root. using Base::child_index; // The type of traversal from parent to this node. using Base::parent_relation; // The type of this node, analogous to Expr::ExprKindCase. using Base::node_kind; // The number of nodes in the tree rooted at this node (including self). using Base::tree_size; // The height of this node in the tree (the number of descendants including // self on the longest path). using Base::height; // The children of this node in their natural order. using Base::children; // Range over the descendants of this node (including self) using preorder // semantics. Each node is visited immediately before all of its descendants. // // example: // for (const cel::NavigableAstNode& node : // ast.Root().DescendantsPreorder()) { // ... // } // // Children are traversed in their natural order: // - call arguments are traversed in order (receiver if present is first) // - list elements are traversed in order // - maps are traversed in order (alternating key, value per entry) // - comprehensions are traversed in the order: range, accu_init, condition, // step, result using Base::DescendantsPreorder; // Range over the descendants of this node (including self) using postorder // semantics. Each node is visited immediately after all of its descendants. using Base::DescendantsPostorder; private: friend class NavigableAst; NavigableAstNode() = default; }; // NavigableExpr provides a view over a CEL AST that allows for generalized // traversal. The traversal structures are eagerly built on construction, // requiring a full traversal of the AST. This is intended for use in tools that // might require random access or multiple passes over the AST, amortizing the // cost of building the traversal structures. // // Pointers to AstNodes are owned by this instance and must not outlive it. // // `NavigableAst` and Navigable nodes are independent of the input Expr and may // outlive it, but may contain dangling pointers if the input Expr is modified // or destroyed. class NavigableAst : public common_internal::NavigableAstBase< common_internal::NativeAstTraits> { private: using Base = common_internal::NavigableAstBase; public: static NavigableAst Build(const Expr& expr); // Default constructor creates an empty instance. // // Operations other than equality are undefined on an empty instance. // // This is intended for composed object construction, a new NavigableAst // should be obtained from the Build factory function. NavigableAst() = default; // Move only. NavigableAst(const NavigableAst&) = delete; NavigableAst& operator=(const NavigableAst&) = delete; NavigableAst(NavigableAst&&) = default; NavigableAst& operator=(NavigableAst&&) = default; // Return ptr to the AST node with id if present. Otherwise returns nullptr. // // If ids are non-unique, the first pre-order node encountered with id is // returned. using Base::FindId; // Return ptr to the AST node representing the given Expr node. using Base::FindExpr; // Returns the root of the AST. using Base::Root; // Return whether the source AST used unique IDs for each node. // // This is typically the case, but older versions of the parsers didn't // guarantee uniqueness for nodes generated by some macros and ASTs modified // outside of CEL's parse/type check may not have unique IDs. using Base::IdsAreUnique; private: using Base::Base; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_NAVIGABLE_AST_H_ ================================================ FILE: common/navigable_ast_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/navigable_ast.h" #include #include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/ast.h" #include "common/expr.h" #include "common/source.h" #include "common/standard_definitions.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" namespace cel { namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::SizeIs; absl::StatusOr> Parse(absl::string_view expr) { static const auto* parser = cel::NewParserBuilder()->Build()->release(); CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expr)); return parser->Parse(*source); } TEST(NavigableAst, Basic) { Expr const_node; const_node.set_id(1); const_node.mutable_const_expr().set_int_value(42); NavigableAst ast = NavigableAst::Build(const_node); EXPECT_TRUE(ast.IdsAreUnique()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.expr(), &const_node); EXPECT_THAT(root.children(), IsEmpty()); EXPECT_TRUE(root.parent() == nullptr); EXPECT_EQ(root.child_index(), -1); EXPECT_EQ(root.node_kind(), NodeKind::kConstant); EXPECT_EQ(root.parent_relation(), ChildKind::kUnspecified); } TEST(NavigableAst, DefaultCtorEmpty) { Expr const_node; const_node.set_id(1); const_node.mutable_const_expr().set_int_value(42); NavigableAst ast = NavigableAst::Build(const_node); EXPECT_EQ(ast, ast); NavigableAst empty; EXPECT_NE(ast, empty); EXPECT_EQ(empty, empty); EXPECT_TRUE(static_cast(ast)); EXPECT_FALSE(static_cast(empty)); NavigableAst moved = std::move(ast); EXPECT_EQ(ast, empty); EXPECT_FALSE(static_cast(ast)); EXPECT_TRUE(static_cast(moved)); } TEST(NavigableAst, FindById) { Expr const_node; const_node.set_id(1); const_node.mutable_const_expr().set_int_value(42); NavigableAst ast = NavigableAst::Build(const_node); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(ast.FindId(const_node.id()), &root); EXPECT_EQ(ast.FindId(-1), nullptr); } MATCHER_P(AstNodeWrapping, expr, "") { const NavigableAstNode* ptr = arg; return ptr != nullptr && ptr->expr() == expr; } TEST(NavigableAst, ToleratesNonUnique) { Expr call_node; call_node.set_id(1); call_node.mutable_call_expr().set_function(cel::StandardFunctions::kNot); Expr* const_node = &call_node.mutable_call_expr().mutable_args().emplace_back(); const_node->mutable_const_expr().set_bool_value(false); const_node->set_id(1); NavigableAst ast = NavigableAst::Build(call_node); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(ast.FindId(1), &root); EXPECT_EQ(ast.FindExpr(&call_node), &root); EXPECT_FALSE(ast.IdsAreUnique()); EXPECT_THAT(ast.FindExpr(const_node), AstNodeWrapping(const_node)); } TEST(NavigableAst, FindByExprPtr) { Expr const_node; const_node.set_id(1); const_node.mutable_const_expr().set_int_value(42); NavigableAst ast = NavigableAst::Build(const_node); const NavigableAstNode& root = ast.Root(); Expr other_expr; EXPECT_EQ(ast.FindExpr(&const_node), &root); EXPECT_EQ(ast.FindExpr(&other_expr), nullptr); } TEST(NavigableAst, Children) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + 2")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.expr(), &parsed_expr->root_expr()); EXPECT_THAT(root.children(), SizeIs(2)); EXPECT_TRUE(root.parent() == nullptr); EXPECT_EQ(root.child_index(), -1); EXPECT_EQ(root.parent_relation(), ChildKind::kUnspecified); EXPECT_EQ(root.node_kind(), NodeKind::kCall); EXPECT_THAT( root.children(), ElementsAre( AstNodeWrapping(&parsed_expr->root_expr().call_expr().args().at(0)), AstNodeWrapping(&parsed_expr->root_expr().call_expr().args().at(1)))); ASSERT_THAT(root.children(), SizeIs(2)); const auto* child1 = root.children()[0]; EXPECT_EQ(child1->child_index(), 0); EXPECT_EQ(child1->parent(), &root); EXPECT_EQ(child1->parent_relation(), ChildKind::kCallArg); EXPECT_EQ(child1->node_kind(), NodeKind::kConstant); EXPECT_THAT(child1->children(), IsEmpty()); const auto* child2 = root.children()[1]; EXPECT_EQ(child2->child_index(), 1); } TEST(NavigableAst, UnspecifiedExpr) { Expr expr; expr.set_id(1); NavigableAst ast = NavigableAst::Build(expr); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.expr(), &expr); EXPECT_THAT(root.children(), SizeIs(0)); EXPECT_TRUE(root.parent() == nullptr); EXPECT_EQ(root.child_index(), -1); EXPECT_EQ(root.node_kind(), NodeKind::kUnspecified); } TEST(NavigableAst, ParentRelationSelect) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("a.b")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); ASSERT_THAT(root.children(), SizeIs(1)); const auto* child = root.children()[0]; EXPECT_EQ(child->parent_relation(), ChildKind::kSelectOperand); EXPECT_EQ(child->node_kind(), NodeKind::kIdent); } TEST(NavigableAst, ParentRelationCallReceiver) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("a.b()")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); ASSERT_THAT(root.children(), SizeIs(1)); const auto* child = root.children()[0]; EXPECT_EQ(child->parent_relation(), ChildKind::kCallReceiver); EXPECT_EQ(child->node_kind(), NodeKind::kIdent); } TEST(NavigableAst, ParentRelationCreateStruct) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("com.example.Type{field: '123'}")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kStruct); ASSERT_THAT(root.children(), SizeIs(1)); const auto* child = root.children()[0]; EXPECT_EQ(child->parent_relation(), ChildKind::kStructValue); EXPECT_EQ(child->node_kind(), NodeKind::kConstant); } TEST(NavigableAst, ParentRelationCreateMap) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'a': 123}")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kMap); ASSERT_THAT(root.children(), SizeIs(2)); const auto* key = root.children()[0]; const auto* value = root.children()[1]; EXPECT_EQ(key->parent_relation(), ChildKind::kMapKey); EXPECT_EQ(key->node_kind(), NodeKind::kConstant); EXPECT_EQ(value->parent_relation(), ChildKind::kMapValue); EXPECT_EQ(value->node_kind(), NodeKind::kConstant); } TEST(NavigableAst, ParentRelationCreateList) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[123]")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kList); ASSERT_THAT(root.children(), SizeIs(1)); const auto* child = root.children()[0]; EXPECT_EQ(child->parent_relation(), ChildKind::kListElem); EXPECT_EQ(child->node_kind(), NodeKind::kConstant); } TEST(NavigableAst, ParentRelationComprehension) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1].all(x, x < 2)")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); ASSERT_THAT(root.children(), SizeIs(5)); const auto* range = root.children()[0]; const auto* init = root.children()[1]; const auto* condition = root.children()[2]; const auto* step = root.children()[3]; const auto* finish = root.children()[4]; EXPECT_EQ(range->parent_relation(), ChildKind::kComprehensionRange); EXPECT_EQ(init->parent_relation(), ChildKind::kComprehensionInit); EXPECT_EQ(condition->parent_relation(), ChildKind::kComprehensionCondition); EXPECT_EQ(step->parent_relation(), ChildKind::kComprehensionLoopStep); EXPECT_EQ(finish->parent_relation(), ChildKind::kComprensionResult); } TEST(NavigableAst, DescendantsPostorder) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + (x * 3)")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kCall); std::vector constants; std::vector node_kinds; for (const NavigableAstNode& node : root.DescendantsPostorder()) { if (node.node_kind() == NodeKind::kConstant) { constants.push_back(node.expr()->const_expr().int64_value()); } node_kinds.push_back(node.node_kind()); } EXPECT_THAT(node_kinds, ElementsAre(NodeKind::kConstant, NodeKind::kIdent, NodeKind::kConstant, NodeKind::kCall, NodeKind::kCall)); EXPECT_THAT(constants, ElementsAre(1, 3)); } TEST(NavigableAst, DescendantsPreorder) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + (x * 3)")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kCall); std::vector constants; std::vector node_kinds; for (const NavigableAstNode& node : root.DescendantsPreorder()) { if (node.node_kind() == NodeKind::kConstant) { constants.push_back(node.expr()->const_expr().int64_value()); } node_kinds.push_back(node.node_kind()); } EXPECT_THAT(node_kinds, ElementsAre(NodeKind::kCall, NodeKind::kConstant, NodeKind::kCall, NodeKind::kIdent, NodeKind::kConstant)); EXPECT_THAT(constants, ElementsAre(1, 3)); } TEST(NavigableAst, DescendantsPreorderComprehension) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); std::vector> node_kinds; for (const NavigableAstNode& node : root.DescendantsPreorder()) { node_kinds.push_back( std::make_pair(node.node_kind(), node.parent_relation())); } EXPECT_THAT( node_kinds, ElementsAre(Pair(NodeKind::kComprehension, ChildKind::kUnspecified), Pair(NodeKind::kList, ChildKind::kComprehensionRange), Pair(NodeKind::kConstant, ChildKind::kListElem), Pair(NodeKind::kConstant, ChildKind::kListElem), Pair(NodeKind::kConstant, ChildKind::kListElem), Pair(NodeKind::kList, ChildKind::kComprehensionInit), Pair(NodeKind::kConstant, ChildKind::kComprehensionCondition), Pair(NodeKind::kCall, ChildKind::kComprehensionLoopStep), Pair(NodeKind::kIdent, ChildKind::kCallArg), Pair(NodeKind::kList, ChildKind::kCallArg), Pair(NodeKind::kCall, ChildKind::kListElem), Pair(NodeKind::kIdent, ChildKind::kCallArg), Pair(NodeKind::kConstant, ChildKind::kCallArg), Pair(NodeKind::kIdent, ChildKind::kComprensionResult))); } TEST(NavigableAst, TreeSize) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); std::vector> node_kinds; EXPECT_EQ(root.tree_size(), 14); auto it = root.DescendantsPostorder().begin(); EXPECT_EQ(it->tree_size(), 1); } TEST(NavigableAst, Height) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); std::vector> node_kinds; EXPECT_EQ(root.height(), 5); auto it = root.DescendantsPostorder().begin(); EXPECT_EQ(it->height(), 1); } TEST(NavigableAst, DescendantsPreorderCreateMap) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'key1': 1, 'key2': 2}")); NavigableAst ast = NavigableAst::Build(parsed_expr->root_expr()); const NavigableAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kMap); std::vector> node_kinds; for (const NavigableAstNode& node : root.DescendantsPreorder()) { node_kinds.push_back( std::make_pair(node.node_kind(), node.parent_relation())); } EXPECT_THAT(node_kinds, ElementsAre(Pair(NodeKind::kMap, ChildKind::kUnspecified), Pair(NodeKind::kConstant, ChildKind::kMapKey), Pair(NodeKind::kConstant, ChildKind::kMapValue), Pair(NodeKind::kConstant, ChildKind::kMapKey), Pair(NodeKind::kConstant, ChildKind::kMapValue))); } } // namespace } // namespace cel ================================================ FILE: common/operators.cc ================================================ // Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/operators.h" #include #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #undef IN namespace google::api::expr::common { namespace { // These functions provide access to reverse mappings for operators. // Functions generally map from text expression to Expr representation, // e.g., from "&&" to "_&&_". Reverse operators provides a mapping from // Expr to textual mapping, e.g., from "_&&_" to "&&". const absl::flat_hash_map& UnaryOperators() { static auto* unaries_map = new absl::flat_hash_map{ {CelOperator::NEGATE, "-"}, {CelOperator::LOGICAL_NOT, "!"}}; return *unaries_map; } const absl::flat_hash_map& BinaryOperators() { static auto* binops_map = new absl::flat_hash_map{ {CelOperator::LOGICAL_OR, "||"}, {CelOperator::LOGICAL_AND, "&&"}, {CelOperator::LESS_EQUALS, "<="}, {CelOperator::LESS, "<"}, {CelOperator::GREATER_EQUALS, ">="}, {CelOperator::GREATER, ">"}, {CelOperator::EQUALS, "=="}, {CelOperator::NOT_EQUALS, "!="}, {CelOperator::IN_DEPRECATED, "in"}, {CelOperator::IN, "in"}, {CelOperator::ADD, "+"}, {CelOperator::SUBTRACT, "-"}, {CelOperator::MULTIPLY, "*"}, {CelOperator::DIVIDE, "/"}, {CelOperator::MODULO, "%"}}; return *binops_map; } const absl::flat_hash_map& ReverseOperators() { static auto* operators_map = new absl::flat_hash_map{ {"+", CelOperator::ADD}, {"-", CelOperator::SUBTRACT}, {"*", CelOperator::MULTIPLY}, {"/", CelOperator::DIVIDE}, {"%", CelOperator::MODULO}, {"==", CelOperator::EQUALS}, {"!=", CelOperator::NOT_EQUALS}, {">", CelOperator::GREATER}, {">=", CelOperator::GREATER_EQUALS}, {"<", CelOperator::LESS}, {"<=", CelOperator::LESS_EQUALS}, {"&&", CelOperator::LOGICAL_AND}, {"!", CelOperator::LOGICAL_NOT}, {"||", CelOperator::LOGICAL_OR}, {"in", CelOperator::IN}, }; return *operators_map; } const absl::flat_hash_map& Operators() { static auto* operators_map = new absl::flat_hash_map{ {CelOperator::ADD, "+"}, {CelOperator::SUBTRACT, "-"}, {CelOperator::MULTIPLY, "*"}, {CelOperator::DIVIDE, "/"}, {CelOperator::MODULO, "%"}, {CelOperator::EQUALS, "=="}, {CelOperator::NOT_EQUALS, "!="}, {CelOperator::GREATER, ">"}, {CelOperator::GREATER_EQUALS, ">="}, {CelOperator::LESS, "<"}, {CelOperator::LESS_EQUALS, "<="}, {CelOperator::LOGICAL_AND, "&&"}, {CelOperator::LOGICAL_NOT, "!"}, {CelOperator::LOGICAL_OR, "||"}, {CelOperator::IN, "in"}, {CelOperator::IN_DEPRECATED, "in"}, {CelOperator::NEGATE, "-"}}; return *operators_map; } // precedence of the operator, where the higher value means higher. const absl::flat_hash_map& Precedences() { static auto* precedence_map = new absl::flat_hash_map{ {CelOperator::CONDITIONAL, 8}, {CelOperator::LOGICAL_OR, 7}, {CelOperator::LOGICAL_AND, 6}, {CelOperator::EQUALS, 5}, {CelOperator::GREATER, 5}, {CelOperator::GREATER_EQUALS, 5}, {CelOperator::IN, 5}, {CelOperator::LESS, 5}, {CelOperator::LESS_EQUALS, 5}, {CelOperator::NOT_EQUALS, 5}, {CelOperator::IN_DEPRECATED, 5}, {CelOperator::ADD, 4}, {CelOperator::SUBTRACT, 4}, {CelOperator::DIVIDE, 3}, {CelOperator::MODULO, 3}, {CelOperator::MULTIPLY, 3}, {CelOperator::LOGICAL_NOT, 2}, {CelOperator::NEGATE, 2}, {CelOperator::INDEX, 1}}; return *precedence_map; } } // namespace const char* CelOperator::CONDITIONAL = "_?_:_"; const char* CelOperator::LOGICAL_AND = "_&&_"; const char* CelOperator::LOGICAL_OR = "_||_"; const char* CelOperator::LOGICAL_NOT = "!_"; const char* CelOperator::IN_DEPRECATED = "_in_"; const char* CelOperator::EQUALS = "_==_"; const char* CelOperator::NOT_EQUALS = "_!=_"; const char* CelOperator::LESS = "_<_"; const char* CelOperator::LESS_EQUALS = "_<=_"; const char* CelOperator::GREATER = "_>_"; const char* CelOperator::GREATER_EQUALS = "_>=_"; const char* CelOperator::ADD = "_+_"; const char* CelOperator::SUBTRACT = "_-_"; const char* CelOperator::MULTIPLY = "_*_"; const char* CelOperator::DIVIDE = "_/_"; const char* CelOperator::MODULO = "_%_"; const char* CelOperator::NEGATE = "-_"; const char* CelOperator::INDEX = "_[_]"; const char* CelOperator::HAS = "has"; const char* CelOperator::ALL = "all"; const char* CelOperator::EXISTS = "exists"; const char* CelOperator::EXISTS_ONE = "exists_one"; const char* CelOperator::MAP = "map"; const char* CelOperator::FILTER = "filter"; const char* CelOperator::NOT_STRICTLY_FALSE = "@not_strictly_false"; const char* CelOperator::IN = "@in"; const absl::string_view CelOperator::OPT_INDEX = "_[?_]"; const absl::string_view CelOperator::OPT_SELECT = "_?._"; int LookupPrecedence(absl::string_view op) { const auto& precs = Precedences(); auto p = precs.find(op); if (p != precs.end()) { return p->second; } return 0; } absl::optional LookupUnaryOperator(absl::string_view op) { const auto& unary_ops = UnaryOperators(); auto o = unary_ops.find(op); if (o == unary_ops.end()) { return absl::optional(); } return o->second; } absl::optional LookupBinaryOperator(absl::string_view op) { const auto& bin_ops = BinaryOperators(); auto o = bin_ops.find(op); if (o == bin_ops.end()) { return absl::optional(); } return o->second; } absl::optional LookupOperator(absl::string_view op) { const auto& ops = Operators(); auto o = ops.find(op); if (o == ops.end()) { return absl::optional(); } return o->second; } absl::optional ReverseLookupOperator(absl::string_view op) { const auto& rev_ops = ReverseOperators(); auto o = rev_ops.find(op); if (o == rev_ops.end()) { return absl::optional(); } return o->second; } bool IsOperatorSamePrecedence(absl::string_view op, const cel::expr::Expr& expr) { if (!expr.has_call_expr()) { return false; } return LookupPrecedence(op) == LookupPrecedence(expr.call_expr().function()); } bool IsOperatorLowerPrecedence(absl::string_view op, const cel::expr::Expr& expr) { if (!expr.has_call_expr()) { return false; } return LookupPrecedence(op) < LookupPrecedence(expr.call_expr().function()); } bool IsOperatorLeftRecursive(absl::string_view op) { return op != CelOperator::LOGICAL_AND && op != CelOperator::LOGICAL_OR; } } // namespace google::api::expr::common ================================================ FILE: common/operators.h ================================================ // Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ #define THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ #include #include #include "cel/expr/syntax.pb.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" namespace google::api::expr::common { // Operator function names. struct CelOperator { static const char* CONDITIONAL; static const char* LOGICAL_AND; static const char* LOGICAL_OR; static const char* LOGICAL_NOT; static const char* IN_DEPRECATED; static const char* EQUALS; static const char* NOT_EQUALS; static const char* LESS; static const char* LESS_EQUALS; static const char* GREATER; static const char* GREATER_EQUALS; static const char* ADD; static const char* SUBTRACT; static const char* MULTIPLY; static const char* DIVIDE; static const char* MODULO; static const char* NEGATE; static const char* INDEX; // Macros static const char* HAS; static const char* ALL; static const char* EXISTS; static const char* EXISTS_ONE; static const char* MAP; static const char* FILTER; // Named operators, must not have be valid identifiers. static const char* NOT_STRICTLY_FALSE; #pragma push_macro("IN") #undef IN static const char* IN; #pragma pop_macro("IN") static const absl::string_view OPT_INDEX; static const absl::string_view OPT_SELECT; }; // These give access to all or some specific precedence value. // Higher value means higher precedence, 0 means no precedence, i.e., // custom function and not builtin operator. int LookupPrecedence(absl::string_view op); absl::optional LookupUnaryOperator(absl::string_view op); absl::optional LookupBinaryOperator(absl::string_view op); absl::optional LookupOperator(absl::string_view op); absl::optional ReverseLookupOperator(absl::string_view op); // returns true if op has a lower precedence than the one expressed in expr bool IsOperatorLowerPrecedence(absl::string_view op, const cel::expr::Expr& expr); // returns true if op has the same precedence as the one expressed in expr bool IsOperatorSamePrecedence(absl::string_view op, const cel::expr::Expr& expr); // return true if operator is left recursive, i.e., neither && nor ||. bool IsOperatorLeftRecursive(absl::string_view op); } // namespace google::api::expr::common #endif // THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ ================================================ FILE: common/optional_ref.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_OPTIONAL_REF_H_ #define THIRD_PARTY_CEL_CPP_OPTIONAL_REF_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/types/optional.h" #include "absl/utility/utility.h" namespace cel { // `optional_ref` looks and feels like `absl::optional`, but instead of // owning the underlying value, it retains a reference to the value it accepts // in its constructor. template class optional_ref final { public: static_assert(!std::is_reference_v, "T must not be a reference."); static_assert(!std::is_same_v>, "optional_ref is not allowed."); static_assert(!std::is_same_v>, "optional_ref is not allowed."); using value_type = T; optional_ref() = default; // NOLINTNEXTLINE(google-explicit-constructor) constexpr optional_ref(absl::nullopt_t) : optional_ref() {} // NOLINTNEXTLINE(google-explicit-constructor) constexpr optional_ref(T& value ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_(std::addressof(value)) {} template < typename U, typename = std::enable_if_t, std::is_same, std::decay_t>>>> // NOLINTNEXTLINE(google-explicit-constructor) constexpr optional_ref( const absl::optional& value ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_(value.has_value() ? std::addressof(*value) : nullptr) {} template , std::decay_t>>> // NOLINTNEXTLINE(google-explicit-constructor) constexpr optional_ref(absl::optional& value ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_(value.has_value() ? std::addressof(*value) : nullptr) {} template < typename U, typename = std::enable_if_t>, std::is_convertible, std::add_pointer_t>>>> // NOLINTNEXTLINE(google-explicit-constructor) constexpr optional_ref(const optional_ref& other) : value_(other.value_) {} optional_ref(const optional_ref&) = default; optional_ref& operator=(const optional_ref&) = delete; constexpr bool has_value() const { return value_ != nullptr; } constexpr explicit operator bool() const { return has_value(); } constexpr T& value() const { return ABSL_PREDICT_TRUE(has_value()) ? *value_ // Replicate the same error logic as in `absl::optional`'s // `value()`. It either throws an exception or aborts the // program. We intentionally ignore the return value of // the constructed optional's value as we only need to run // the code for error checking. : ((void)absl::optional().value(), *value_); } constexpr T& operator*() const { ABSL_ASSERT(has_value()); return *value_; } constexpr T* absl_nonnull operator->() const { ABSL_ASSERT(has_value()); return value_; } private: template friend class optional_ref; T* const value_ = nullptr; }; template optional_ref(const T&) -> optional_ref; template optional_ref(T&) -> optional_ref; template optional_ref(const absl::optional&) -> optional_ref; template optional_ref(absl::optional&) -> optional_ref; template constexpr bool operator==(const optional_ref& lhs, absl::nullopt_t) { return !lhs.has_value(); } template constexpr bool operator==(absl::nullopt_t, const optional_ref& rhs) { return !rhs.has_value(); } template constexpr bool operator!=(const optional_ref& lhs, absl::nullopt_t) { return !operator==(lhs, absl::nullopt); } template constexpr bool operator!=(absl::nullopt_t, const optional_ref& rhs) { return !operator==(absl::nullopt, rhs); } namespace common_internal { template absl::optional> AsOptional(optional_ref ref) { if (ref) { return *ref; } return absl::nullopt; } template absl::optional AsOptional(absl::optional opt) { return opt; } } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_OPTIONAL_REF_H_ ================================================ FILE: common/reference.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/reference.h" #include "absl/base/no_destructor.h" namespace cel { const VariableReference& VariableReference::default_instance() { static const absl::NoDestructor instance; return *instance; } const FunctionReference& FunctionReference::default_instance() { static const absl::NoDestructor instance; return *instance; } } // namespace cel ================================================ FILE: common/reference.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_H_ #include #include #include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "common/constant.h" namespace cel { class Reference; class VariableReference; class FunctionReference; using ReferenceKind = absl::variant; // `VariableReference` is a resolved reference to a `VariableDecl`. class VariableReference final { public: bool has_value() const { return value_.has_value(); } void set_value(Constant value) { value_ = std::move(value); } const Constant& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } Constant& mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } ABSL_MUST_USE_RESULT Constant release_value() { using std::swap; Constant value; swap(mutable_value(), value); return value; } friend void swap(VariableReference& lhs, VariableReference& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); } private: friend class Reference; static const VariableReference& default_instance(); Constant value_; }; inline bool operator==(const VariableReference& lhs, const VariableReference& rhs) { return lhs.value() == rhs.value(); } inline bool operator!=(const VariableReference& lhs, const VariableReference& rhs) { return !operator==(lhs, rhs); } // `FunctionReference` is a resolved reference to a `FunctionDecl`. class FunctionReference final { public: const std::vector& overloads() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return overloads_; } void set_overloads(std::vector overloads) { mutable_overloads() = std::move(overloads); } std::vector& mutable_overloads() ABSL_ATTRIBUTE_LIFETIME_BOUND { return overloads_; } ABSL_MUST_USE_RESULT std::vector release_overloads() { std::vector overloads; overloads.swap(mutable_overloads()); return overloads; } friend void swap(FunctionReference& lhs, FunctionReference& rhs) noexcept { using std::swap; swap(lhs.overloads_, rhs.overloads_); } private: friend class Reference; static const FunctionReference& default_instance(); std::vector overloads_; }; inline bool operator==(const FunctionReference& lhs, const FunctionReference& rhs) { return absl::c_equal(lhs.overloads(), rhs.overloads()); } inline bool operator!=(const FunctionReference& lhs, const FunctionReference& rhs) { return !operator==(lhs, rhs); } // `Reference` is a resolved reference to a `VariableDecl` or `FunctionDecl`. By // default `Reference` is a `VariableReference`. class Reference final { public: const std::string& name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } void set_name(std::string name) { name_ = std::move(name); } void set_name(absl::string_view name) { name_.assign(name.data(), name.size()); } void set_name(const char* name) { set_name(absl::NullSafeStringView(name)); } ABSL_MUST_USE_RESULT std::string release_name() { std::string name; name.swap(name_); return name; } void set_kind(ReferenceKind kind) { kind_ = std::move(kind); } const ReferenceKind& kind() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return kind_; } ReferenceKind& mutable_kind() ABSL_ATTRIBUTE_LIFETIME_BOUND { return kind_; } ABSL_MUST_USE_RESULT ReferenceKind release_kind() { using std::swap; ReferenceKind kind; swap(kind, kind_); return kind; } ABSL_MUST_USE_RESULT bool has_variable() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const VariableReference& variable() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (const auto* alt = absl::get_if(&kind()); alt) { return *alt; } return VariableReference::default_instance(); } void set_variable(VariableReference variable) { mutable_variable() = std::move(variable); } VariableReference& mutable_variable() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_variable()) { mutable_kind().emplace(); } return absl::get(mutable_kind()); } ABSL_MUST_USE_RESULT VariableReference release_variable() { VariableReference variable_reference; if (auto* alt = absl::get_if(&mutable_kind()); alt) { variable_reference = std::move(*alt); } mutable_kind().emplace(); return variable_reference; } ABSL_MUST_USE_RESULT bool has_function() const { return absl::holds_alternative(kind()); } ABSL_MUST_USE_RESULT const FunctionReference& function() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (const auto* alt = absl::get_if(&kind()); alt) { return *alt; } return FunctionReference::default_instance(); } void set_function(FunctionReference function) { mutable_function() = std::move(function); } FunctionReference& mutable_function() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (!has_function()) { mutable_kind().emplace(); } return absl::get(mutable_kind()); } ABSL_MUST_USE_RESULT FunctionReference release_function() { FunctionReference function_reference; if (auto* alt = absl::get_if(&mutable_kind()); alt) { function_reference = std::move(*alt); } mutable_kind().emplace(); return function_reference; } friend void swap(Reference& lhs, Reference& rhs) noexcept { using std::swap; swap(lhs.name_, rhs.name_); swap(lhs.kind_, rhs.kind_); } private: std::string name_; ReferenceKind kind_; }; inline bool operator==(const Reference& lhs, const Reference& rhs) { return lhs.name() == rhs.name() && lhs.kind() == rhs.kind(); } inline bool operator!=(const Reference& lhs, const Reference& rhs) { return !operator==(lhs, rhs); } inline Reference MakeVariableReference(std::string name) { Reference reference; reference.set_name(std::move(name)); reference.mutable_kind().emplace(); return reference; } inline Reference MakeConstantVariableReference(std::string name, Constant constant) { Reference reference; reference.set_name(std::move(name)); reference.mutable_kind().emplace().set_value( std::move(constant)); return reference; } inline Reference MakeFunctionReference(std::string name, std::vector overloads) { Reference reference; reference.set_name(std::move(name)); reference.mutable_kind().emplace().set_overloads( std::move(overloads)); return reference; } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_H_ ================================================ FILE: common/reference_count.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_COUNT_H_ #define THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_COUNT_H_ #include "common/internal/reference_count.h" namespace cel { using ReferenceCount = common_internal::ReferenceCount; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_REFERENCE_COUNT_H_ ================================================ FILE: common/reference_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/reference.h" #include #include #include #include "common/constant.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::_; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::VariantWith; TEST(VariableReference, Value) { VariableReference variable_reference; EXPECT_FALSE(variable_reference.has_value()); EXPECT_EQ(variable_reference.value(), Constant{}); Constant value; value.set_bool_value(true); variable_reference.set_value(value); EXPECT_TRUE(variable_reference.has_value()); EXPECT_EQ(variable_reference.value(), value); EXPECT_EQ(variable_reference.release_value(), value); EXPECT_EQ(variable_reference.value(), Constant{}); } TEST(VariableReference, Equality) { VariableReference variable_reference; EXPECT_EQ(variable_reference, VariableReference{}); variable_reference.mutable_value().set_bool_value(true); EXPECT_NE(variable_reference, VariableReference{}); } TEST(FunctionReference, Overloads) { FunctionReference function_reference; EXPECT_THAT(function_reference.overloads(), IsEmpty()); function_reference.mutable_overloads().reserve(2); function_reference.mutable_overloads().push_back("foo"); function_reference.mutable_overloads().push_back("bar"); EXPECT_THAT(function_reference.release_overloads(), ElementsAre("foo", "bar")); EXPECT_THAT(function_reference.overloads(), IsEmpty()); } TEST(FunctionReference, Equality) { FunctionReference function_reference; EXPECT_EQ(function_reference, FunctionReference{}); function_reference.mutable_overloads().push_back("foo"); EXPECT_NE(function_reference, FunctionReference{}); } TEST(Reference, Name) { Reference reference; EXPECT_THAT(reference.name(), IsEmpty()); reference.set_name("foo"); EXPECT_EQ(reference.name(), "foo"); EXPECT_EQ(reference.release_name(), "foo"); EXPECT_THAT(reference.name(), IsEmpty()); } TEST(Reference, Variable) { Reference reference; EXPECT_THAT(reference.kind(), VariantWith(_)); EXPECT_TRUE(reference.has_variable()); EXPECT_THAT(reference.release_variable(), Eq(VariableReference{})); EXPECT_TRUE(reference.has_variable()); } TEST(Reference, Function) { Reference reference; EXPECT_FALSE(reference.has_function()); EXPECT_THAT(reference.function(), Eq(FunctionReference{})); reference.mutable_function(); EXPECT_TRUE(reference.has_function()); EXPECT_THAT(reference.variable(), Eq(VariableReference{})); EXPECT_THAT(reference.kind(), VariantWith(_)); EXPECT_THAT(reference.release_function(), Eq(FunctionReference{})); EXPECT_FALSE(reference.has_function()); } TEST(Reference, Equality) { EXPECT_EQ(MakeVariableReference("foo"), MakeVariableReference("foo")); EXPECT_NE(MakeVariableReference("foo"), MakeConstantVariableReference("foo", Constant(int64_t{1}))); EXPECT_EQ( MakeFunctionReference("foo", std::vector{"bar", "baz"}), MakeFunctionReference("foo", std::vector{"bar", "baz"})); EXPECT_NE( MakeFunctionReference("foo", std::vector{"bar", "baz"}), MakeFunctionReference("foo", std::vector{"bar"})); } } // namespace } // namespace cel ================================================ FILE: common/source.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/source.h" #include #include #include #include #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "internal/unicode.h" #include "internal/utf8.h" namespace cel { SourcePosition SourceContentView::size() const { return static_cast(absl::visit( absl::Overload( [](absl::Span view) { return view.size(); }, [](absl::Span view) { return view.size(); }, [](absl::Span view) { return view.size(); }, [](absl::Span view) { return view.size(); }), view_)); } bool SourceContentView::empty() const { return absl::visit( absl::Overload( [](absl::Span view) { return view.empty(); }, [](absl::Span view) { return view.empty(); }, [](absl::Span view) { return view.empty(); }, [](absl::Span view) { return view.empty(); }), view_); } char32_t SourceContentView::at(SourcePosition position) const { ABSL_DCHECK_GE(position, 0); ABSL_DCHECK_LT(position, size()); return absl::visit( absl::Overload( [position = static_cast(position)](absl::Span view) { return static_cast(static_cast(view[position])); }, [position = static_cast(position)](absl::Span view) { return static_cast(view[position]); }, [position = static_cast(position)](absl::Span view) { return static_cast(view[position]); }, [position = static_cast(position)](absl::Span view) { return static_cast(view[position]); }), view_); } std::string SourceContentView::ToString(SourcePosition begin, SourcePosition end) const { ABSL_DCHECK_GE(begin, 0); ABSL_DCHECK_LE(end, size()); ABSL_DCHECK_LE(begin, end); return absl::visit( absl::Overload( [begin = static_cast(begin), end = static_cast(end)](absl::Span view) { view = view.subspan(begin, end - begin); return std::string(view.data(), view.size()); }, [begin = static_cast(begin), end = static_cast(end)](absl::Span view) { view = view.subspan(begin, end - begin); std::string result; result.reserve(view.size() * 2); for (const auto& code_point : view) { internal::Utf8Encode(result, code_point); } result.shrink_to_fit(); return result; }, [begin = static_cast(begin), end = static_cast(end)](absl::Span view) { view = view.subspan(begin, end - begin); std::string result; result.reserve(view.size() * 3); for (const auto& code_point : view) { internal::Utf8Encode(result, code_point); } result.shrink_to_fit(); return result; }, [begin = static_cast(begin), end = static_cast(end)](absl::Span view) { view = view.subspan(begin, end - begin); std::string result; result.reserve(view.size() * 4); for (const auto& code_point : view) { internal::Utf8Encode(result, code_point); } result.shrink_to_fit(); return result; }), view_); } void SourceContentView::AppendToString(std::string& dest) const { absl::visit(absl::Overload( [&dest](absl::Span view) { dest.append(view.data(), view.size()); }, [&dest](absl::Span view) { for (const auto& code_point : view) { internal::Utf8Encode(dest, code_point); } }, [&dest](absl::Span view) { for (const auto& code_point : view) { internal::Utf8Encode(dest, code_point); } }, [&dest](absl::Span view) { for (const auto& code_point : view) { internal::Utf8Encode(dest, code_point); } }), view_); } namespace common_internal { class SourceImpl : public Source { public: SourceImpl(std::string description, absl::InlinedVector line_offsets) : description_(std::move(description)), line_offsets_(std::move(line_offsets)) {} absl::string_view description() const final { return description_; } absl::Span line_offsets() const final { return absl::MakeConstSpan(line_offsets_); } private: const std::string description_; const absl::InlinedVector line_offsets_; }; namespace { class AsciiSource final : public SourceImpl { public: AsciiSource(std::string description, absl::InlinedVector line_offsets, std::vector text) : SourceImpl(std::move(description), std::move(line_offsets)), text_(std::move(text)) {} ContentView content() const override { return MakeContentView(absl::MakeConstSpan(text_)); } private: const std::vector text_; }; class Latin1Source final : public SourceImpl { public: Latin1Source(std::string description, absl::InlinedVector line_offsets, std::vector text) : SourceImpl(std::move(description), std::move(line_offsets)), text_(std::move(text)) {} ContentView content() const override { return MakeContentView(absl::MakeConstSpan(text_)); } private: const std::vector text_; }; class BasicPlaneSource final : public SourceImpl { public: BasicPlaneSource(std::string description, absl::InlinedVector line_offsets, std::vector text) : SourceImpl(std::move(description), std::move(line_offsets)), text_(std::move(text)) {} ContentView content() const override { return MakeContentView(absl::MakeConstSpan(text_)); } private: const std::vector text_; }; class SupplementalPlaneSource final : public SourceImpl { public: SupplementalPlaneSource(std::string description, absl::InlinedVector line_offsets, std::vector text) : SourceImpl(std::move(description), std::move(line_offsets)), text_(std::move(text)) {} ContentView content() const override { return MakeContentView(absl::MakeConstSpan(text_)); } private: const std::vector text_; }; template struct SourceTextTraits; template <> struct SourceTextTraits { using iterator_type = absl::string_view; static iterator_type Begin(absl::string_view text) { return text; } static void Advance(iterator_type& it, size_t n) { it.remove_prefix(n); } static void AppendTo(std::vector& out, absl::string_view text, size_t n) { const auto* in = reinterpret_cast(text.data()); out.insert(out.end(), in, in + n); } static std::vector ToVector(absl::string_view in) { std::vector out; out.reserve(in.size()); out.insert(out.end(), in.begin(), in.end()); return out; } }; template <> struct SourceTextTraits { using iterator_type = absl::Cord::CharIterator; static iterator_type Begin(const absl::Cord& text) { return text.char_begin(); } static void Advance(iterator_type& it, size_t n) { absl::Cord::Advance(&it, n); } static void AppendTo(std::vector& out, const absl::Cord& text, size_t n) { auto it = text.char_begin(); while (n > 0) { auto str = absl::Cord::ChunkRemaining(it); size_t to_append = std::min(n, str.size()); const auto* in = reinterpret_cast(str.data()); out.insert(out.end(), in, in + to_append); n -= to_append; absl::Cord::Advance(&it, to_append); } } static std::vector ToVector(const absl::Cord& in) { std::vector out; out.reserve(in.size()); for (const auto& chunk : in.Chunks()) { out.insert(out.end(), chunk.begin(), chunk.end()); } return out; } }; template absl::StatusOr NewSourceImpl(std::string description, const T& text, const size_t text_size) { if (ABSL_PREDICT_FALSE( text_size > static_cast(std::numeric_limits::max()))) { return absl::InvalidArgumentError("expression larger than 2GiB limit"); } using Traits = SourceTextTraits; size_t index = 0; typename Traits::iterator_type it = Traits::Begin(text); SourcePosition offset = 0; char32_t code_point; size_t code_units; std::vector data8; std::vector data16; std::vector data32; absl::InlinedVector line_offsets; while (index < text_size) { std::tie(code_point, code_units) = cel::internal::Utf8Decode(it); if (ABSL_PREDICT_FALSE(code_point == cel::internal::kUnicodeReplacementCharacter && code_units == 1)) { // Thats an invalid UTF-8 encoding. return absl::InvalidArgumentError("cannot parse malformed UTF-8 input"); } if (code_point == '\n') { line_offsets.push_back(offset + 1); } if (code_point <= 0x7f) { Traits::Advance(it, code_units); index += code_units; ++offset; continue; } if (code_point <= 0xff) { data8.reserve(text_size); Traits::AppendTo(data8, text, index); data8.push_back(static_cast(code_point)); Traits::Advance(it, code_units); index += code_units; ++offset; goto latin1; } if (code_point <= 0xffff) { data16.reserve(text_size); for (size_t offset = 0; offset < index; offset++) { data16.push_back(static_cast(text[offset])); } data16.push_back(static_cast(code_point)); Traits::Advance(it, code_units); index += code_units; ++offset; goto basic; } data32.reserve(text_size); for (size_t offset = 0; offset < index; offset++) { data32.push_back(static_cast(text[offset])); } data32.push_back(code_point); Traits::Advance(it, code_units); index += code_units; ++offset; goto supplemental; } line_offsets.push_back(offset + 1); return std::make_unique( std::move(description), std::move(line_offsets), Traits::ToVector(text)); latin1: while (index < text_size) { std::tie(code_point, code_units) = internal::Utf8Decode(it); if (ABSL_PREDICT_FALSE(code_point == internal::kUnicodeReplacementCharacter && code_units == 1)) { // Thats an invalid UTF-8 encoding. return absl::InvalidArgumentError("cannot parse malformed UTF-8 input"); } if (code_point == '\n') { line_offsets.push_back(offset + 1); } if (code_point <= 0xff) { data8.push_back(static_cast(code_point)); Traits::Advance(it, code_units); index += code_units; ++offset; continue; } if (code_point <= 0xffff) { data16.reserve(text_size); for (const auto& value : data8) { data16.push_back(value); } std::vector().swap(data8); data16.push_back(static_cast(code_point)); Traits::Advance(it, code_units); index += code_units; ++offset; goto basic; } data32.reserve(text_size); for (const auto& value : data8) { data32.push_back(value); } std::vector().swap(data8); data32.push_back(code_point); Traits::Advance(it, code_units); index += code_units; ++offset; goto supplemental; } line_offsets.push_back(offset + 1); return std::make_unique( std::move(description), std::move(line_offsets), std::move(data8)); basic: while (index < text_size) { std::tie(code_point, code_units) = internal::Utf8Decode(it); if (ABSL_PREDICT_FALSE(code_point == internal::kUnicodeReplacementCharacter && code_units == 1)) { // Thats an invalid UTF-8 encoding. return absl::InvalidArgumentError("cannot parse malformed UTF-8 input"); } if (code_point == '\n') { line_offsets.push_back(offset + 1); } if (code_point <= 0xffff) { data16.push_back(static_cast(code_point)); Traits::Advance(it, code_units); index += code_units; ++offset; continue; } data32.reserve(text_size); for (const auto& value : data16) { data32.push_back(static_cast(value)); } std::vector().swap(data16); data32.push_back(code_point); Traits::Advance(it, code_units); index += code_units; ++offset; goto supplemental; } line_offsets.push_back(offset + 1); return std::make_unique( std::move(description), std::move(line_offsets), std::move(data16)); supplemental: while (index < text_size) { std::tie(code_point, code_units) = internal::Utf8Decode(it); if (ABSL_PREDICT_FALSE(code_point == internal::kUnicodeReplacementCharacter && code_units == 1)) { // Thats an invalid UTF-8 encoding. return absl::InvalidArgumentError("cannot parse malformed UTF-8 input"); } if (code_point == '\n') { line_offsets.push_back(offset + 1); } data32.push_back(code_point); Traits::Advance(it, code_units); index += code_units; ++offset; } line_offsets.push_back(offset + 1); return std::make_unique( std::move(description), std::move(line_offsets), std::move(data32)); } } // namespace } // namespace common_internal absl::optional Source::GetLocation( SourcePosition position) const { if (auto line_and_offset = FindLine(position); ABSL_PREDICT_TRUE(line_and_offset.has_value())) { return SourceLocation{line_and_offset->first, position - line_and_offset->second}; } return absl::nullopt; } absl::optional Source::GetPosition( const SourceLocation& location) const { if (ABSL_PREDICT_FALSE(location.line < 1 || location.column < 0)) { return absl::nullopt; } if (auto position = FindLinePosition(location.line); ABSL_PREDICT_TRUE(position.has_value())) { return *position + location.column; } return absl::nullopt; } absl::optional Source::Snippet(int32_t line) const { auto content = this->content(); auto start = FindLinePosition(line); if (ABSL_PREDICT_FALSE(!start.has_value() || content.empty())) { return absl::nullopt; } auto end = FindLinePosition(line + 1); if (end.has_value()) { return content.ToString(*start, *end - 1); } return content.ToString(*start); } std::string Source::DisplayErrorLocation(SourceLocation location) const { constexpr char32_t kDot = '.'; constexpr char32_t kHat = '^'; constexpr char32_t kWideDot = 0xff0e; constexpr char32_t kWideHat = 0xff3e; absl::optional snippet = Snippet(location.line); if (!snippet || snippet->empty()) { return ""; } *snippet = absl::StrReplaceAll(*snippet, {{"\t", " "}}); absl::string_view snippet_view(*snippet); std::string result; absl::StrAppend(&result, "\n | ", *snippet); absl::StrAppend(&result, "\n | "); std::string index_line; for (int32_t i = 0; i < location.column && !snippet_view.empty(); ++i) { size_t count; std::tie(std::ignore, count) = internal::Utf8Decode(snippet_view); snippet_view.remove_prefix(count); if (count > 1) { internal::Utf8Encode(index_line, kWideDot); } else { internal::Utf8Encode(index_line, kDot); } } size_t count = 0; if (!snippet_view.empty()) { std::tie(std::ignore, count) = internal::Utf8Decode(snippet_view); } if (count > 1) { internal::Utf8Encode(index_line, kWideHat); } else { internal::Utf8Encode(index_line, kHat); } absl::StrAppend(&result, index_line); return result; } absl::optional Source::FindLinePosition(int32_t line) const { if (ABSL_PREDICT_FALSE(line < 1)) { return absl::nullopt; } if (line == 1) { return SourcePosition{0}; } const auto line_offsets = this->line_offsets(); if (ABSL_PREDICT_TRUE(line <= static_cast(line_offsets.size()))) { return line_offsets[static_cast(line - 2)]; } return absl::nullopt; } absl::optional> Source::FindLine( SourcePosition position) const { if (ABSL_PREDICT_FALSE(position < 0)) { return absl::nullopt; } int32_t line = 1; const auto line_offsets = this->line_offsets(); for (const auto& line_offset : line_offsets) { if (line_offset > position) { break; } ++line; } if (line == 1) { return std::make_pair(line, SourcePosition{0}); } return std::make_pair(line, line_offsets[static_cast(line) - 2]); } absl::StatusOr NewSource(absl::string_view content, std::string description) { return common_internal::NewSourceImpl(std::move(description), content, content.size()); } absl::StatusOr NewSource(const absl::Cord& content, std::string description) { return common_internal::NewSourceImpl(std::move(description), content, content.size()); } } // namespace cel ================================================ FILE: common/source.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_SOURCE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_SOURCE_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" namespace cel { namespace common_internal { class SourceImpl; } // namespace common_internal class Source; // SourcePosition represents an offset in source text. using SourcePosition = int32_t; // SourceRange represents a range of positions, where `begin` is inclusive and // `end` is exclusive. struct SourceRange final { SourcePosition begin = -1; SourcePosition end = -1; }; inline bool operator==(const SourceRange& lhs, const SourceRange& rhs) { return lhs.begin == rhs.begin && lhs.end == rhs.end; } inline bool operator!=(const SourceRange& lhs, const SourceRange& rhs) { return !operator==(lhs, rhs); } // `SourceLocation` is a representation of a line and column in source text. struct SourceLocation final { int32_t line = -1; // 1-based line number. int32_t column = -1; // 0-based column number. }; inline bool operator==(const SourceLocation& lhs, const SourceLocation& rhs) { return lhs.line == rhs.line && lhs.column == rhs.column; } inline bool operator!=(const SourceLocation& lhs, const SourceLocation& rhs) { return !operator==(lhs, rhs); } // `SourceContentView` is a view of the content owned by `Source`, which is a // sequence of Unicode code points. class SourceContentView final { public: SourceContentView(const SourceContentView&) = default; SourceContentView(SourceContentView&&) = default; SourceContentView& operator=(const SourceContentView&) = default; SourceContentView& operator=(SourceContentView&&) = default; SourcePosition size() const; bool empty() const; char32_t at(SourcePosition position) const; std::string ToString(SourcePosition begin, SourcePosition end) const; std::string ToString(SourcePosition begin) const { return ToString(begin, size()); } std::string ToString() const { return ToString(0); } void AppendToString(std::string& dest) const; private: friend class Source; constexpr SourceContentView() = default; constexpr explicit SourceContentView(absl::Span view) : view_(view) {} constexpr explicit SourceContentView(absl::Span view) : view_(view) {} constexpr explicit SourceContentView(absl::Span view) : view_(view) {} constexpr explicit SourceContentView(absl::Span view) : view_(view) {} absl::variant, absl::Span, absl::Span, absl::Span> view_; }; // `Source` represents the source expression. class Source { public: using ContentView = SourceContentView; Source(const Source&) = delete; Source(Source&&) = delete; virtual ~Source() = default; Source& operator=(const Source&) = delete; Source& operator=(Source&&) = delete; virtual absl::string_view description() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; // Maps a `SourcePosition` to a `SourceLocation`. Returns an empty // `absl::optional` when `SourcePosition` is invalid or the information // required to perform the mapping is not present. absl::optional GetLocation(SourcePosition position) const; // Maps a `SourceLocation` to a `SourcePosition`. Returns an empty // `absl::optional` when `SourceLocation` is invalid or the information // required to perform the mapping is not present. absl::optional GetPosition( const SourceLocation& location) const; absl::optional Snippet(int32_t line) const; // Formats an annotated snippet highlighting an error at location, e.g. // // "\n | $SOURCE_SNIPPET" + // "\n | .......^" // // Returns an empty string if location is not a valid location in this source. std::string DisplayErrorLocation(SourceLocation location) const; // Returns a view of the underlying expression text, if present. virtual ContentView content() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; // Returns a `absl::Span` of `SourcePosition` which represent the positions // where new lines occur. virtual absl::Span line_offsets() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; protected: static constexpr ContentView EmptyContentView() { return ContentView(); } static constexpr ContentView MakeContentView(absl::Span view) { return ContentView(view); } static constexpr ContentView MakeContentView(absl::Span view) { return ContentView(view); } static constexpr ContentView MakeContentView( absl::Span view) { return ContentView(view); } static constexpr ContentView MakeContentView( absl::Span view) { return ContentView(view); } private: friend class common_internal::SourceImpl; Source() = default; absl::optional FindLinePosition(int32_t line) const; absl::optional> FindLine( SourcePosition position) const; }; using SourcePtr = std::unique_ptr; absl::StatusOr NewSource( absl::string_view content, std::string description = ""); absl::StatusOr NewSource( const absl::Cord& content, std::string description = ""); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_SOURCE_H_ ================================================ FILE: common/source_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/source.h" #include "absl/strings/cord.h" #include "absl/types/optional.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Ne; using ::testing::Optional; TEST(SourceRange, Default) { SourceRange range; EXPECT_EQ(range.begin, -1); EXPECT_EQ(range.end, -1); } TEST(SourceRange, Equality) { EXPECT_THAT((SourceRange{}), (Eq(SourceRange{}))); EXPECT_THAT((SourceRange{0, 1}), (Ne(SourceRange{0, 0}))); } TEST(SourceLocation, Default) { SourceLocation location; EXPECT_EQ(location.line, -1); EXPECT_EQ(location.column, -1); } TEST(SourceLocation, Equality) { EXPECT_THAT((SourceLocation{}), (Eq(SourceLocation{}))); EXPECT_THAT((SourceLocation{1, 1}), (Ne(SourceLocation{1, 0}))); } TEST(StringSource, Description) { ASSERT_OK_AND_ASSIGN( auto source, NewSource("c.d &&\n\t b.c.arg(10) &&\n\t test(10)", "offset-test")); EXPECT_THAT(source->description(), Eq("offset-test")); } TEST(StringSource, Content) { ASSERT_OK_AND_ASSIGN( auto source, NewSource("c.d &&\n\t b.c.arg(10) &&\n\t test(10)", "offset-test")); EXPECT_THAT(source->content().ToString(), Eq("c.d &&\n\t b.c.arg(10) &&\n\t test(10)")); } TEST(StringSource, PositionAndLocation) { ASSERT_OK_AND_ASSIGN( auto source, NewSource("c.d &&\n\t b.c.arg(10) &&\n\t test(10)", "offset-test")); EXPECT_THAT(source->line_offsets(), ElementsAre(7, 24, 35)); auto start = source->GetPosition(SourceLocation{int32_t{1}, int32_t{2}}); auto end = source->GetPosition(SourceLocation{int32_t{3}, int32_t{2}}); ASSERT_TRUE(start.has_value()); ASSERT_TRUE(end.has_value()); EXPECT_THAT(source->GetLocation(*start), Optional(Eq(SourceLocation{int32_t{1}, int32_t{2}}))); EXPECT_THAT(source->GetLocation(*end), Optional(Eq(SourceLocation{int32_t{3}, int32_t{2}}))); EXPECT_THAT(source->GetLocation(-1), Eq(absl::nullopt)); EXPECT_THAT(source->content().ToString(*start, *end), Eq("d &&\n\t b.c.arg(10) &&\n\t ")); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{0}, int32_t{0}}), Eq(absl::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{1}, int32_t{-1}}), Eq(absl::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{4}, int32_t{0}}), Eq(absl::nullopt)); } TEST(StringSource, SnippetSingle) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("hello, world", "one-line-test")); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello, world"))); EXPECT_THAT(source->Snippet(2), Eq(absl::nullopt)); } TEST(StringSource, SnippetMulti) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("hello\nworld\nmy\nbub\n", "four-line-test")); EXPECT_THAT(source->Snippet(0), Eq(absl::nullopt)); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello"))); EXPECT_THAT(source->Snippet(2), Optional(Eq("world"))); EXPECT_THAT(source->Snippet(3), Optional(Eq("my"))); EXPECT_THAT(source->Snippet(4), Optional(Eq("bub"))); EXPECT_THAT(source->Snippet(5), Optional(Eq(""))); EXPECT_THAT(source->Snippet(6), Eq(absl::nullopt)); } TEST(CordSource, Description) { ASSERT_OK_AND_ASSIGN( auto source, NewSource(absl::Cord("c.d &&\n\t b.c.arg(10) &&\n\t test(10)"), "offset-test")); EXPECT_THAT(source->description(), Eq("offset-test")); } TEST(CordSource, Content) { ASSERT_OK_AND_ASSIGN( auto source, NewSource(absl::Cord("c.d &&\n\t b.c.arg(10) &&\n\t test(10)"), "offset-test")); EXPECT_THAT(source->content().ToString(), Eq("c.d &&\n\t b.c.arg(10) &&\n\t test(10)")); } TEST(CordSource, PositionAndLocation) { ASSERT_OK_AND_ASSIGN( auto source, NewSource(absl::Cord("c.d &&\n\t b.c.arg(10) &&\n\t test(10)"), "offset-test")); EXPECT_THAT(source->line_offsets(), ElementsAre(7, 24, 35)); auto start = source->GetPosition(SourceLocation{int32_t{1}, int32_t{2}}); auto end = source->GetPosition(SourceLocation{int32_t{3}, int32_t{2}}); ASSERT_TRUE(start.has_value()); ASSERT_TRUE(end.has_value()); EXPECT_THAT(source->GetLocation(*start), Optional(Eq(SourceLocation{int32_t{1}, int32_t{2}}))); EXPECT_THAT(source->GetLocation(*end), Optional(Eq(SourceLocation{int32_t{3}, int32_t{2}}))); EXPECT_THAT(source->GetLocation(-1), Eq(absl::nullopt)); EXPECT_THAT(source->content().ToString(*start, *end), Eq("d &&\n\t b.c.arg(10) &&\n\t ")); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{0}, int32_t{0}}), Eq(absl::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{1}, int32_t{-1}}), Eq(absl::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{4}, int32_t{0}}), Eq(absl::nullopt)); } TEST(CordSource, SnippetSingle) { ASSERT_OK_AND_ASSIGN(auto source, NewSource(absl::Cord("hello, world"), "one-line-test")); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello, world"))); EXPECT_THAT(source->Snippet(2), Eq(absl::nullopt)); } TEST(CordSource, SnippetMulti) { ASSERT_OK_AND_ASSIGN( auto source, NewSource(absl::Cord("hello\nworld\nmy\nbub\n"), "four-line-test")); EXPECT_THAT(source->Snippet(0), Eq(absl::nullopt)); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello"))); EXPECT_THAT(source->Snippet(2), Optional(Eq("world"))); EXPECT_THAT(source->Snippet(3), Optional(Eq("my"))); EXPECT_THAT(source->Snippet(4), Optional(Eq("bub"))); EXPECT_THAT(source->Snippet(5), Optional(Eq(""))); EXPECT_THAT(source->Snippet(6), Eq(absl::nullopt)); } TEST(Source, DisplayErrorLocationBasic) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("'Hello' +\n 'world'")); SourceLocation location{/*line=*/2, /*column=*/3}; EXPECT_EQ(source->DisplayErrorLocation(location), "\n | 'world'" "\n | ...^"); } TEST(Source, DisplayErrorLocationOutOfRange) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("'Hello world!'")); SourceLocation location{/*line=*/3, /*column=*/3}; EXPECT_EQ(source->DisplayErrorLocation(location), ""); } TEST(Source, DisplayErrorLocationTabsShortened) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("'Hello' +\n\t\t'world!'")); SourceLocation location{/*line=*/2, /*column=*/4}; EXPECT_EQ(source->DisplayErrorLocation(location), "\n | 'world!'" "\n | ....^"); } TEST(Source, DisplayErrorLocationFullWidth) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("'Hello'")); SourceLocation location{/*line=*/1, /*column=*/2}; EXPECT_EQ(source->DisplayErrorLocation(location), "\n | 'Hello'" "\n | ..^"); } } // namespace } // namespace cel ================================================ FILE: common/standard_definitions.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Constants used for standard definitions for CEL. #ifndef THIRD_PARTY_CEL_CPP_COMMON_STANDARD_DEFINITIONS_H_ #define THIRD_PARTY_CEL_CPP_COMMON_STANDARD_DEFINITIONS_H_ #include "absl/strings/string_view.h" namespace cel { // Standard function names as represented in an AST. // TODO(uncreated-issue/71): use a namespace instead of a class. struct StandardFunctions { // Comparison static constexpr absl::string_view kEqual = "_==_"; static constexpr absl::string_view kInequal = "_!=_"; static constexpr absl::string_view kLess = "_<_"; static constexpr absl::string_view kLessOrEqual = "_<=_"; static constexpr absl::string_view kGreater = "_>_"; static constexpr absl::string_view kGreaterOrEqual = "_>=_"; // Logical static constexpr absl::string_view kAnd = "_&&_"; static constexpr absl::string_view kOr = "_||_"; static constexpr absl::string_view kNot = "!_"; // Strictness static constexpr absl::string_view kNotStrictlyFalse = "@not_strictly_false"; // Deprecated '__not_strictly_false__' function. Preserved for backwards // compatibility with stored expressions. static constexpr absl::string_view kNotStrictlyFalseDeprecated = "__not_strictly_false__"; // Arithmetical static constexpr absl::string_view kAdd = "_+_"; static constexpr absl::string_view kSubtract = "_-_"; static constexpr absl::string_view kNeg = "-_"; static constexpr absl::string_view kMultiply = "_*_"; static constexpr absl::string_view kDivide = "_/_"; static constexpr absl::string_view kModulo = "_%_"; // String operations static constexpr absl::string_view kRegexMatch = "matches"; static constexpr absl::string_view kStringContains = "contains"; static constexpr absl::string_view kStringEndsWith = "endsWith"; static constexpr absl::string_view kStringStartsWith = "startsWith"; // Container operations static constexpr absl::string_view kIn = "@in"; // Deprecated '_in_' operator. Preserved for backwards compatibility with // stored expressions. static constexpr absl::string_view kInDeprecated = "_in_"; // Deprecated 'in()' function. Preserved for backwards compatibility with // stored expressions. static constexpr absl::string_view kInFunction = "in"; static constexpr absl::string_view kIndex = "_[_]"; static constexpr absl::string_view kSize = "size"; static constexpr absl::string_view kTernary = "_?_:_"; // Timestamp and Duration static constexpr absl::string_view kDuration = "duration"; static constexpr absl::string_view kTimestamp = "timestamp"; static constexpr absl::string_view kFullYear = "getFullYear"; static constexpr absl::string_view kMonth = "getMonth"; static constexpr absl::string_view kDayOfYear = "getDayOfYear"; static constexpr absl::string_view kDayOfMonth = "getDayOfMonth"; static constexpr absl::string_view kDate = "getDate"; static constexpr absl::string_view kDayOfWeek = "getDayOfWeek"; static constexpr absl::string_view kHours = "getHours"; static constexpr absl::string_view kMinutes = "getMinutes"; static constexpr absl::string_view kSeconds = "getSeconds"; static constexpr absl::string_view kMilliseconds = "getMilliseconds"; // Type conversions static constexpr absl::string_view kBool = "bool"; static constexpr absl::string_view kBytes = "bytes"; static constexpr absl::string_view kDouble = "double"; static constexpr absl::string_view kDyn = "dyn"; static constexpr absl::string_view kInt = "int"; static constexpr absl::string_view kString = "string"; static constexpr absl::string_view kType = "type"; static constexpr absl::string_view kUint = "uint"; // Runtime-only functions. // The convention for runtime-only functions where only the runtime needs to // differentiate behavior is to prefix the function with `#`. // Note, this is a different convention from CEL internal functions where the // whole stack needs to be aware of the function id. static constexpr absl::string_view kRuntimeListAppend = "#list_append"; }; // Standard overload IDs used by type checkers. // TODO(uncreated-issue/71): use a namespace instead of a class. struct StandardOverloadIds { // Add operator _+_ static constexpr absl::string_view kAddInt = "add_int64"; static constexpr absl::string_view kAddUint = "add_uint64"; static constexpr absl::string_view kAddDouble = "add_double"; static constexpr absl::string_view kAddDurationDuration = "add_duration_duration"; static constexpr absl::string_view kAddDurationTimestamp = "add_duration_timestamp"; static constexpr absl::string_view kAddTimestampDuration = "add_timestamp_duration"; static constexpr absl::string_view kAddString = "add_string"; static constexpr absl::string_view kAddBytes = "add_bytes"; static constexpr absl::string_view kAddList = "add_list"; // Subtract operator _-_ static constexpr absl::string_view kSubtractInt = "subtract_int64"; static constexpr absl::string_view kSubtractUint = "subtract_uint64"; static constexpr absl::string_view kSubtractDouble = "subtract_double"; static constexpr absl::string_view kSubtractDurationDuration = "subtract_duration_duration"; static constexpr absl::string_view kSubtractTimestampDuration = "subtract_timestamp_duration"; static constexpr absl::string_view kSubtractTimestampTimestamp = "subtract_timestamp_timestamp"; // Multiply operator _*_ static constexpr absl::string_view kMultiplyInt = "multiply_int64"; static constexpr absl::string_view kMultiplyUint = "multiply_uint64"; static constexpr absl::string_view kMultiplyDouble = "multiply_double"; // Division operator _/_ static constexpr absl::string_view kDivideInt = "divide_int64"; static constexpr absl::string_view kDivideUint = "divide_uint64"; static constexpr absl::string_view kDivideDouble = "divide_double"; // Modulo operator _%_ static constexpr absl::string_view kModuloInt = "modulo_int64"; static constexpr absl::string_view kModuloUint = "modulo_uint64"; // Negation operator -_ static constexpr absl::string_view kNegateInt = "negate_int64"; static constexpr absl::string_view kNegateDouble = "negate_double"; // Logical operators static constexpr absl::string_view kNot = "logical_not"; static constexpr absl::string_view kAnd = "logical_and"; static constexpr absl::string_view kOr = "logical_or"; static constexpr absl::string_view kConditional = "conditional"; // Comprehension logic static constexpr absl::string_view kNotStrictlyFalse = "not_strictly_false"; static constexpr absl::string_view kNotStrictlyFalseDeprecated = "__not_strictly_false__"; // Equality operators static constexpr absl::string_view kEquals = "equals"; static constexpr absl::string_view kNotEquals = "not_equals"; // Relational operators static constexpr absl::string_view kLessBool = "less_bool"; static constexpr absl::string_view kLessString = "less_string"; static constexpr absl::string_view kLessBytes = "less_bytes"; static constexpr absl::string_view kLessDuration = "less_duration"; static constexpr absl::string_view kLessTimestamp = "less_timestamp"; static constexpr absl::string_view kLessInt = "less_int64"; static constexpr absl::string_view kLessIntUint = "less_int64_uint64"; static constexpr absl::string_view kLessIntDouble = "less_int64_double"; static constexpr absl::string_view kLessDouble = "less_double"; static constexpr absl::string_view kLessDoubleInt = "less_double_int64"; static constexpr absl::string_view kLessDoubleUint = "less_double_uint64"; static constexpr absl::string_view kLessUint = "less_uint64"; static constexpr absl::string_view kLessUintInt = "less_uint64_int64"; static constexpr absl::string_view kLessUintDouble = "less_uint64_double"; static constexpr absl::string_view kGreaterBool = "greater_bool"; static constexpr absl::string_view kGreaterString = "greater_string"; static constexpr absl::string_view kGreaterBytes = "greater_bytes"; static constexpr absl::string_view kGreaterDuration = "greater_duration"; static constexpr absl::string_view kGreaterTimestamp = "greater_timestamp"; static constexpr absl::string_view kGreaterInt = "greater_int64"; static constexpr absl::string_view kGreaterIntUint = "greater_int64_uint64"; static constexpr absl::string_view kGreaterIntDouble = "greater_int64_double"; static constexpr absl::string_view kGreaterDouble = "greater_double"; static constexpr absl::string_view kGreaterDoubleInt = "greater_double_int64"; static constexpr absl::string_view kGreaterDoubleUint = "greater_double_uint64"; static constexpr absl::string_view kGreaterUint = "greater_uint64"; static constexpr absl::string_view kGreaterUintInt = "greater_uint64_int64"; static constexpr absl::string_view kGreaterUintDouble = "greater_uint64_double"; static constexpr absl::string_view kGreaterEqualsBool = "greater_equals_bool"; static constexpr absl::string_view kGreaterEqualsString = "greater_equals_string"; static constexpr absl::string_view kGreaterEqualsBytes = "greater_equals_bytes"; static constexpr absl::string_view kGreaterEqualsDuration = "greater_equals_duration"; static constexpr absl::string_view kGreaterEqualsTimestamp = "greater_equals_timestamp"; static constexpr absl::string_view kGreaterEqualsInt = "greater_equals_int64"; static constexpr absl::string_view kGreaterEqualsIntUint = "greater_equals_int64_uint64"; static constexpr absl::string_view kGreaterEqualsIntDouble = "greater_equals_int64_double"; static constexpr absl::string_view kGreaterEqualsDouble = "greater_equals_double"; static constexpr absl::string_view kGreaterEqualsDoubleInt = "greater_equals_double_int64"; static constexpr absl::string_view kGreaterEqualsDoubleUint = "greater_equals_double_uint64"; static constexpr absl::string_view kGreaterEqualsUint = "greater_equals_uint64"; static constexpr absl::string_view kGreaterEqualsUintInt = "greater_equals_uint64_int64"; static constexpr absl::string_view kGreaterEqualsUintDouble = "greater_equals_uint_double"; static constexpr absl::string_view kLessEqualsBool = "less_equals_bool"; static constexpr absl::string_view kLessEqualsString = "less_equals_string"; static constexpr absl::string_view kLessEqualsBytes = "less_equals_bytes"; static constexpr absl::string_view kLessEqualsDuration = "less_equals_duration"; static constexpr absl::string_view kLessEqualsTimestamp = "less_equals_timestamp"; static constexpr absl::string_view kLessEqualsInt = "less_equals_int64"; static constexpr absl::string_view kLessEqualsIntUint = "less_equals_int64_uint64"; static constexpr absl::string_view kLessEqualsIntDouble = "less_equals_int64_double"; static constexpr absl::string_view kLessEqualsDouble = "less_equals_double"; static constexpr absl::string_view kLessEqualsDoubleInt = "less_equals_double_int64"; static constexpr absl::string_view kLessEqualsDoubleUint = "less_equals_double_uint64"; static constexpr absl::string_view kLessEqualsUint = "less_equals_uint64"; static constexpr absl::string_view kLessEqualsUintInt = "less_equals_uint64_int64"; static constexpr absl::string_view kLessEqualsUintDouble = "less_equals_uint64_double"; // Container operators static constexpr absl::string_view kIndexList = "index_list"; static constexpr absl::string_view kIndexMap = "index_map"; static constexpr absl::string_view kInList = "in_list"; static constexpr absl::string_view kInMap = "in_map"; static constexpr absl::string_view kSizeBytes = "size_bytes"; static constexpr absl::string_view kSizeList = "size_list"; static constexpr absl::string_view kSizeMap = "size_map"; static constexpr absl::string_view kSizeString = "size_string"; static constexpr absl::string_view kSizeBytesMember = "bytes_size"; static constexpr absl::string_view kSizeListMember = "list_size"; static constexpr absl::string_view kSizeMapMember = "map_size"; static constexpr absl::string_view kSizeStringMember = "string_size"; // String functions static constexpr absl::string_view kContainsString = "contains_string"; static constexpr absl::string_view kEndsWithString = "ends_with_string"; static constexpr absl::string_view kStartsWithString = "starts_with_string"; // String RE2 functions static constexpr absl::string_view kMatches = "matches"; static constexpr absl::string_view kMatchesMember = "matches_string"; // Timestamp / duration accessors static constexpr absl::string_view kTimestampToYear = "timestamp_to_year"; static constexpr absl::string_view kTimestampToYearWithTz = "timestamp_to_year_with_tz"; static constexpr absl::string_view kTimestampToMonth = "timestamp_to_month"; static constexpr absl::string_view kTimestampToMonthWithTz = "timestamp_to_month_with_tz"; static constexpr absl::string_view kTimestampToDayOfYear = "timestamp_to_day_of_year"; static constexpr absl::string_view kTimestampToDayOfYearWithTz = "timestamp_to_day_of_year_with_tz"; static constexpr absl::string_view kTimestampToDayOfMonth = "timestamp_to_day_of_month"; static constexpr absl::string_view kTimestampToDayOfMonthWithTz = "timestamp_to_day_of_month_with_tz"; static constexpr absl::string_view kTimestampToDayOfWeek = "timestamp_to_day_of_week"; static constexpr absl::string_view kTimestampToDayOfWeekWithTz = "timestamp_to_day_of_week_with_tz"; static constexpr absl::string_view kTimestampToDate = "timestamp_to_day_of_month_1_based"; static constexpr absl::string_view kTimestampToDateWithTz = "timestamp_to_day_of_month_1_based_with_tz"; static constexpr absl::string_view kTimestampToHours = "timestamp_to_hours"; static constexpr absl::string_view kTimestampToHoursWithTz = "timestamp_to_hours_with_tz"; static constexpr absl::string_view kDurationToHours = "duration_to_hours"; static constexpr absl::string_view kTimestampToMinutes = "timestamp_to_minutes"; static constexpr absl::string_view kTimestampToMinutesWithTz = "timestamp_to_minutes_with_tz"; static constexpr absl::string_view kDurationToMinutes = "duration_to_minutes"; static constexpr absl::string_view kTimestampToSeconds = "timestamp_to_seconds"; static constexpr absl::string_view kTimestampToSecondsWithTz = "timestamp_to_seconds_tz"; static constexpr absl::string_view kDurationToSeconds = "duration_to_seconds"; static constexpr absl::string_view kTimestampToMilliseconds = "timestamp_to_milliseconds"; static constexpr absl::string_view kTimestampToMillisecondsWithTz = "timestamp_to_milliseconds_with_tz"; static constexpr absl::string_view kDurationToMilliseconds = "duration_to_milliseconds"; // Type conversions static constexpr absl::string_view kToDyn = "to_dyn"; // to_uint static constexpr absl::string_view kUintToUint = "uint64_to_uint64"; static constexpr absl::string_view kDoubleToUint = "double_to_uint64"; static constexpr absl::string_view kIntToUint = "int64_to_uint64"; static constexpr absl::string_view kStringToUint = "string_to_uint64"; // to_int static constexpr absl::string_view kUintToInt = "uint64_to_int64"; static constexpr absl::string_view kDoubleToInt = "double_to_int64"; static constexpr absl::string_view kIntToInt = "int64_to_int64"; static constexpr absl::string_view kStringToInt = "string_to_int64"; static constexpr absl::string_view kTimestampToInt = "timestamp_to_int64"; static constexpr absl::string_view kDurationToInt = "duration_to_int64"; // to_double static constexpr absl::string_view kDoubleToDouble = "double_to_double"; static constexpr absl::string_view kUintToDouble = "uint64_to_double"; static constexpr absl::string_view kIntToDouble = "int64_to_double"; static constexpr absl::string_view kStringToDouble = "string_to_double"; // to_bool static constexpr absl::string_view kBoolToBool = "bool_to_bool"; static constexpr absl::string_view kStringToBool = "string_to_bool"; // to_bytes static constexpr absl::string_view kBytesToBytes = "bytes_to_bytes"; static constexpr absl::string_view kStringToBytes = "string_to_bytes"; // to_string static constexpr absl::string_view kStringToString = "string_to_string"; static constexpr absl::string_view kBytesToString = "bytes_to_string"; static constexpr absl::string_view kBoolToString = "bool_to_string"; static constexpr absl::string_view kDoubleToString = "double_to_string"; static constexpr absl::string_view kIntToString = "int64_to_string"; static constexpr absl::string_view kUintToString = "uint64_to_string"; static constexpr absl::string_view kDurationToString = "duration_to_string"; static constexpr absl::string_view kTimestampToString = "timestamp_to_string"; // to_timestamp static constexpr absl::string_view kTimestampToTimestamp = "timestamp_to_timestamp"; static constexpr absl::string_view kIntToTimestamp = "int64_to_timestamp"; static constexpr absl::string_view kStringToTimestamp = "string_to_timestamp"; // to_duration static constexpr absl::string_view kDurationToDuration = "duration_to_duration"; static constexpr absl::string_view kIntToDuration = "int64_to_duration"; static constexpr absl::string_view kStringToDuration = "string_to_duration"; // to_type static constexpr absl::string_view kToType = "type"; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_STANDARD_DEFINITIONS_H_ ================================================ FILE: common/type.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type.h" #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "common/type_kind.h" #include "common/types/types.h" #include "google/protobuf/descriptor.h" namespace cel { using ::google::protobuf::Descriptor; using ::google::protobuf::FieldDescriptor; Type Type::Message(const Descriptor* absl_nonnull descriptor) { switch (descriptor->well_known_type()) { case Descriptor::WELLKNOWNTYPE_BOOLVALUE: return BoolWrapperType(); case Descriptor::WELLKNOWNTYPE_INT32VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_INT64VALUE: return IntWrapperType(); case Descriptor::WELLKNOWNTYPE_UINT32VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_UINT64VALUE: return UintWrapperType(); case Descriptor::WELLKNOWNTYPE_FLOATVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: return DoubleWrapperType(); case Descriptor::WELLKNOWNTYPE_BYTESVALUE: return BytesWrapperType(); case Descriptor::WELLKNOWNTYPE_STRINGVALUE: return StringWrapperType(); case Descriptor::WELLKNOWNTYPE_ANY: return AnyType(); case Descriptor::WELLKNOWNTYPE_DURATION: return DurationType(); case Descriptor::WELLKNOWNTYPE_TIMESTAMP: return TimestampType(); case Descriptor::WELLKNOWNTYPE_VALUE: return DynType(); case Descriptor::WELLKNOWNTYPE_LISTVALUE: return ListType(); case Descriptor::WELLKNOWNTYPE_STRUCT: return JsonMapType(); default: return MessageType(descriptor); } } Type Type::Enum(const google::protobuf::EnumDescriptor* absl_nonnull descriptor) { if (descriptor->full_name() == "google.protobuf.NullValue") { // Special case NullValue to prevent the emebedder providing a different // descriptor for it and it leaking. return IntType(); } return EnumType(descriptor); } namespace { static constexpr std::array kTypeToKindArray = { TypeKind::kDyn, TypeKind::kAny, TypeKind::kBool, TypeKind::kBoolWrapper, TypeKind::kBytes, TypeKind::kBytesWrapper, TypeKind::kDouble, TypeKind::kDoubleWrapper, TypeKind::kDuration, TypeKind::kEnum, TypeKind::kError, TypeKind::kFunction, TypeKind::kInt, TypeKind::kIntWrapper, TypeKind::kList, TypeKind::kMap, TypeKind::kNull, TypeKind::kOpaque, TypeKind::kString, TypeKind::kStringWrapper, TypeKind::kStruct, TypeKind::kStruct, TypeKind::kTimestamp, TypeKind::kTypeParam, TypeKind::kType, TypeKind::kUint, TypeKind::kUintWrapper, TypeKind::kUnknown}; static_assert(kTypeToKindArray.size() == absl::variant_size(), "Kind indexer must match variant declaration for cel::Type."); } // namespace TypeKind Type::kind() const { return kTypeToKindArray[variant_.index()]; } absl::string_view Type::name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return absl::visit( [](const auto& alternative) -> absl::string_view { return alternative.name(); }, variant_); } std::string Type::DebugString() const { return absl::visit( [](const auto& alternative) -> std::string { return alternative.DebugString(); }, variant_); } TypeParameters Type::GetParameters() const { return absl::visit( [](const auto& alternative) -> TypeParameters { return alternative.GetParameters(); }, variant_); } bool operator==(const Type& lhs, const Type& rhs) { if (lhs.IsStruct() && rhs.IsStruct()) { return lhs.GetStruct() == rhs.GetStruct(); } else if (lhs.IsStruct() || rhs.IsStruct()) { return false; } else { return lhs.variant_ == rhs.variant_; } } common_internal::StructTypeVariant Type::ToStructTypeVariant() const { if (const auto* other = absl::get_if(&variant_); other != nullptr) { return common_internal::StructTypeVariant(*other); } if (const auto* other = absl::get_if(&variant_); other != nullptr) { return common_internal::StructTypeVariant(*other); } return common_internal::StructTypeVariant(); } namespace { template absl::optional GetOrNullopt(const common_internal::TypeVariant& variant) { if (const auto* alt = absl::get_if(&variant); alt != nullptr) { return *alt; } return absl::nullopt; } } // namespace absl::optional Type::AsAny() const { return GetOrNullopt(variant_); } absl::optional Type::AsBool() const { return GetOrNullopt(variant_); } absl::optional Type::AsBoolWrapper() const { return GetOrNullopt(variant_); } absl::optional Type::AsBytes() const { return GetOrNullopt(variant_); } absl::optional Type::AsBytesWrapper() const { return GetOrNullopt(variant_); } absl::optional Type::AsDouble() const { return GetOrNullopt(variant_); } absl::optional Type::AsDoubleWrapper() const { return GetOrNullopt(variant_); } absl::optional Type::AsDuration() const { return GetOrNullopt(variant_); } absl::optional Type::AsDyn() const { return GetOrNullopt(variant_); } absl::optional Type::AsEnum() const { return GetOrNullopt(variant_); } absl::optional Type::AsError() const { return GetOrNullopt(variant_); } absl::optional Type::AsFunction() const { return GetOrNullopt(variant_); } absl::optional Type::AsInt() const { return GetOrNullopt(variant_); } absl::optional Type::AsIntWrapper() const { return GetOrNullopt(variant_); } absl::optional Type::AsList() const { return GetOrNullopt(variant_); } absl::optional Type::AsMap() const { return GetOrNullopt(variant_); } absl::optional Type::AsMessage() const { return GetOrNullopt(variant_); } absl::optional Type::AsNull() const { return GetOrNullopt(variant_); } absl::optional Type::AsOpaque() const { return GetOrNullopt(variant_); } absl::optional Type::AsOptional() const { if (auto maybe_opaque = AsOpaque(); maybe_opaque.has_value()) { return maybe_opaque->AsOptional(); } return absl::nullopt; } absl::optional Type::AsString() const { return GetOrNullopt(variant_); } absl::optional Type::AsStringWrapper() const { return GetOrNullopt(variant_); } absl::optional Type::AsStruct() const { if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } return absl::nullopt; } absl::optional Type::AsTimestamp() const { return GetOrNullopt(variant_); } absl::optional Type::AsTypeParam() const { return GetOrNullopt(variant_); } absl::optional Type::AsType() const { return GetOrNullopt(variant_); } absl::optional Type::AsUint() const { return GetOrNullopt(variant_); } absl::optional Type::AsUintWrapper() const { return GetOrNullopt(variant_); } absl::optional Type::AsUnknown() const { return GetOrNullopt(variant_); } namespace { template T GetOrDie(const common_internal::TypeVariant& variant) { return absl::get(variant); } } // namespace AnyType Type::GetAny() const { ABSL_DCHECK(IsAny()) << DebugString(); return GetOrDie(variant_); } BoolType Type::GetBool() const { ABSL_DCHECK(IsBool()) << DebugString(); return GetOrDie(variant_); } BoolWrapperType Type::GetBoolWrapper() const { ABSL_DCHECK(IsBoolWrapper()) << DebugString(); return GetOrDie(variant_); } BytesType Type::GetBytes() const { ABSL_DCHECK(IsBytes()) << DebugString(); return GetOrDie(variant_); } BytesWrapperType Type::GetBytesWrapper() const { ABSL_DCHECK(IsBytesWrapper()) << DebugString(); return GetOrDie(variant_); } DoubleType Type::GetDouble() const { ABSL_DCHECK(IsDouble()) << DebugString(); return GetOrDie(variant_); } DoubleWrapperType Type::GetDoubleWrapper() const { ABSL_DCHECK(IsDoubleWrapper()) << DebugString(); return GetOrDie(variant_); } DurationType Type::GetDuration() const { ABSL_DCHECK(IsDuration()) << DebugString(); return GetOrDie(variant_); } DynType Type::GetDyn() const { ABSL_DCHECK(IsDyn()) << DebugString(); return GetOrDie(variant_); } EnumType Type::GetEnum() const { ABSL_DCHECK(IsEnum()) << DebugString(); return GetOrDie(variant_); } ErrorType Type::GetError() const { ABSL_DCHECK(IsError()) << DebugString(); return GetOrDie(variant_); } FunctionType Type::GetFunction() const { ABSL_DCHECK(IsFunction()) << DebugString(); return GetOrDie(variant_); } IntType Type::GetInt() const { ABSL_DCHECK(IsInt()) << DebugString(); return GetOrDie(variant_); } IntWrapperType Type::GetIntWrapper() const { ABSL_DCHECK(IsIntWrapper()) << DebugString(); return GetOrDie(variant_); } ListType Type::GetList() const { ABSL_DCHECK(IsList()) << DebugString(); return GetOrDie(variant_); } MapType Type::GetMap() const { ABSL_DCHECK(IsMap()) << DebugString(); return GetOrDie(variant_); } MessageType Type::GetMessage() const { ABSL_DCHECK(IsMessage()) << DebugString(); return GetOrDie(variant_); } NullType Type::GetNull() const { ABSL_DCHECK(IsNull()) << DebugString(); return GetOrDie(variant_); } OpaqueType Type::GetOpaque() const { ABSL_DCHECK(IsOpaque()) << DebugString(); return GetOrDie(variant_); } OptionalType Type::GetOptional() const { ABSL_DCHECK(IsOptional()) << DebugString(); return GetOrDie(variant_).GetOptional(); } StringType Type::GetString() const { ABSL_DCHECK(IsString()) << DebugString(); return GetOrDie(variant_); } StringWrapperType Type::GetStringWrapper() const { ABSL_DCHECK(IsStringWrapper()) << DebugString(); return GetOrDie(variant_); } StructType Type::GetStruct() const { ABSL_DCHECK(IsStruct()) << DebugString(); if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } return StructType(); } TimestampType Type::GetTimestamp() const { ABSL_DCHECK(IsTimestamp()) << DebugString(); return GetOrDie(variant_); } TypeParamType Type::GetTypeParam() const { ABSL_DCHECK(IsTypeParam()) << DebugString(); return GetOrDie(variant_); } TypeType Type::GetType() const { ABSL_DCHECK(IsType()) << DebugString(); return GetOrDie(variant_); } UintType Type::GetUint() const { ABSL_DCHECK(IsUint()) << DebugString(); return GetOrDie(variant_); } UintWrapperType Type::GetUintWrapper() const { ABSL_DCHECK(IsUintWrapper()) << DebugString(); return GetOrDie(variant_); } UnknownType Type::GetUnknown() const { ABSL_DCHECK(IsUnknown()) << DebugString(); return GetOrDie(variant_); } Type Type::Unwrap() const { switch (kind()) { case TypeKind::kBoolWrapper: return BoolType(); case TypeKind::kIntWrapper: return IntType(); case TypeKind::kUintWrapper: return UintType(); case TypeKind::kDoubleWrapper: return DoubleType(); case TypeKind::kBytesWrapper: return BytesType(); case TypeKind::kStringWrapper: return StringType(); default: return *this; } } Type Type::Wrap() const { switch (kind()) { case TypeKind::kBool: return BoolWrapperType(); case TypeKind::kInt: return IntWrapperType(); case TypeKind::kUint: return UintWrapperType(); case TypeKind::kDouble: return DoubleWrapperType(); case TypeKind::kBytes: return BytesWrapperType(); case TypeKind::kString: return StringWrapperType(); default: return *this; } } namespace common_internal { Type SingularMessageFieldType( const google::protobuf::FieldDescriptor* absl_nonnull descriptor) { ABSL_DCHECK(!descriptor->is_map()); switch (descriptor->type()) { case FieldDescriptor::TYPE_BOOL: return BoolType(); case FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SINT32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_INT32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SINT64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_INT64: return IntType(); case FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_UINT32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_FIXED64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_UINT64: return UintType(); case FieldDescriptor::TYPE_FLOAT: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_DOUBLE: return DoubleType(); case FieldDescriptor::TYPE_BYTES: return BytesType(); case FieldDescriptor::TYPE_STRING: return StringType(); case FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_MESSAGE: return Type::Message(descriptor->message_type()); case FieldDescriptor::TYPE_ENUM: return Type::Enum(descriptor->enum_type()); default: return Type(); } } std::string BasicStructTypeField::DebugString() const { if (!name().empty() && number() >= 1) { return absl::StrCat("[", number(), "]", name()); } if (!name().empty()) { return std::string(name()); } if (number() >= 1) { return absl::StrCat(number()); } return std::string(); } } // namespace common_internal Type Type::Field(const google::protobuf::FieldDescriptor* absl_nonnull descriptor) { if (descriptor->is_map()) { return MapType(descriptor->message_type()); } if (descriptor->is_repeated()) { return ListType(descriptor); } return common_internal::SingularMessageFieldType(descriptor); } std::string StructTypeField::DebugString() const { return absl::visit( [](const auto& alternative) -> std::string { return alternative.DebugString(); }, variant_); } absl::string_view StructTypeField::name() const { return absl::visit( [](const auto& alternative) -> absl::string_view { return alternative.name(); }, variant_); } int32_t StructTypeField::number() const { return absl::visit( [](const auto& alternative) -> int32_t { return alternative.number(); }, variant_); } Type StructTypeField::GetType() const { return absl::visit( [](const auto& alternative) -> Type { return alternative.GetType(); }, variant_); } StructTypeField::operator bool() const { return absl::visit( [](const auto& alternative) -> bool { return static_cast(alternative); }, variant_); } absl::optional StructTypeField::AsMessage() const { if (const auto* alternative = absl::get_if(&variant_); alternative != nullptr) { return *alternative; } return absl::nullopt; } StructTypeField::operator MessageTypeField() const { ABSL_DCHECK(IsMessage()); return absl::get(variant_); } TypeParameters::TypeParameters(absl::Span types) : size_(types.size()) { if (size_ <= 2) { std::memcpy(&internal_[0], types.data(), size_ * sizeof(Type)); } else { external_ = types.data(); } } TypeParameters::TypeParameters(const Type& element) : size_(1) { std::memcpy(&internal_[0], &element, sizeof(element)); } TypeParameters::TypeParameters(const Type& key, const Type& value) : size_(2) { std::memcpy(&internal_[0], &key, sizeof(key)); std::memcpy(&internal_[0] + sizeof(key), &value, sizeof(value)); } namespace common_internal { namespace { constexpr absl::string_view kNullTypeName = "null_type"; constexpr absl::string_view kBoolTypeName = "bool"; constexpr absl::string_view kInt64TypeName = "int"; constexpr absl::string_view kUInt64TypeName = "uint"; constexpr absl::string_view kDoubleTypeName = "double"; constexpr absl::string_view kStringTypeName = "string"; constexpr absl::string_view kBytesTypeName = "bytes"; constexpr absl::string_view kDurationTypeName = "google.protobuf.Duration"; constexpr absl::string_view kTimestampTypeName = "google.protobuf.Timestamp"; constexpr absl::string_view kListTypeName = "list"; constexpr absl::string_view kMapTypeName = "map"; constexpr absl::string_view kCelTypeTypeName = "type"; } // namespace Type LegacyRuntimeType(absl::string_view name) { if (name == kNullTypeName) { return NullType{}; } if (name == kBoolTypeName) { return BoolType{}; } if (name == kInt64TypeName) { return IntType{}; } if (name == kUInt64TypeName) { return UintType{}; } if (name == kDoubleTypeName) { return DoubleType{}; } if (name == kStringTypeName) { return StringType{}; } if (name == kBytesTypeName) { return BytesType{}; } if (name == kDurationTypeName) { return DurationType{}; } if (name == kTimestampTypeName) { return TimestampType{}; } if (name == kListTypeName) { return ListType{}; } if (name == kMapTypeName) { return MapType{}; } if (name == kCelTypeTypeName) { return TypeType{}; } return common_internal::MakeBasicStructType(name); } } // namespace common_internal } // namespace cel ================================================ FILE: common/type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_H_ #include #include #include #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/meta/type_traits.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "absl/utility/utility.h" #include "common/type_kind.h" #include "common/types/any_type.h" // IWYU pragma: export #include "common/types/bool_type.h" // IWYU pragma: export #include "common/types/bool_wrapper_type.h" // IWYU pragma: export #include "common/types/bytes_type.h" // IWYU pragma: export #include "common/types/bytes_wrapper_type.h" // IWYU pragma: export #include "common/types/double_type.h" // IWYU pragma: export #include "common/types/double_wrapper_type.h" // IWYU pragma: export #include "common/types/duration_type.h" // IWYU pragma: export #include "common/types/dyn_type.h" // IWYU pragma: export #include "common/types/enum_type.h" // IWYU pragma: export #include "common/types/error_type.h" // IWYU pragma: export #include "common/types/function_type.h" // IWYU pragma: export #include "common/types/int_type.h" // IWYU pragma: export #include "common/types/int_wrapper_type.h" // IWYU pragma: export #include "common/types/list_type.h" // IWYU pragma: export #include "common/types/map_type.h" // IWYU pragma: export #include "common/types/message_type.h" // IWYU pragma: export #include "common/types/null_type.h" // IWYU pragma: export #include "common/types/opaque_type.h" // IWYU pragma: export #include "common/types/optional_type.h" // IWYU pragma: export #include "common/types/string_type.h" // IWYU pragma: export #include "common/types/string_wrapper_type.h" // IWYU pragma: export #include "common/types/struct_type.h" // IWYU pragma: export #include "common/types/timestamp_type.h" // IWYU pragma: export #include "common/types/type_param_type.h" // IWYU pragma: export #include "common/types/type_type.h" // IWYU pragma: export #include "common/types/types.h" #include "common/types/uint_type.h" // IWYU pragma: export #include "common/types/uint_wrapper_type.h" // IWYU pragma: export #include "common/types/unknown_type.h" // IWYU pragma: export #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { class Type; class TypeParameters; // `Type` is a composition type which encompasses all types supported by the // Common Expression Language. When default constructed, `Type` is in a // known but invalid state. Any attempt to use it from then on, without // assigning another type, is undefined behavior. In debug builds, we do our // best to fail. // // The data underlying `Type` is either static or owned by `google::protobuf::Arena`. As // such, care must be taken to ensure types remain valid throughout their use. class Type final { public: // Returns an appropriate `Type` for the dynamic protobuf message. For well // known message types, the appropriate `Type` is returned. All others return // `MessageType`. static Type Message(const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Type` for the dynamic protobuf message field. static Type Field(const google::protobuf::FieldDescriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Type` for the dynamic protobuf enum. For well // known enum types, the appropriate `Type` is returned. All others return // `EnumType`. static Type Enum(const google::protobuf::EnumDescriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); using Parameters = TypeParameters; // The default constructor results in Type being DynType. Type() = default; Type(const Type&) = default; Type(Type&&) = default; Type& operator=(const Type&) = default; Type& operator=(Type&&) = default; template >>> // NOLINTNEXTLINE(google-explicit-constructor) constexpr Type(T&& alternative) noexcept : variant_(absl::in_place_type>, std::forward(alternative)) {} template >>> // NOLINTNEXTLINE(google-explicit-constructor) Type& operator=(T&& type) noexcept { variant_.emplace>(std::forward(type)); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Type(StructType alternative) : variant_(alternative.ToTypeVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Type& operator=(StructType alternative) { variant_ = alternative.ToTypeVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Type(OptionalType alternative) : Type(OpaqueType(std::move(alternative))) {} // NOLINTNEXTLINE(google-explicit-constructor) Type& operator=(OptionalType alternative) { return *this = OpaqueType(std::move(alternative)); } TypeKind kind() const; absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; // Returns a debug string for the type. Not suitable for user-facing error // messages. std::string DebugString() const; Parameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; template friend H AbslHashValue(H state, const Type& type) { return absl::visit( [state = std::move(state)](const auto& alternative) mutable -> H { return H::combine(std::move(state), alternative, alternative.kind()); }, type.variant_); } friend bool operator==(const Type& lhs, const Type& rhs); friend std::ostream& operator<<(std::ostream& out, const Type& type) { return absl::visit( [&out](const auto& alternative) -> std::ostream& { return out << alternative; }, type.variant_); } bool IsAny() const { return absl::holds_alternative(variant_); } bool IsBool() const { return absl::holds_alternative(variant_); } bool IsBoolWrapper() const { return absl::holds_alternative(variant_); } bool IsBytes() const { return absl::holds_alternative(variant_); } bool IsBytesWrapper() const { return absl::holds_alternative(variant_); } bool IsDouble() const { return absl::holds_alternative(variant_); } bool IsDoubleWrapper() const { return absl::holds_alternative(variant_); } bool IsDuration() const { return absl::holds_alternative(variant_); } bool IsDyn() const { return absl::holds_alternative(variant_); } bool IsEnum() const { return absl::holds_alternative(variant_); } bool IsError() const { return absl::holds_alternative(variant_); } bool IsFunction() const { return absl::holds_alternative(variant_); } bool IsInt() const { return absl::holds_alternative(variant_); } bool IsIntWrapper() const { return absl::holds_alternative(variant_); } bool IsList() const { return absl::holds_alternative(variant_); } bool IsMap() const { return absl::holds_alternative(variant_); } bool IsMessage() const { return absl::holds_alternative(variant_); } bool IsNull() const { return absl::holds_alternative(variant_); } bool IsOpaque() const { return absl::holds_alternative(variant_); } bool IsOptional() const { return IsOpaque() && GetOpaque().IsOptional(); } bool IsString() const { return absl::holds_alternative(variant_); } bool IsStringWrapper() const { return absl::holds_alternative(variant_); } bool IsStruct() const { return absl::holds_alternative( variant_) || absl::holds_alternative(variant_); } bool IsTimestamp() const { return absl::holds_alternative(variant_); } bool IsTypeParam() const { return absl::holds_alternative(variant_); } bool IsType() const { return absl::holds_alternative(variant_); } bool IsUint() const { return absl::holds_alternative(variant_); } bool IsUintWrapper() const { return absl::holds_alternative(variant_); } bool IsUnknown() const { return absl::holds_alternative(variant_); } bool IsWrapper() const { return IsBoolWrapper() || IsIntWrapper() || IsUintWrapper() || IsDoubleWrapper() || IsBytesWrapper() || IsStringWrapper(); } template std::enable_if_t, bool> Is() const { return IsAny(); } template std::enable_if_t, bool> Is() const { return IsBool(); } template std::enable_if_t, bool> Is() const { return IsBoolWrapper(); } template std::enable_if_t, bool> Is() const { return IsBytes(); } template std::enable_if_t, bool> Is() const { return IsBytesWrapper(); } template std::enable_if_t, bool> Is() const { return IsDouble(); } template std::enable_if_t, bool> Is() const { return IsDoubleWrapper(); } template std::enable_if_t, bool> Is() const { return IsDuration(); } template std::enable_if_t, bool> Is() const { return IsDyn(); } template std::enable_if_t, bool> Is() const { return IsEnum(); } template std::enable_if_t, bool> Is() const { return IsError(); } template std::enable_if_t, bool> Is() const { return IsFunction(); } template std::enable_if_t, bool> Is() const { return IsInt(); } template std::enable_if_t, bool> Is() const { return IsIntWrapper(); } template std::enable_if_t, bool> Is() const { return IsList(); } template std::enable_if_t, bool> Is() const { return IsMap(); } template std::enable_if_t, bool> Is() const { return IsMessage(); } template std::enable_if_t, bool> Is() const { return IsNull(); } template std::enable_if_t, bool> Is() const { return IsOpaque(); } template std::enable_if_t, bool> Is() const { return IsOptional(); } template std::enable_if_t, bool> Is() const { return IsString(); } template std::enable_if_t, bool> Is() const { return IsStringWrapper(); } template std::enable_if_t, bool> Is() const { return IsStruct(); } template std::enable_if_t, bool> Is() const { return IsTimestamp(); } template std::enable_if_t, bool> Is() const { return IsTypeParam(); } template std::enable_if_t, bool> Is() const { return IsType(); } template std::enable_if_t, bool> Is() const { return IsUint(); } template std::enable_if_t, bool> Is() const { return IsUintWrapper(); } template std::enable_if_t, bool> Is() const { return IsUnknown(); } absl::optional AsAny() const; absl::optional AsBool() const; absl::optional AsBoolWrapper() const; absl::optional AsBytes() const; absl::optional AsBytesWrapper() const; absl::optional AsDouble() const; absl::optional AsDoubleWrapper() const; absl::optional AsDuration() const; absl::optional AsDyn() const; absl::optional AsEnum() const; absl::optional AsError() const; absl::optional AsFunction() const; absl::optional AsInt() const; absl::optional AsIntWrapper() const; absl::optional AsList() const; absl::optional AsMap() const; // AsMessage performs a checked cast, returning `MessageType` if this type is // both a struct and a message or `absl::nullopt` otherwise. If you have // already called `IsMessage()` it is more performant to perform to do // `static_cast(type)`. absl::optional AsMessage() const; absl::optional AsNull() const; absl::optional AsOpaque() const; absl::optional AsOptional() const; absl::optional AsString() const; absl::optional AsStringWrapper() const; // AsStruct performs a checked cast, returning `StructType` if this type is a // struct or `absl::nullopt` otherwise. If you have already called // `IsStruct()` it is more performant to perform to do // `static_cast(type)`. absl::optional AsStruct() const; absl::optional AsTimestamp() const; absl::optional AsTypeParam() const; absl::optional AsType() const; absl::optional AsUint() const; absl::optional AsUintWrapper() const; absl::optional AsUnknown() const; template std::enable_if_t, absl::optional> As() const { return AsAny(); } template std::enable_if_t, absl::optional> As() const { return AsBool(); } template std::enable_if_t, absl::optional> As() const { return AsBoolWrapper(); } template std::enable_if_t, absl::optional> As() const { return AsBytes(); } template std::enable_if_t, absl::optional> As() const { return AsBytesWrapper(); } template std::enable_if_t, absl::optional> As() const { return AsDouble(); } template std::enable_if_t, absl::optional> As() const { return AsDoubleWrapper(); } template std::enable_if_t, absl::optional> As() const { return AsDuration(); } template std::enable_if_t, absl::optional> As() const { return AsDyn(); } template std::enable_if_t, absl::optional> As() const { return AsEnum(); } template std::enable_if_t, absl::optional> As() const { return AsError(); } template std::enable_if_t, absl::optional> As() const { return AsFunction(); } template std::enable_if_t, absl::optional> As() const { return AsInt(); } template std::enable_if_t, absl::optional> As() const { return AsIntWrapper(); } template std::enable_if_t, absl::optional> As() const { return AsList(); } template std::enable_if_t, absl::optional> As() const { return AsMap(); } template std::enable_if_t, absl::optional> As() const { return AsMessage(); } template std::enable_if_t, absl::optional> As() const { return AsNull(); } template std::enable_if_t, absl::optional> As() const { return AsOpaque(); } template std::enable_if_t, absl::optional> As() const { return AsOptional(); } template std::enable_if_t, absl::optional> As() const { return AsString(); } template std::enable_if_t, absl::optional> As() const { return AsStringWrapper(); } template std::enable_if_t, absl::optional> As() const { return AsStruct(); } template std::enable_if_t, absl::optional> As() const { return AsTimestamp(); } template std::enable_if_t, absl::optional> As() const { return AsTypeParam(); } template std::enable_if_t, absl::optional> As() const { return AsType(); } template std::enable_if_t, absl::optional> As() const { return AsUint(); } template std::enable_if_t, absl::optional> As() const { return AsUintWrapper(); } template std::enable_if_t, absl::optional> As() const { return AsUnknown(); } AnyType GetAny() const; BoolType GetBool() const; BoolWrapperType GetBoolWrapper() const; BytesType GetBytes() const; BytesWrapperType GetBytesWrapper() const; DoubleType GetDouble() const; DoubleWrapperType GetDoubleWrapper() const; DurationType GetDuration() const; DynType GetDyn() const; EnumType GetEnum() const; ErrorType GetError() const; FunctionType GetFunction() const; IntType GetInt() const; IntWrapperType GetIntWrapper() const; ListType GetList() const; MapType GetMap() const; MessageType GetMessage() const; NullType GetNull() const; OpaqueType GetOpaque() const; OptionalType GetOptional() const; StringType GetString() const; StringWrapperType GetStringWrapper() const; StructType GetStruct() const; TimestampType GetTimestamp() const; TypeParamType GetTypeParam() const; TypeType GetType() const; UintType GetUint() const; UintWrapperType GetUintWrapper() const; UnknownType GetUnknown() const; template std::enable_if_t, AnyType> Get() const { return GetAny(); } template std::enable_if_t, BoolType> Get() const { return GetBool(); } template std::enable_if_t, BoolWrapperType> Get() const { return GetBoolWrapper(); } template std::enable_if_t, BytesType> Get() const { return GetBytes(); } template std::enable_if_t, BytesWrapperType> Get() const { return GetBytesWrapper(); } template std::enable_if_t, DoubleType> Get() const { return GetDouble(); } template std::enable_if_t, DoubleWrapperType> Get() const { return GetDoubleWrapper(); } template std::enable_if_t, DurationType> Get() const { return GetDuration(); } template std::enable_if_t, DynType> Get() const { return GetDyn(); } template std::enable_if_t, EnumType> Get() const { return GetEnum(); } template std::enable_if_t, ErrorType> Get() const { return GetError(); } template std::enable_if_t, FunctionType> Get() const { return GetFunction(); } template std::enable_if_t, IntType> Get() const { return GetInt(); } template std::enable_if_t, IntWrapperType> Get() const { return GetIntWrapper(); } template std::enable_if_t, ListType> Get() const { return GetList(); } template std::enable_if_t, MapType> Get() const { return GetMap(); } template std::enable_if_t, MessageType> Get() const { return GetMessage(); } template std::enable_if_t, NullType> Get() const { return GetNull(); } template std::enable_if_t, OpaqueType> Get() const { return GetOpaque(); } template std::enable_if_t, OptionalType> Get() const { return GetOptional(); } template std::enable_if_t, StringType> Get() const { return GetString(); } template std::enable_if_t, StringWrapperType> Get() const { return GetStringWrapper(); } template std::enable_if_t, StructType> Get() const { return GetStruct(); } template std::enable_if_t, TimestampType> Get() const { return GetTimestamp(); } template std::enable_if_t, TypeParamType> Get() const { return GetTypeParam(); } template std::enable_if_t, TypeType> Get() const { return GetType(); } template std::enable_if_t, UintType> Get() const { return GetUint(); } template std::enable_if_t, UintWrapperType> Get() const { return GetUintWrapper(); } template std::enable_if_t, UnknownType> Get() const { return GetUnknown(); } // Returns an unwrapped `Type` for a wrapped type, otherwise just returns // this. Type Unwrap() const; // Returns an wrapped `Type` for a primitive type, otherwise just returns // this. Type Wrap() const; private: friend class StructType; friend class MessageType; friend class common_internal::BasicStructType; common_internal::StructTypeVariant ToStructTypeVariant() const; common_internal::TypeVariant variant_; }; inline bool operator!=(const Type& lhs, const Type& rhs) { return !operator==(lhs, rhs); } inline Type JsonType() { return DynType(); } // Statically assert some expectations. static_assert(std::is_default_constructible_v); static_assert(std::is_copy_constructible_v); static_assert(std::is_copy_assignable_v); static_assert(std::is_nothrow_move_constructible_v); static_assert(std::is_nothrow_move_assignable_v); // TypeParameters is a specialized view of a contiguous list of `Type`. It is // very similar to `absl::Span`, except that it has a small amount // of inline storage. Thus the pointers and references returned by // TypeParameters are invalidated upon copying or moving. // // We store up to 2 types inline. This is done to accommodate list and map types // which correspond to protocol buffer message fields. We launder around their // descriptors and would have to allocate to return the type parameters. We want // to avoid this, as types are supposed to be constant after creation. class TypeParameters final { public: using element_type = const Type; using value_type = Type; using pointer = element_type*; using const_pointer = const element_type*; using reference = element_type&; using const_reference = const element_type&; using iterator = pointer; using const_iterator = const_pointer; using reverse_iterator = std::reverse_iterator; using const_reverse_iterator = std::reverse_iterator; using size_type = size_t; using difference_type = ptrdiff_t; explicit TypeParameters(absl::Span types); TypeParameters() = default; TypeParameters(const TypeParameters&) = default; TypeParameters(TypeParameters&&) = default; TypeParameters& operator=(const TypeParameters&) = default; TypeParameters& operator=(TypeParameters&&) = default; size_type size() const { return size_; } bool empty() const { return size() == 0; } const_reference front() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(!empty()); return data()[0]; } const_reference back() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(!empty()); return data()[size() - 1]; } const_reference operator[](size_type index) const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK_LT(index, size()); return data()[index]; } const_pointer data() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return size() <= 2 ? reinterpret_cast(&internal_[0]) : external_; } const_iterator begin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return data(); } const_iterator cbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return begin(); } const_iterator end() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return data() + size(); } const_iterator cend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return end(); } const_reverse_iterator rbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::make_reverse_iterator(end()); } const_reverse_iterator crbegin() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return rbegin(); } const_reverse_iterator rend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::make_reverse_iterator(begin()); } const_reverse_iterator crend() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return rend(); } private: friend class ListType; friend class MapType; explicit TypeParameters(const Type& element); explicit TypeParameters(const Type& key, const Type& value); // When size_ <= 2, elements are stored directly in `internal_`. Otherwise we // store a pointer to the elements in `external_`. size_t size_ = 0; union { const Type* external_ = nullptr; // Old versions of GCC do not like `Type internal_[2]`, so we cheat. alignas(Type) char internal_[sizeof(Type) * 2]; }; }; // Now that TypeParameters is defined, we can define `GetParameters()` for most // types. inline TypeParameters AnyType::GetParameters() { return {}; } inline TypeParameters BoolType::GetParameters() { return {}; } inline TypeParameters BoolWrapperType::GetParameters() { return {}; } inline TypeParameters BytesType::GetParameters() { return {}; } inline TypeParameters BytesWrapperType::GetParameters() { return {}; } inline TypeParameters DoubleType::GetParameters() { return {}; } inline TypeParameters DoubleWrapperType::GetParameters() { return {}; } inline TypeParameters DurationType::GetParameters() { return {}; } inline TypeParameters DynType::GetParameters() { return {}; } inline TypeParameters EnumType::GetParameters() { return {}; } inline TypeParameters ErrorType::GetParameters() { return {}; } inline TypeParameters IntType::GetParameters() { return {}; } inline TypeParameters IntWrapperType::GetParameters() { return {}; } inline TypeParameters MessageType::GetParameters() { return {}; } inline TypeParameters NullType::GetParameters() { return {}; } inline TypeParameters OptionalType::GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return opaque_.GetParameters(); } inline TypeParameters StringType::GetParameters() { return {}; } inline TypeParameters StringWrapperType::GetParameters() { return {}; } inline TypeParameters TimestampType::GetParameters() { return {}; } inline TypeParameters TypeParamType::GetParameters() { return {}; } inline TypeParameters UintType::GetParameters() { return {}; } inline TypeParameters UintWrapperType::GetParameters() { return {}; } inline TypeParameters UnknownType::GetParameters() { return {}; } namespace common_internal { inline TypeParameters BasicStructType::GetParameters() { return {}; } Type SingularMessageFieldType( const google::protobuf::FieldDescriptor* absl_nonnull descriptor); class BasicStructTypeField final { public: BasicStructTypeField(absl::string_view name, int32_t number, Type type) : name_(name), number_(number), type_(type) {} BasicStructTypeField(const BasicStructTypeField&) = default; BasicStructTypeField(BasicStructTypeField&&) = default; BasicStructTypeField& operator=(const BasicStructTypeField&) = default; BasicStructTypeField& operator=(BasicStructTypeField&&) = default; std::string DebugString() const; absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } int32_t number() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return number_; } Type GetType() const { return type_; } explicit operator bool() const { return !name_.empty() || number_ >= 1; } private: absl::string_view name_; int32_t number_ = 0; Type type_; }; inline bool operator==(const BasicStructTypeField& lhs, const BasicStructTypeField& rhs) { return lhs.name() == rhs.name() && lhs.number() == rhs.number() && lhs.GetType() == rhs.GetType(); } inline bool operator!=(const BasicStructTypeField& lhs, const BasicStructTypeField& rhs) { return !operator==(lhs, rhs); } } // namespace common_internal class StructTypeField final { public: // NOLINTNEXTLINE(google-explicit-constructor) StructTypeField(common_internal::BasicStructTypeField field) : variant_(absl::in_place_type, field) {} // NOLINTNEXTLINE(google-explicit-constructor) StructTypeField(MessageTypeField field) : variant_(absl::in_place_type, field) {} StructTypeField() = delete; StructTypeField(const StructTypeField&) = default; StructTypeField(StructTypeField&&) = default; StructTypeField& operator=(const StructTypeField&) = default; StructTypeField& operator=(StructTypeField&&) = default; std::string DebugString() const; absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; int32_t number() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Type GetType() const; explicit operator bool() const; bool IsMessage() const { return absl::holds_alternative(variant_); } absl::optional AsMessage() const; explicit operator MessageTypeField() const; private: absl::variant variant_; }; inline bool operator==(const StructTypeField& lhs, const StructTypeField& rhs) { return lhs.name() == rhs.name() && lhs.number() == rhs.number() && lhs.GetType() == rhs.GetType(); } inline bool operator!=(const StructTypeField& lhs, const StructTypeField& rhs) { return !operator==(lhs, rhs); } // Now that Type is defined, we can define everything else. namespace common_internal { struct ListTypeData final { static ListTypeData* absl_nonnull Create(google::protobuf::Arena* absl_nonnull arena, const Type& element); ListTypeData() = default; ListTypeData(const ListTypeData&) = delete; ListTypeData(ListTypeData&&) = delete; ListTypeData& operator=(const ListTypeData&) = delete; ListTypeData& operator=(ListTypeData&&) = delete; Type element = DynType(); private: explicit ListTypeData(const Type& element); }; struct MapTypeData final { static MapTypeData* absl_nonnull Create(google::protobuf::Arena* absl_nonnull arena, const Type& key, const Type& value); Type key_and_value[2]; }; struct FunctionTypeData final { static FunctionTypeData* absl_nonnull Create( google::protobuf::Arena* absl_nonnull arena, const Type& result, absl::Span args); FunctionTypeData() = delete; FunctionTypeData(const FunctionTypeData&) = delete; FunctionTypeData(FunctionTypeData&&) = delete; FunctionTypeData& operator=(const FunctionTypeData&) = delete; FunctionTypeData& operator=(FunctionTypeData&&) = delete; const size_t args_size; // Flexible array, has `args_size` elements, with the first element being the // return type. FunctionTypeData has a variable length size, which includes // this flexible array. Type args[]; private: FunctionTypeData(const Type& result, absl::Span args); }; struct OpaqueTypeData final { static OpaqueTypeData* absl_nonnull Create(google::protobuf::Arena* absl_nonnull arena, absl::string_view name, absl::Span parameters); OpaqueTypeData() = delete; OpaqueTypeData(const OpaqueTypeData&) = delete; OpaqueTypeData(OpaqueTypeData&&) = delete; OpaqueTypeData& operator=(const OpaqueTypeData&) = delete; OpaqueTypeData& operator=(OpaqueTypeData&&) = delete; const absl::string_view name; const size_t parameters_size; // Flexible array, has `parameters_size` elements. OpaqueTypeData has a // variable length size, which includes this flexible array. Type parameters[]; private: OpaqueTypeData(absl::string_view name, absl::Span parameters); }; } // namespace common_internal inline bool operator==(const MessageTypeField& lhs, const MessageTypeField& rhs) { return lhs.name() == rhs.name() && lhs.number() == rhs.number() && lhs.GetType() == rhs.GetType(); } inline bool operator!=(const MessageTypeField& lhs, const MessageTypeField& rhs) { return !operator==(lhs, rhs); } inline bool operator==(const ListType& lhs, const ListType& rhs) { return &lhs == &rhs || lhs.GetElement() == rhs.GetElement(); } template inline H AbslHashValue(H state, const ListType& type) { return H::combine(std::move(state), type.GetElement(), size_t{1}); } inline bool operator==(const MapType& lhs, const MapType& rhs) { return &lhs == &rhs || (lhs.GetKey() == rhs.GetKey() && lhs.GetValue() == rhs.GetValue()); } template inline H AbslHashValue(H state, const MapType& type) { return H::combine(std::move(state), type.GetKey(), type.GetValue(), size_t{2}); } inline bool operator==(const OpaqueType& lhs, const OpaqueType& rhs) { return lhs.name() == rhs.name() && absl::c_equal(lhs.GetParameters(), rhs.GetParameters()); } template inline H AbslHashValue(H state, const OpaqueType& type) { state = H::combine(std::move(state), type.name()); auto parameters = type.GetParameters(); for (const auto& parameter : parameters) { state = H::combine(std::move(state), parameter); } return H::combine(std::move(state), parameters.size()); } inline bool operator==(const FunctionType& lhs, const FunctionType& rhs) { return lhs.result() == rhs.result() && absl::c_equal(lhs.args(), rhs.args()); } template inline H AbslHashValue(H state, const FunctionType& type) { state = H::combine(std::move(state), type.result()); auto args = type.args(); for (const auto& arg : args) { state = H::combine(std::move(state), arg); } return H::combine(std::move(state), args.size()); } namespace common_internal { // Converts the string returned from `CelValue::CelTypeHolder` to `cel::Type`. // The underlying content of `name` must outlive the resulting type and any of // its shallow copies. Type LegacyRuntimeType(absl::string_view name); } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_H_ ================================================ FILE: common/type_introspector.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type_introspector.h" #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/type.h" namespace cel { namespace { common_internal::BasicStructTypeField MakeBasicStructTypeField( absl::string_view name, Type type, int32_t number) { return common_internal::BasicStructTypeField(name, number, type); } struct FieldNameComparer { using is_transparent = void; bool operator()(const common_internal::BasicStructTypeField& lhs, const common_internal::BasicStructTypeField& rhs) const { return (*this)(lhs.name(), rhs.name()); } bool operator()(const common_internal::BasicStructTypeField& lhs, absl::string_view rhs) const { return (*this)(lhs.name(), rhs); } bool operator()(absl::string_view lhs, const common_internal::BasicStructTypeField& rhs) const { return (*this)(lhs, rhs.name()); } bool operator()(absl::string_view lhs, absl::string_view rhs) const { return lhs < rhs; } }; struct FieldNumberComparer { using is_transparent = void; bool operator()(const common_internal::BasicStructTypeField& lhs, const common_internal::BasicStructTypeField& rhs) const { return (*this)(lhs.number(), rhs.number()); } bool operator()(const common_internal::BasicStructTypeField& lhs, int64_t rhs) const { return (*this)(lhs.number(), rhs); } bool operator()(int64_t lhs, const common_internal::BasicStructTypeField& rhs) const { return (*this)(lhs, rhs.number()); } bool operator()(int64_t lhs, int64_t rhs) const { return lhs < rhs; } }; struct WellKnownType { WellKnownType( const Type& type, std::initializer_list fields) : type(type), fields_by_name(fields), fields_by_number(fields) { std::sort(fields_by_name.begin(), fields_by_name.end(), FieldNameComparer{}); std::sort(fields_by_number.begin(), fields_by_number.end(), FieldNumberComparer{}); } explicit WellKnownType(const Type& type) : WellKnownType(type, {}) {} Type type; // We use `2` as that accommodates most well known types. absl::InlinedVector fields_by_name; absl::InlinedVector fields_by_number; absl::optional FieldByName(absl::string_view name) const { // Basically `std::binary_search`. auto it = std::lower_bound(fields_by_name.begin(), fields_by_name.end(), name, FieldNameComparer{}); if (it == fields_by_name.end() || it->name() != name) { return absl::nullopt; } return *it; } absl::optional FieldByNumber(int64_t number) const { // Basically `std::binary_search`. auto it = std::lower_bound(fields_by_number.begin(), fields_by_number.end(), number, FieldNumberComparer{}); if (it == fields_by_number.end() || it->number() != number) { return absl::nullopt; } return *it; } }; using WellKnownTypesMap = absl::flat_hash_map; const WellKnownTypesMap& GetWellKnownTypesMap() { static const WellKnownTypesMap* types = []() -> WellKnownTypesMap* { WellKnownTypesMap* types = new WellKnownTypesMap(); types->insert_or_assign( "google.protobuf.BoolValue", WellKnownType{BoolWrapperType{}, {MakeBasicStructTypeField("value", BoolType{}, 1)}}); types->insert_or_assign( "google.protobuf.Int32Value", WellKnownType{IntWrapperType{}, {MakeBasicStructTypeField("value", IntType{}, 1)}}); types->insert_or_assign( "google.protobuf.Int64Value", WellKnownType{IntWrapperType{}, {MakeBasicStructTypeField("value", IntType{}, 1)}}); types->insert_or_assign( "google.protobuf.UInt32Value", WellKnownType{UintWrapperType{}, {MakeBasicStructTypeField("value", UintType{}, 1)}}); types->insert_or_assign( "google.protobuf.UInt64Value", WellKnownType{UintWrapperType{}, {MakeBasicStructTypeField("value", UintType{}, 1)}}); types->insert_or_assign( "google.protobuf.FloatValue", WellKnownType{DoubleWrapperType{}, {MakeBasicStructTypeField("value", DoubleType{}, 1)}}); types->insert_or_assign( "google.protobuf.DoubleValue", WellKnownType{DoubleWrapperType{}, {MakeBasicStructTypeField("value", DoubleType{}, 1)}}); types->insert_or_assign( "google.protobuf.StringValue", WellKnownType{StringWrapperType{}, {MakeBasicStructTypeField("value", StringType{}, 1)}}); types->insert_or_assign( "google.protobuf.BytesValue", WellKnownType{BytesWrapperType{}, {MakeBasicStructTypeField("value", BytesType{}, 1)}}); types->insert_or_assign( "google.protobuf.Duration", WellKnownType{DurationType{}, {MakeBasicStructTypeField("seconds", IntType{}, 1), MakeBasicStructTypeField("nanos", IntType{}, 2)}}); types->insert_or_assign( "google.protobuf.Timestamp", WellKnownType{TimestampType{}, {MakeBasicStructTypeField("seconds", IntType{}, 1), MakeBasicStructTypeField("nanos", IntType{}, 2)}}); types->insert_or_assign( "google.protobuf.Value", WellKnownType{ DynType{}, {// NullValue enum is an int. Not normally referenced directly. MakeBasicStructTypeField("null_value", IntType{}, 1), MakeBasicStructTypeField("number_value", DoubleType{}, 2), MakeBasicStructTypeField("string_value", StringType{}, 3), MakeBasicStructTypeField("bool_value", BoolType{}, 4), MakeBasicStructTypeField("struct_value", JsonMapType(), 5), MakeBasicStructTypeField("list_value", ListType{}, 6)}}); types->insert_or_assign( "google.protobuf.ListValue", WellKnownType{ListType{}, {MakeBasicStructTypeField("values", ListType{}, 1)}}); types->insert_or_assign( "google.protobuf.Struct", WellKnownType{JsonMapType(), {MakeBasicStructTypeField("fields", JsonMapType(), 1)}}); types->insert_or_assign( "google.protobuf.Any", WellKnownType{AnyType{}, {MakeBasicStructTypeField("type_url", StringType{}, 1), MakeBasicStructTypeField("value", BytesType{}, 2)}}); types->insert_or_assign("null_type", WellKnownType{NullType{}}); types->insert_or_assign("google.protobuf.NullValue", WellKnownType{NullType{}}); types->insert_or_assign("bool", WellKnownType{BoolType{}}); types->insert_or_assign("int", WellKnownType{IntType{}}); types->insert_or_assign("uint", WellKnownType{UintType{}}); types->insert_or_assign("double", WellKnownType{DoubleType{}}); types->insert_or_assign("bytes", WellKnownType{BytesType{}}); types->insert_or_assign("string", WellKnownType{StringType{}}); types->insert_or_assign("list", WellKnownType{ListType{}}); types->insert_or_assign("map", WellKnownType{MapType{}}); types->insert_or_assign("type", WellKnownType{TypeType{}}); return types; }(); return *types; } } // namespace absl::StatusOr> TypeIntrospector::FindTypeImpl( absl::string_view) const { return absl::nullopt; } absl::StatusOr> TypeIntrospector::FindEnumConstantImpl(absl::string_view, absl::string_view) const { return absl::nullopt; } absl::StatusOr> TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, absl::string_view) const { return absl::nullopt; } absl::StatusOr< absl::optional>> TypeIntrospector::ListFieldsForStructTypeImpl(absl::string_view) const { return absl::nullopt; } absl::optional FindWellKnownType(absl::string_view name) { const auto& well_known_types = GetWellKnownTypesMap(); if (auto it = well_known_types.find(name); it != well_known_types.end()) { return it->second.type; } return absl::nullopt; } absl::optional FindWellKnownTypeEnumConstant( absl::string_view type, absl::string_view value) { if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { return TypeIntrospector::EnumConstant{ IntType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; } return absl::nullopt; } absl::optional FindWellKnownTypeFieldByName( absl::string_view type, absl::string_view name) { const auto& well_known_types = GetWellKnownTypesMap(); if (auto it = well_known_types.find(type); it != well_known_types.end()) { return it->second.FieldByName(name); } return absl::nullopt; } absl::optional> ListFieldsForWellKnownType(absl::string_view type) { const auto& well_known_types = GetWellKnownTypesMap(); auto it = well_known_types.find(type); if (it == well_known_types.end()) { return absl::nullopt; } // The fields are not normally gettable. return {}; } } // namespace cel ================================================ FILE: common/type_introspector.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ #include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/type.h" namespace cel { // `TypeIntrospector` is an interface which allows querying type-related // information. It handles type introspection, but not type reflection. That is, // it is not capable of instantiating new values or understanding values. Its // primary usage is for type checking, and a subset of that shared functionality // is used by the runtime. class TypeIntrospector { public: struct EnumConstant { // The type of the enum. For JSON null, this may be a specific type rather // than an enum type. Type type; absl::string_view type_full_name; absl::string_view value_name; int32_t number; }; struct StructTypeFieldListing { // The name used to access the field in source CEL. // This is assumed owned by the TypeIntrospector or a dependency that // outlives it. absl::string_view name; // The field description. StructTypeField field; }; virtual ~TypeIntrospector() = default; // `FindType` find the type corresponding to name `name`. absl::StatusOr> FindType(absl::string_view name) const { return FindTypeImpl(name); } // `FindEnumConstant` find a fully qualified enumerator name `name` in enum // type `type`. absl::StatusOr> FindEnumConstant( absl::string_view type, absl::string_view value) const { return FindEnumConstantImpl(type, value); } // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in type `type`. absl::StatusOr> FindStructTypeFieldByName( absl::string_view type, absl::string_view name) const { return FindStructTypeFieldByNameImpl(type, name); } // `ListFieldsForStructType` returns the fields of struct type `type`. // // This is used when the struct is declared as a context type. // // If the type is not found, returns `absl::nullopt`. // If the type exists but is not a struct or has no fields, returns an empty // vector. absl::StatusOr>> ListFieldsForStructType(absl::string_view type) const { return ListFieldsForStructTypeImpl(type); } // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in struct type `type`. absl::StatusOr> FindStructTypeFieldByName( const StructType& type, absl::string_view name) const { return FindStructTypeFieldByName(type.name(), name); } protected: virtual absl::StatusOr> FindTypeImpl( absl::string_view name) const; virtual absl::StatusOr> FindEnumConstantImpl( absl::string_view type, absl::string_view value) const; virtual absl::StatusOr> FindStructTypeFieldByNameImpl(absl::string_view type, absl::string_view name) const; virtual absl::StatusOr>> ListFieldsForStructTypeImpl(absl::string_view type) const; }; // Looks up a well-known type by name. absl::optional FindWellKnownType(absl::string_view name); // Looks up a well-known enum constant by type and value. absl::optional FindWellKnownTypeEnumConstant( absl::string_view type, absl::string_view value); // Looks up a well-known struct type field by type and field name. absl::optional FindWellKnownTypeFieldByName( absl::string_view type, absl::string_view name); absl::optional> ListFieldsForWellKnownType(absl::string_view type); // `WellKnownTypeIntrospector` is an implementation of `TypeIntrospector` which // handles well known types that are treated specially by CEL. // // This also serves as a minimal implementation of a TypeInstrospector when no // custom types are present. // // This class has no mutable state, so trivially thread-safe. class WellKnownTypeIntrospector : public virtual TypeIntrospector { public: WellKnownTypeIntrospector() = default; private: absl::StatusOr> FindTypeImpl( absl::string_view name) const final { return FindWellKnownType(name); } absl::StatusOr> FindEnumConstantImpl( absl::string_view type, absl::string_view value) const final { return FindWellKnownTypeEnumConstant(type, value); } absl::StatusOr> FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const final { return FindWellKnownTypeFieldByName(type, name); } absl::StatusOr>> ListFieldsForStructTypeImpl(absl::string_view type) const final { return ListFieldsForWellKnownType(type); } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ ================================================ FILE: common/type_kind.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_KIND_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_KIND_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/strings/string_view.h" #include "common/kind.h" namespace cel { // `TypeKind` is a subset of `Kind`, representing all valid `Kind` for `Type`. // All `TypeKind` are valid `Kind`, but it is not guaranteed that all `Kind` are // valid `TypeKind`. enum class TypeKind : std::underlying_type_t { kNull = static_cast(Kind::kNull), kBool = static_cast(Kind::kBool), kInt = static_cast(Kind::kInt), kUint = static_cast(Kind::kUint), kDouble = static_cast(Kind::kDouble), kString = static_cast(Kind::kString), kBytes = static_cast(Kind::kBytes), kStruct = static_cast(Kind::kStruct), kDuration = static_cast(Kind::kDuration), kTimestamp = static_cast(Kind::kTimestamp), kList = static_cast(Kind::kList), kMap = static_cast(Kind::kMap), kUnknown = static_cast(Kind::kUnknown), kType = static_cast(Kind::kType), kError = static_cast(Kind::kError), kAny = static_cast(Kind::kAny), kDyn = static_cast(Kind::kDyn), kOpaque = static_cast(Kind::kOpaque), kBoolWrapper = static_cast(Kind::kBoolWrapper), kIntWrapper = static_cast(Kind::kIntWrapper), kUintWrapper = static_cast(Kind::kUintWrapper), kDoubleWrapper = static_cast(Kind::kDoubleWrapper), kStringWrapper = static_cast(Kind::kStringWrapper), kBytesWrapper = static_cast(Kind::kBytesWrapper), kTypeParam = static_cast(Kind::kTypeParam), kFunction = static_cast(Kind::kFunction), kEnum = static_cast(Kind::kEnum), // Legacy aliases, deprecated do not use. kNullType = kNull, kInt64 = kInt, kUint64 = kUint, kMessage = kStruct, kUnknownSet = kUnknown, kCelType = kType, // INTERNAL: Do not exceed 63. Implementation details rely on the fact that // we can store `Kind` using 6 bits. kNotForUseWithExhaustiveSwitchStatements = static_cast(Kind::kNotForUseWithExhaustiveSwitchStatements), }; constexpr Kind TypeKindToKind(TypeKind kind) { return static_cast(static_cast>(kind)); } constexpr bool KindIsTypeKind(Kind kind ABSL_ATTRIBUTE_UNUSED) { // Currently all Kind are valid TypeKind. return true; } constexpr bool operator==(Kind lhs, TypeKind rhs) { return lhs == TypeKindToKind(rhs); } constexpr bool operator==(TypeKind lhs, Kind rhs) { return TypeKindToKind(lhs) == rhs; } constexpr bool operator!=(Kind lhs, TypeKind rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(TypeKind lhs, Kind rhs) { return !operator==(lhs, rhs); } inline absl::string_view TypeKindToString(TypeKind kind) { // All TypeKind are valid Kind. return KindToString(TypeKindToKind(kind)); } constexpr TypeKind KindToTypeKind(Kind kind) { ABSL_ASSERT(KindIsTypeKind(kind)); return static_cast(static_cast>(kind)); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_KIND_H_ ================================================ FILE: common/type_proto.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type_proto.h" #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { using ::google::protobuf::NullValue; using TypePb = cel::expr::Type; // filter well-known types from message types. absl::optional MaybeWellKnownType(absl::string_view type_name) { static const absl::flat_hash_map* kWellKnownTypes = []() { auto* instance = new absl::flat_hash_map{ // keep-sorted start {"google.protobuf.Any", AnyType()}, {"google.protobuf.BoolValue", BoolWrapperType()}, {"google.protobuf.BytesValue", BytesWrapperType()}, {"google.protobuf.DoubleValue", DoubleWrapperType()}, {"google.protobuf.Duration", DurationType()}, {"google.protobuf.FloatValue", DoubleWrapperType()}, {"google.protobuf.Int32Value", IntWrapperType()}, {"google.protobuf.Int64Value", IntWrapperType()}, {"google.protobuf.ListValue", ListType()}, {"google.protobuf.StringValue", StringWrapperType()}, {"google.protobuf.Struct", JsonMapType()}, {"google.protobuf.Timestamp", TimestampType()}, {"google.protobuf.UInt32Value", UintWrapperType()}, {"google.protobuf.UInt64Value", UintWrapperType()}, {"google.protobuf.Value", DynType()}, // keep-sorted end }; return instance; }(); if (auto it = kWellKnownTypes->find(type_name); it != kWellKnownTypes->end()) { return it->second; } return absl::nullopt; } absl::Status TypeToProtoInternal(const cel::Type& type, TypePb* absl_nonnull type_pb); absl::Status ToProtoAbstractType(const cel::OpaqueType& type, TypePb* absl_nonnull type_pb) { auto* abstract_type = type_pb->mutable_abstract_type(); abstract_type->set_name(type.name()); abstract_type->mutable_parameter_types()->Reserve( type.GetParameters().size()); for (const auto& param : type.GetParameters()) { CEL_RETURN_IF_ERROR( TypeToProtoInternal(param, abstract_type->add_parameter_types())); } return absl::OkStatus(); } absl::Status ToProtoMapType(const cel::MapType& type, TypePb* absl_nonnull type_pb) { auto* map_type = type_pb->mutable_map_type(); CEL_RETURN_IF_ERROR( TypeToProtoInternal(type.key(), map_type->mutable_key_type())); CEL_RETURN_IF_ERROR( TypeToProtoInternal(type.value(), map_type->mutable_value_type())); return absl::OkStatus(); } absl::Status ToProtoListType(const cel::ListType& type, TypePb* absl_nonnull type_pb) { auto* list_type = type_pb->mutable_list_type(); CEL_RETURN_IF_ERROR( TypeToProtoInternal(type.element(), list_type->mutable_elem_type())); return absl::OkStatus(); } absl::Status ToProtoTypeType(const cel::TypeType& type, TypePb* absl_nonnull type_pb) { if (type.GetParameters().size() > 1) { return absl::InternalError( absl::StrCat("unsupported type: ", type.DebugString())); } auto* type_type = type_pb->mutable_type(); if (type.GetParameters().empty()) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(TypeToProtoInternal(type.GetParameters()[0], type_type)); return absl::OkStatus(); } absl::Status TypeToProtoInternal(const cel::Type& type, TypePb* absl_nonnull type_pb) { switch (type.kind()) { case TypeKind::kDyn: type_pb->mutable_dyn(); return absl::OkStatus(); case TypeKind::kError: type_pb->mutable_error(); return absl::OkStatus(); case TypeKind::kNull: type_pb->set_null(NullValue::NULL_VALUE); return absl::OkStatus(); case TypeKind::kBool: type_pb->set_primitive(TypePb::BOOL); return absl::OkStatus(); case TypeKind::kInt: type_pb->set_primitive(TypePb::INT64); return absl::OkStatus(); case TypeKind::kUint: type_pb->set_primitive(TypePb::UINT64); return absl::OkStatus(); case TypeKind::kDouble: type_pb->set_primitive(TypePb::DOUBLE); return absl::OkStatus(); case TypeKind::kString: type_pb->set_primitive(TypePb::STRING); return absl::OkStatus(); case TypeKind::kBytes: type_pb->set_primitive(TypePb::BYTES); return absl::OkStatus(); case TypeKind::kEnum: type_pb->set_primitive(TypePb::INT64); return absl::OkStatus(); case TypeKind::kDuration: type_pb->set_well_known(TypePb::DURATION); return absl::OkStatus(); case TypeKind::kTimestamp: type_pb->set_well_known(TypePb::TIMESTAMP); return absl::OkStatus(); case TypeKind::kStruct: type_pb->set_message_type(type.GetStruct().name()); return absl::OkStatus(); case TypeKind::kList: return ToProtoListType(type.GetList(), type_pb); case TypeKind::kMap: return ToProtoMapType(type.GetMap(), type_pb); case TypeKind::kOpaque: return ToProtoAbstractType(type.GetOpaque(), type_pb); case TypeKind::kBoolWrapper: type_pb->set_wrapper(TypePb::BOOL); return absl::OkStatus(); case TypeKind::kIntWrapper: type_pb->set_wrapper(TypePb::INT64); return absl::OkStatus(); case TypeKind::kUintWrapper: type_pb->set_wrapper(TypePb::UINT64); return absl::OkStatus(); case TypeKind::kDoubleWrapper: type_pb->set_wrapper(TypePb::DOUBLE); return absl::OkStatus(); case TypeKind::kStringWrapper: type_pb->set_wrapper(TypePb::STRING); return absl::OkStatus(); case TypeKind::kBytesWrapper: type_pb->set_wrapper(TypePb::BYTES); return absl::OkStatus(); case TypeKind::kTypeParam: type_pb->set_type_param(type.GetTypeParam().name()); return absl::OkStatus(); case TypeKind::kType: return ToProtoTypeType(type.GetType(), type_pb); case TypeKind::kAny: type_pb->set_well_known(TypePb::ANY); return absl::OkStatus(); default: return absl::InternalError( absl::StrCat("unsupported type: ", type.DebugString())); } } } // namespace absl::StatusOr TypeFromProto( const cel::expr::Type& type_pb, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { switch (type_pb.type_kind_case()) { case TypePb::kAbstractType: { auto* name = google::protobuf::Arena::Create( arena, type_pb.abstract_type().name()); std::vector params; params.resize(type_pb.abstract_type().parameter_types_size()); size_t i = 0; for (const auto& p : type_pb.abstract_type().parameter_types()) { CEL_ASSIGN_OR_RETURN(params[i], TypeFromProto(p, descriptor_pool, arena)); i++; } return OpaqueType(arena, *name, params); } case TypePb::kDyn: return DynType(); case TypePb::kError: return ErrorType(); case TypePb::kListType: { CEL_ASSIGN_OR_RETURN(Type element, TypeFromProto(type_pb.list_type().elem_type(), descriptor_pool, arena)); return ListType(arena, element); } case TypePb::kMapType: { CEL_ASSIGN_OR_RETURN( Type key, TypeFromProto(type_pb.map_type().key_type(), descriptor_pool, arena)); CEL_ASSIGN_OR_RETURN(Type value, TypeFromProto(type_pb.map_type().value_type(), descriptor_pool, arena)); return MapType(arena, key, value); } case TypePb::kMessageType: { if (auto well_known = MaybeWellKnownType(type_pb.message_type()); well_known.has_value()) { return *well_known; } const auto* descriptor = descriptor_pool->FindMessageTypeByName(type_pb.message_type()); if (descriptor == nullptr) { return absl::InvalidArgumentError( absl::StrCat("unknown message type: ", type_pb.message_type())); } return MessageType(descriptor); } case TypePb::kNull: return NullType(); case TypePb::kPrimitive: switch (type_pb.primitive()) { case TypePb::BOOL: return BoolType(); case TypePb::BYTES: return BytesType(); case TypePb::DOUBLE: return DoubleType(); case TypePb::INT64: return IntType(); case TypePb::STRING: return StringType(); case TypePb::UINT64: return UintType(); default: return absl::InvalidArgumentError("unknown primitive kind"); } case TypePb::kType: { CEL_ASSIGN_OR_RETURN( Type nested, TypeFromProto(type_pb.type(), descriptor_pool, arena)); return TypeType(arena, nested); } case TypePb::kTypeParam: { auto* name = google::protobuf::Arena::Create(arena, type_pb.type_param()); return TypeParamType(*name); } case TypePb::kWellKnown: switch (type_pb.well_known()) { case TypePb::ANY: return AnyType(); case TypePb::DURATION: return DurationType(); case TypePb::TIMESTAMP: return TimestampType(); default: break; } return absl::InvalidArgumentError("unknown well known type."); case TypePb::kWrapper: { switch (type_pb.wrapper()) { case TypePb::BOOL: return BoolWrapperType(); case TypePb::BYTES: return BytesWrapperType(); case TypePb::DOUBLE: return DoubleWrapperType(); case TypePb::INT64: return IntWrapperType(); case TypePb::STRING: return StringWrapperType(); case TypePb::UINT64: return UintWrapperType(); default: return absl::InvalidArgumentError("unknown primitive wrapper kind"); } } // Function types are not supported in the C++ type checker. case TypePb::kFunction: default: return absl::InvalidArgumentError( absl::StrCat("unsupported type kind: ", type_pb.type_kind_case())); } } absl::Status TypeToProto(const Type& type, TypePb* absl_nonnull type_pb) { return TypeToProtoInternal(type, type_pb); } } // namespace cel ================================================ FILE: common/type_proto.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ #include "cel/expr/checked.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/type.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { // Creates a Type from a google.api.expr.Type proto. absl::StatusOr TypeFromProto( const cel::expr::Type& type_pb, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena); absl::Status TypeToProto(const Type& type, cel::expr::Type* absl_nonnull type_pb); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_PROTO_H_ ================================================ FILE: common/type_proto_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type_proto.h" #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/proto_matchers.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/text_format.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::internal::test::EqualsProto; enum class RoundTrip { kYes, kNo, }; struct TestCase { std::string type_pb; absl::StatusOr type_kind; RoundTrip round_trip = RoundTrip::kYes; }; class TypeFromProtoTest : public ::testing::TestWithParam {}; TEST_P(TypeFromProtoTest, FromProtoWorks) { const google::protobuf::DescriptorPool* descriptor_pool = internal::GetTestingDescriptorPool(); google::protobuf::Arena arena; const TestCase& test_case = GetParam(); cel::expr::Type type_pb; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(test_case.type_pb, &type_pb)); absl::StatusOr result = TypeFromProto(type_pb, descriptor_pool, &arena); if (test_case.type_kind.ok()) { ASSERT_OK_AND_ASSIGN(Type type, result); EXPECT_EQ(type.kind(), *test_case.type_kind) << absl::StrCat("got: ", type.DebugString(), " want: ", TypeKindToString(*test_case.type_kind)); } else { EXPECT_THAT(result, StatusIs(test_case.type_kind.status().code())); } } TEST_P(TypeFromProtoTest, RoundTripProtoWorks) { const google::protobuf::DescriptorPool* descriptor_pool = internal::GetTestingDescriptorPool(); google::protobuf::Arena arena; const TestCase& test_case = GetParam(); if (!test_case.type_kind.ok() || test_case.round_trip == RoundTrip::kNo) { return GTEST_SUCCEED(); } cel::expr::Type type_pb; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(test_case.type_pb, &type_pb)); absl::StatusOr result = TypeFromProto(type_pb, descriptor_pool, &arena); ASSERT_THAT(test_case.type_kind, IsOk()); ASSERT_OK_AND_ASSIGN(Type type, result); EXPECT_EQ(type.kind(), *test_case.type_kind) << absl::StrCat("got: ", type.DebugString(), " want: ", TypeKindToString(*test_case.type_kind)); cel::expr::Type round_trip_pb; ASSERT_THAT(TypeToProto(type, &round_trip_pb), IsOk()); EXPECT_THAT(round_trip_pb, EqualsProto(type_pb)); } INSTANTIATE_TEST_SUITE_P( TypeFromProtoTest, TypeFromProtoTest, testing::Values( TestCase{ R"pb( abstract_type { name: "foo" parameter_types { primitive: INT64 } parameter_types { primitive: STRING } } )pb", TypeKind::kOpaque}, TestCase{R"pb( dyn {} )pb", TypeKind::kDyn}, TestCase{R"pb( error {} )pb", TypeKind::kError}, TestCase{R"pb( list_type { elem_type { primitive: INT64 } } )pb", TypeKind::kList}, TestCase{R"pb( map_type { key_type { primitive: INT64 } value_type { primitive: STRING } } )pb", TypeKind::kMap}, TestCase{R"pb( message_type: "google.api.expr.runtime.TestExtensions" )pb", TypeKind::kMessage}, TestCase{R"pb( message_type: "com.example.UnknownMessage" )pb", absl::InvalidArgumentError("")}, // Special-case well known types referenced by // equivalent proto message types. TestCase{R"pb( message_type: "google.protobuf.Any" )pb", TypeKind::kAny, RoundTrip::kNo}, TestCase{R"pb( message_type: "google.protobuf.Timestamp" )pb", TypeKind::kTimestamp, RoundTrip::kNo}, TestCase{R"pb( message_type: "google.protobuf.Duration" )pb", TypeKind::kDuration, RoundTrip::kNo}, TestCase{R"pb( message_type: "google.protobuf.Struct" )pb", TypeKind::kMap, RoundTrip::kNo}, TestCase{R"pb( message_type: "google.protobuf.ListValue" )pb", TypeKind::kList, RoundTrip::kNo}, TestCase{R"pb( message_type: "google.protobuf.Value" )pb", TypeKind::kDyn, RoundTrip::kNo}, TestCase{R"pb( message_type: "google.protobuf.Int64Value" )pb", TypeKind::kIntWrapper, RoundTrip::kNo}, TestCase{R"pb( null: 0 )pb", TypeKind::kNull}, TestCase{ R"pb( primitive: BOOL)pb", TypeKind::kBool}, TestCase{ R"pb( primitive: BYTES)pb", TypeKind::kBytes}, TestCase{ R"pb( primitive: DOUBLE)pb", TypeKind::kDouble}, TestCase{ R"pb( primitive: INT64)pb", TypeKind::kInt}, TestCase{ R"pb( primitive: STRING)pb", TypeKind::kString}, TestCase{ R"pb( primitive: UINT64)pb", TypeKind::kUint}, TestCase{ R"pb( primitive: PRIMITIVE_TYPE_UNSPECIFIED)pb", absl::InvalidArgumentError("")}, TestCase{ R"pb( type { type { primitive: UINT64 } })pb", TypeKind::kType}, TestCase{ R"pb( type_param: "T")pb", TypeKind::kTypeParam}, TestCase{ R"pb( well_known: ANY)pb", TypeKind::kAny}, TestCase{ R"pb( well_known: TIMESTAMP)pb", TypeKind::kTimestamp}, TestCase{ R"pb( well_known: DURATION)pb", TypeKind::kDuration}, TestCase{ R"pb( well_known: WELL_KNOWN_TYPE_UNSPECIFIED)pb", absl::InvalidArgumentError("")}, TestCase{ R"pb( wrapper: BOOL )pb", TypeKind::kBoolWrapper}, TestCase{ R"pb( wrapper: BYTES )pb", TypeKind::kBytesWrapper}, TestCase{ R"pb( wrapper: DOUBLE )pb", TypeKind::kDoubleWrapper}, TestCase{ R"pb( wrapper: INT64 )pb", TypeKind::kIntWrapper}, TestCase{ R"pb( wrapper: STRING )pb", TypeKind::kStringWrapper}, TestCase{ R"pb( wrapper: UINT64 )pb", TypeKind::kUintWrapper}, TestCase{ R"pb( wrapper: PRIMITIVE_TYPE_UNSPECIFIED )pb", absl::InvalidArgumentError("")}, TestCase{ R"pb( function { result_type { primitive: BOOL } arg_types { primitive: INT64 } arg_types { primitive: STRING } })pb", absl::InvalidArgumentError("")})); } // namespace } // namespace cel ================================================ FILE: common/type_reflector.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/type_introspector.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel { // `TypeReflector` is an interface for constructing new instances of types are // runtime. It handles type reflection. class TypeReflector : public virtual TypeIntrospector { public: // `NewValueBuilder` returns a new `ValueBuilder` for the corresponding type // `name`. It is primarily used to handle wrapper types which sometimes show // up literally in expressions. virtual absl::StatusOr NewValueBuilder( absl::string_view name, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const = 0; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_REFLECTOR_H_ ================================================ FILE: common/type_reflector_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "common/casting.h" #include "common/value.h" #include "common/value_testing.h" #include "common/values/list_value.h" #include "common/values/value_builder.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" namespace cel { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::ErrorValueIs; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Not; using ::testing::NotNull; using ::testing::Optional; using TypeReflectorTest = common_internal::ValueTest<>; #define TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(element_type) \ TEST_F(TypeReflectorTest, NewListValueBuilder_##element_type) { \ auto list_value_builder = NewListValueBuilder(arena()); \ EXPECT_TRUE(list_value_builder->IsEmpty()); \ EXPECT_EQ(list_value_builder->Size(), 0); \ auto list_value = std::move(*list_value_builder).Build(); \ EXPECT_THAT(list_value.IsEmpty(), IsOkAndHolds(true)); \ EXPECT_THAT(list_value.Size(), IsOkAndHolds(0)); \ EXPECT_EQ(list_value.DebugString(), "[]"); \ } TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(BoolType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(BytesType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(DoubleType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(DurationType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(IntType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(ListType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(MapType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(NullType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(OptionalType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(StringType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(TimestampType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(TypeType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(UintType) TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST(DynType) #undef TYPE_REFLECTOR_NEW_LIST_VALUE_BUILDER_TEST #define TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(key_type, value_type) \ TEST_F(TypeReflectorTest, NewMapValueBuilder_##key_type##_##value_type) { \ auto map_value_builder = NewMapValueBuilder(arena()); \ EXPECT_TRUE(map_value_builder->IsEmpty()); \ EXPECT_EQ(map_value_builder->Size(), 0); \ auto map_value = std::move(*map_value_builder).Build(); \ EXPECT_THAT(map_value.IsEmpty(), IsOkAndHolds(true)); \ EXPECT_THAT(map_value.Size(), IsOkAndHolds(0)); \ EXPECT_EQ(map_value.DebugString(), "{}"); \ } TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, BoolType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, BytesType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, DoubleType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, DurationType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, IntType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, ListType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, MapType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, NullType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, OptionalType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, StringType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, TimestampType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, TypeType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, UintType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(BoolType, DynType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, BoolType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, BytesType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, DoubleType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, DurationType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, IntType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, ListType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, MapType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, NullType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, OptionalType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, StringType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, TimestampType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, TypeType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, UintType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(IntType, DynType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, BoolType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, BytesType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, DoubleType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, DurationType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, IntType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, ListType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, MapType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, NullType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, OptionalType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, StringType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, TimestampType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, TypeType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, UintType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(UintType, DynType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, BoolType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, BytesType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, DoubleType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, DurationType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, IntType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, ListType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, MapType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, NullType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, OptionalType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, StringType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, TimestampType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, TypeType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, UintType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(StringType, DynType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, BoolType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, BytesType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, DoubleType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, DurationType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, IntType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, ListType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, MapType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, NullType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, OptionalType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, StringType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, TimestampType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, TypeType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, UintType) TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST(DynType, DynType) #undef TYPE_REFLECTOR_NEW_MAP_VALUE_BUILDER_TEST TEST_F(TypeReflectorTest, NewListValueBuilderCoverage_Dynamic) { auto builder = NewListValueBuilder(arena()); EXPECT_OK(builder->Add(IntValue(0))); EXPECT_OK(builder->Add(IntValue(1))); EXPECT_OK(builder->Add(IntValue(2))); EXPECT_EQ(builder->Size(), 3); EXPECT_FALSE(builder->IsEmpty()); auto value = std::move(*builder).Build(); EXPECT_EQ(value.DebugString(), "[0, 1, 2]"); } TEST_F(TypeReflectorTest, NewMapValueBuilderCoverage_DynamicDynamic) { auto builder = NewMapValueBuilder(arena()); EXPECT_OK(builder->Put(BoolValue(false), IntValue(1))); EXPECT_OK(builder->Put(BoolValue(true), IntValue(2))); EXPECT_OK(builder->Put(IntValue(0), IntValue(3))); EXPECT_OK(builder->Put(IntValue(1), IntValue(4))); EXPECT_OK(builder->Put(UintValue(0), IntValue(5))); EXPECT_OK(builder->Put(UintValue(1), IntValue(6))); EXPECT_OK(builder->Put(StringValue("a"), IntValue(7))); EXPECT_OK(builder->Put(StringValue("b"), IntValue(8))); EXPECT_EQ(builder->Size(), 8); EXPECT_FALSE(builder->IsEmpty()); auto value = std::move(*builder).Build(); EXPECT_THAT(value.DebugString(), Not(IsEmpty())); } TEST_F(TypeReflectorTest, NewMapValueBuilderCoverage_StaticDynamic) { auto builder = NewMapValueBuilder(arena()); EXPECT_OK(builder->Put(BoolValue(true), IntValue(0))); EXPECT_EQ(builder->Size(), 1); EXPECT_FALSE(builder->IsEmpty()); auto value = std::move(*builder).Build(); EXPECT_EQ(value.DebugString(), "{true: 0}"); } TEST_F(TypeReflectorTest, NewMapValueBuilderCoverage_DynamicStatic) { auto builder = NewMapValueBuilder(arena()); EXPECT_OK(builder->Put(BoolValue(true), IntValue(0))); EXPECT_EQ(builder->Size(), 1); EXPECT_FALSE(builder->IsEmpty()); auto value = std::move(*builder).Build(); EXPECT_EQ(value.DebugString(), "{true: 0}"); } TEST_F(TypeReflectorTest, NewValueBuilder_BoolValue) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.BoolValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), true); } TEST_F(TypeReflectorTest, NewValueBuilder_Int32Value) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.Int32Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName( "value", IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber( 1, IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_F(TypeReflectorTest, NewValueBuilder_Int64Value) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.Int64Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_F(TypeReflectorTest, NewValueBuilder_UInt32Value) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.UInt32Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName( "value", UintValue(std::numeric_limits::max())), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByNumber(1, UintValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber( 1, UintValue(std::numeric_limits::max())), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_F(TypeReflectorTest, NewValueBuilder_UInt64Value) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.UInt64Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, UintValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_F(TypeReflectorTest, NewValueBuilder_FloatValue) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.FloatValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, DoubleValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_F(TypeReflectorTest, NewValueBuilder_DoubleValue) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.DoubleValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, DoubleValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), 1); } TEST_F(TypeReflectorTest, NewValueBuilder_StringValue) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.StringValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", StringValue("foo")), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", StringValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, StringValue("foo")), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, StringValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeString(), "foo"); } TEST_F(TypeReflectorTest, NewValueBuilder_BytesValue) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.BytesValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BytesValue("foo")), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BytesValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, BytesValue("foo")), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeString(), "foo"); } TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.Duration"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("seconds", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("nanos", IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName( "nanos", IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByName("nanos", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber( 2, IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), absl::Seconds(1) + absl::Nanoseconds(1)); } TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.Timestamp"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("seconds", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("nanos", IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName( "nanos", IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByName("nanos", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber( 2, IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)); } TEST_F(TypeReflectorTest, NewValueBuilder_Any) { auto builder = common_internal::NewValueBuilder( arena(), internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), "google.protobuf.Any"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName( "type_url", StringValue("type.googleapis.com/google.protobuf.BoolValue")), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByName("type_url", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("value", BytesValue()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT( builder->SetFieldByNumber( 1, StringValue("type.googleapis.com/google.protobuf.BoolValue")), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); ASSERT_OK_AND_ASSIGN(auto value, std::move(*builder).Build()); EXPECT_TRUE(InstanceOf(value)); EXPECT_EQ(Cast(value).NativeValue(), false); } } // namespace } // namespace cel ================================================ FILE: common/type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type.h" #include "absl/hash/hash.h" #include "absl/hash/hash_testing.h" #include "absl/log/die_if_null.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::cel::internal::GetTestingDescriptorPool; using ::testing::An; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Optional; TEST(Type, Default) { EXPECT_EQ(Type(), DynType()); EXPECT_TRUE(Type().IsDyn()); } TEST(Type, Enum) { EXPECT_EQ( Type::Enum( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"))), EnumType(ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))); EXPECT_EQ(Type::Enum( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "google.protobuf.NullValue"))), IntType()); } TEST(Type, Field) { google::protobuf::Arena arena; const auto* descriptor = ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")); EXPECT_EQ( Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_bool"))), BoolType()); EXPECT_EQ( Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("null_value"))), IntType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_int32"))), IntType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_sint32"))), IntType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_sfixed32"))), IntType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_int64"))), IntType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_sint64"))), IntType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_sfixed64"))), IntType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_fixed32"))), UintType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_uint32"))), UintType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_fixed64"))), UintType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_uint64"))), UintType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_float"))), DoubleType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_double"))), DoubleType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_bytes"))), BytesType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_string"))), StringType()); EXPECT_EQ( Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_any"))), AnyType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_duration"))), DurationType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_timestamp"))), TimestampType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_struct"))), JsonMapType()); EXPECT_EQ( Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("list_value"))), JsonListType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_value"))), JsonType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_bool_wrapper"))), BoolWrapperType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_int32_wrapper"))), IntWrapperType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_int64_wrapper"))), IntWrapperType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_uint32_wrapper"))), UintWrapperType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_uint64_wrapper"))), UintWrapperType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_float_wrapper"))), DoubleWrapperType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_double_wrapper"))), DoubleWrapperType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_bytes_wrapper"))), BytesWrapperType()); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("single_string_wrapper"))), StringWrapperType()); EXPECT_EQ( Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("standalone_enum"))), EnumType(ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("repeated_int32"))), ListType(&arena, IntType())); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( descriptor->FindFieldByName("map_int32_int32"))), MapType(&arena, IntType(), IntType())); } TEST(Type, Kind) { google::protobuf::Arena arena; EXPECT_EQ(Type(AnyType()).kind(), AnyType::kKind); EXPECT_EQ(Type(BoolType()).kind(), BoolType::kKind); EXPECT_EQ(Type(BoolWrapperType()).kind(), BoolWrapperType::kKind); EXPECT_EQ(Type(BytesType()).kind(), BytesType::kKind); EXPECT_EQ(Type(BytesWrapperType()).kind(), BytesWrapperType::kKind); EXPECT_EQ(Type(DoubleType()).kind(), DoubleType::kKind); EXPECT_EQ(Type(DoubleWrapperType()).kind(), DoubleWrapperType::kKind); EXPECT_EQ(Type(DurationType()).kind(), DurationType::kKind); EXPECT_EQ(Type(DynType()).kind(), DynType::kKind); EXPECT_EQ( Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) .kind(), EnumType::kKind); EXPECT_EQ(Type(ErrorType()).kind(), ErrorType::kKind); EXPECT_EQ(Type(FunctionType(&arena, DynType(), {})).kind(), FunctionType::kKind); EXPECT_EQ(Type(IntType()).kind(), IntType::kKind); EXPECT_EQ(Type(IntWrapperType()).kind(), IntWrapperType::kKind); EXPECT_EQ(Type(ListType()).kind(), ListType::kKind); EXPECT_EQ(Type(MapType()).kind(), MapType::kKind); EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")))) .kind(), MessageType::kKind); EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")))) .kind(), MessageType::kKind); EXPECT_EQ(Type(NullType()).kind(), NullType::kKind); EXPECT_EQ(Type(OptionalType()).kind(), OpaqueType::kKind); EXPECT_EQ(Type(StringType()).kind(), StringType::kKind); EXPECT_EQ(Type(StringWrapperType()).kind(), StringWrapperType::kKind); EXPECT_EQ(Type(TimestampType()).kind(), TimestampType::kKind); EXPECT_EQ(Type(UintType()).kind(), UintType::kKind); EXPECT_EQ(Type(UintWrapperType()).kind(), UintWrapperType::kKind); EXPECT_EQ(Type(UnknownType()).kind(), UnknownType::kKind); } TEST(Type, GetParameters) { google::protobuf::Arena arena; EXPECT_THAT(Type(AnyType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(BoolType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(BoolWrapperType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(BytesType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(BytesWrapperType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(DoubleType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(DoubleWrapperType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(DurationType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(DynType()).GetParameters(), IsEmpty()); EXPECT_THAT( Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) .GetParameters(), IsEmpty()); EXPECT_THAT(Type(ErrorType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(FunctionType(&arena, DynType(), {IntType(), StringType(), DynType()})) .GetParameters(), ElementsAre(DynType(), IntType(), StringType(), DynType())); EXPECT_THAT(Type(IntType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(IntWrapperType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(ListType()).GetParameters(), ElementsAre(DynType())); EXPECT_THAT(Type(MapType()).GetParameters(), ElementsAre(DynType(), DynType())); EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")))) .GetParameters(), IsEmpty()); EXPECT_THAT(Type(NullType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(OptionalType()).GetParameters(), ElementsAre(DynType())); EXPECT_THAT(Type(StringType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(StringWrapperType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(TimestampType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(UintType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(UintWrapperType()).GetParameters(), IsEmpty()); EXPECT_THAT(Type(UnknownType()).GetParameters(), IsEmpty()); } TEST(Type, Is) { google::protobuf::Arena arena; EXPECT_TRUE(Type(AnyType()).Is()); EXPECT_TRUE(Type(BoolType()).Is()); EXPECT_TRUE(Type(BoolWrapperType()).Is()); EXPECT_TRUE(Type(BoolWrapperType()).IsWrapper()); EXPECT_TRUE(Type(BytesType()).Is()); EXPECT_TRUE(Type(BytesWrapperType()).Is()); EXPECT_TRUE(Type(BytesWrapperType()).IsWrapper()); EXPECT_TRUE(Type(DoubleType()).Is()); EXPECT_TRUE(Type(DoubleWrapperType()).Is()); EXPECT_TRUE(Type(DoubleWrapperType()).IsWrapper()); EXPECT_TRUE(Type(DurationType()).Is()); EXPECT_TRUE(Type(DynType()).Is()); EXPECT_TRUE( Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) .Is()); EXPECT_TRUE(Type(ErrorType()).Is()); EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is()); EXPECT_TRUE(Type(IntType()).Is()); EXPECT_TRUE(Type(IntWrapperType()).Is()); EXPECT_TRUE(Type(IntWrapperType()).IsWrapper()); EXPECT_TRUE(Type(ListType()).Is()); EXPECT_TRUE(Type(MapType()).Is()); EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")))) .IsStruct()); EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")))) .IsMessage()); EXPECT_TRUE(Type(NullType()).Is()); EXPECT_TRUE(Type(OptionalType()).Is()); EXPECT_TRUE(Type(OptionalType()).Is()); EXPECT_TRUE(Type(StringType()).Is()); EXPECT_TRUE(Type(StringWrapperType()).Is()); EXPECT_TRUE(Type(StringWrapperType()).IsWrapper()); EXPECT_TRUE(Type(TimestampType()).Is()); EXPECT_TRUE(Type(TypeType()).Is()); EXPECT_TRUE(Type(TypeParamType("T")).Is()); EXPECT_TRUE(Type(UintType()).Is()); EXPECT_TRUE(Type(UintWrapperType()).Is()); EXPECT_TRUE(Type(UintWrapperType()).IsWrapper()); EXPECT_TRUE(Type(UnknownType()).Is()); } TEST(Type, As) { google::protobuf::Arena arena; EXPECT_THAT(Type(AnyType()).As(), Optional(An())); EXPECT_THAT(Type(BoolType()).As(), Optional(An())); EXPECT_THAT(Type(BoolWrapperType()).As(), Optional(An())); EXPECT_THAT(Type(BytesType()).As(), Optional(An())); EXPECT_THAT(Type(BytesWrapperType()).As(), Optional(An())); EXPECT_THAT(Type(DoubleType()).As(), Optional(An())); EXPECT_THAT(Type(DoubleWrapperType()).As(), Optional(An())); EXPECT_THAT(Type(DurationType()).As(), Optional(An())); EXPECT_THAT(Type(DynType()).As(), Optional(An())); EXPECT_THAT( Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum")))) .As(), Optional(An())); EXPECT_THAT(Type(ErrorType()).As(), Optional(An())); EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is()); EXPECT_THAT(Type(IntType()).As(), Optional(An())); EXPECT_THAT(Type(IntWrapperType()).As(), Optional(An())); EXPECT_THAT(Type(ListType()).As(), Optional(An())); EXPECT_THAT(Type(MapType()).As(), Optional(An())); EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")))) .As(), Optional(An())); EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")))) .As(), Optional(An())); EXPECT_THAT(Type(NullType()).As(), Optional(An())); EXPECT_THAT(Type(OptionalType()).As(), Optional(An())); EXPECT_THAT(Type(OptionalType()).As(), Optional(An())); EXPECT_THAT(Type(StringType()).As(), Optional(An())); EXPECT_THAT(Type(StringWrapperType()).As(), Optional(An())); EXPECT_THAT(Type(TimestampType()).As(), Optional(An())); EXPECT_THAT(Type(TypeType()).As(), Optional(An())); EXPECT_THAT(Type(TypeParamType("T")).As(), Optional(An())); EXPECT_THAT(Type(UintType()).As(), Optional(An())); EXPECT_THAT(Type(UintWrapperType()).As(), Optional(An())); EXPECT_THAT(Type(UnknownType()).As(), Optional(An())); } template T DoGet(const Type& type) { return type.template Get(); } TEST(Type, Get) { google::protobuf::Arena arena; EXPECT_THAT(DoGet(Type(AnyType())), An()); EXPECT_THAT(DoGet(Type(BoolType())), An()); EXPECT_THAT(DoGet(Type(BoolWrapperType())), An()); EXPECT_THAT(DoGet(Type(BoolWrapperType())), An()); EXPECT_THAT(DoGet(Type(BytesType())), An()); EXPECT_THAT(DoGet(Type(BytesWrapperType())), An()); EXPECT_THAT(DoGet(Type(BytesWrapperType())), An()); EXPECT_THAT(DoGet(Type(DoubleType())), An()); EXPECT_THAT(DoGet(Type(DoubleWrapperType())), An()); EXPECT_THAT(DoGet(Type(DoubleWrapperType())), An()); EXPECT_THAT(DoGet(Type(DurationType())), An()); EXPECT_THAT(DoGet(Type(DynType())), An()); EXPECT_THAT( DoGet(Type(EnumType( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"))))), An()); EXPECT_THAT(DoGet(Type(ErrorType())), An()); EXPECT_THAT(DoGet(Type(FunctionType(&arena, DynType(), {}))), An()); EXPECT_THAT(DoGet(Type(IntType())), An()); EXPECT_THAT(DoGet(Type(IntWrapperType())), An()); EXPECT_THAT(DoGet(Type(IntWrapperType())), An()); EXPECT_THAT(DoGet(Type(ListType())), An()); EXPECT_THAT(DoGet(Type(MapType())), An()); EXPECT_THAT(DoGet(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes"))))), An()); EXPECT_THAT(DoGet(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes"))))), An()); EXPECT_THAT(DoGet(Type(NullType())), An()); EXPECT_THAT(DoGet(Type(OptionalType())), An()); EXPECT_THAT(DoGet(Type(OptionalType())), An()); EXPECT_THAT(DoGet(Type(StringType())), An()); EXPECT_THAT(DoGet(Type(StringWrapperType())), An()); EXPECT_THAT(DoGet(Type(StringWrapperType())), An()); EXPECT_THAT(DoGet(Type(TimestampType())), An()); EXPECT_THAT(DoGet(Type(TypeType())), An()); EXPECT_THAT(DoGet(Type(TypeParamType("T"))), An()); EXPECT_THAT(DoGet(Type(UintType())), An()); EXPECT_THAT(DoGet(Type(UintWrapperType())), An()); EXPECT_THAT(DoGet(Type(UintWrapperType())), An()); EXPECT_THAT(DoGet(Type(UnknownType())), An()); } TEST(Type, VerifyTypeImplementsAbslHashCorrectly) { google::protobuf::Arena arena; EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( {Type(AnyType()), Type(BoolType()), Type(BoolWrapperType()), Type(BytesType()), Type(BytesWrapperType()), Type(DoubleType()), Type(DoubleWrapperType()), Type(DurationType()), Type(DynType()), Type(ErrorType()), Type(FunctionType(&arena, DynType(), {DynType()})), Type(IntType()), Type(IntWrapperType()), Type(ListType(&arena, DynType())), Type(MapType(&arena, DynType(), DynType())), Type(NullType()), Type(OptionalType(&arena, DynType())), Type(StringType()), Type(StringWrapperType()), Type(StructType(common_internal::MakeBasicStructType("test.Struct"))), Type(TimestampType()), Type(TypeParamType("T")), Type(TypeType()), Type(UintType()), Type(UintWrapperType()), Type(UnknownType())})); EXPECT_EQ( absl::HashOf(Type::Field( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("repeated_int64"))), absl::HashOf(Type(ListType(&arena, IntType())))); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("repeated_int64")), Type(ListType(&arena, IntType()))); EXPECT_EQ( absl::HashOf(Type::Field( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("map_int64_int64"))), absl::HashOf(Type(MapType(&arena, IntType(), IntType())))); EXPECT_EQ(Type::Field(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("map_int64_int64")), Type(MapType(&arena, IntType(), IntType()))); EXPECT_EQ(absl::HashOf(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes"))))), absl::HashOf(Type(StructType(common_internal::MakeBasicStructType( "cel.expr.conformance.proto3.TestAllTypes"))))); EXPECT_EQ(Type(MessageType(ABSL_DIE_IF_NULL( GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")))), Type(StructType(common_internal::MakeBasicStructType( "cel.expr.conformance.proto3.TestAllTypes")))); } TEST(Type, Unwrap) { EXPECT_EQ(Type(BoolWrapperType()).Unwrap(), BoolType()); EXPECT_EQ(Type(IntWrapperType()).Unwrap(), IntType()); EXPECT_EQ(Type(UintWrapperType()).Unwrap(), UintType()); EXPECT_EQ(Type(DoubleWrapperType()).Unwrap(), DoubleType()); EXPECT_EQ(Type(BytesWrapperType()).Unwrap(), BytesType()); EXPECT_EQ(Type(StringWrapperType()).Unwrap(), StringType()); EXPECT_EQ(Type(AnyType()).Unwrap(), AnyType()); } TEST(Type, Wrap) { EXPECT_EQ(Type(BoolType()).Wrap(), BoolWrapperType()); EXPECT_EQ(Type(IntType()).Wrap(), IntWrapperType()); EXPECT_EQ(Type(UintType()).Wrap(), UintWrapperType()); EXPECT_EQ(Type(DoubleType()).Wrap(), DoubleWrapperType()); EXPECT_EQ(Type(BytesType()).Wrap(), BytesWrapperType()); EXPECT_EQ(Type(StringType()).Wrap(), StringWrapperType()); EXPECT_EQ(Type(AnyType()).Wrap(), AnyType()); } } // namespace } // namespace cel ================================================ FILE: common/type_testing.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_TESTING_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_TESTING_H_ namespace cel::common_internal { // Empty for now. } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_TESTING_H_ ================================================ FILE: common/typeinfo.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/typeinfo.h" #include #include // IWYU pragma: keep #include #include #include #include "absl/base/casts.h" // IWYU pragma: keep #include "absl/strings/str_cat.h" // IWYU pragma: keep #ifdef CEL_INTERNAL_HAVE_RTTI #ifdef _WIN32 extern "C" char* __unDName(char*, const char*, int, void* (*)(size_t), void (*)(void*), int); #else #include #endif #endif namespace cel { namespace { #ifdef CEL_INTERNAL_HAVE_RTTI struct FreeDeleter { void operator()(char* ptr) const { std::free(ptr); } }; #endif } // namespace std::string TypeInfo::DebugString() const { if (rep_ == nullptr) { return std::string(); } #ifdef CEL_INTERNAL_HAVE_RTTI #ifdef _WIN32 std::unique_ptr demangled( __unDName(nullptr, rep_->raw_name(), 0, std::malloc, std::free, 0x2800)); if (demangled == nullptr) { return std::string(rep_->name()); } return std::string(demangled.get()); #else size_t length = 0; int status = 0; std::unique_ptr demangled( abi::__cxa_demangle(rep_->name(), nullptr, &length, &status)); if (status != 0 || demangled == nullptr) { return std::string(rep_->name()); } while (length != 0 && demangled.get()[length - 1] == '\0') { // length includes the null terminator, remove it. --length; } return std::string(demangled.get(), length); #endif #else return absl::StrCat("0x", absl::Hex(absl::bit_cast(rep_))); #endif } } // namespace cel ================================================ FILE: common/typeinfo.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/casts.h" // IWYU pragma: keep #include "absl/base/config.h" #include "absl/base/nullability.h" #include "absl/meta/type_traits.h" #if ABSL_HAVE_FEATURE(cxx_rtti) #define CEL_INTERNAL_HAVE_RTTI 1 #elif defined(__GNUC__) && defined(__GXX_RTTI) #define CEL_INTERNAL_HAVE_RTTI 1 #elif defined(_MSC_VER) && defined(_CPPRTTI) #define CEL_INTERNAL_HAVE_RTTI 1 #elif !defined(__GNUC__) && !defined(_MSC_VER) #define CEL_INTERNAL_HAVE_RTTI 1 #endif #ifdef CEL_INTERNAL_HAVE_RTTI #include #endif namespace cel { class TypeInfo; template struct NativeTypeTraits; namespace common_internal { template struct HasNativeTypeTraitsId : std::false_type {}; template struct HasNativeTypeTraitsId< T, std::void_t::Id(std::declval()))>> : std::true_type {}; template static constexpr bool HasNativeTypeTraitsIdV = HasNativeTypeTraitsId::value; template struct HasCelTypeId : std::false_type {}; template struct HasCelTypeId< T, std::enable_if_t()))>, TypeInfo>>> : std::true_type {}; } // namespace common_internal template TypeInfo TypeId(); template std::enable_if_t< std::conjunction_v, std::negation>>, TypeInfo> TypeId(const T& t [[maybe_unused]]) { return NativeTypeTraits>::Id(t); } template std::enable_if_t< std::conjunction_v>, std::negation>, std::is_final>, TypeInfo> TypeId(const T& t [[maybe_unused]]) { return cel::TypeId>(); } template std::enable_if_t< std::conjunction_v>, common_internal::HasCelTypeId>, TypeInfo> TypeId(const T& t [[maybe_unused]]) { return CelTypeId(t); } class TypeInfo final { public: template ABSL_DEPRECATED("Use cel::TypeId() instead") static TypeInfo For() { return cel::TypeId(); } template ABSL_DEPRECATED("Use cel::TypeId(...) instead") static TypeInfo Of(const T& type) { return cel::TypeId(type); } TypeInfo() = default; TypeInfo(const TypeInfo&) = default; TypeInfo& operator=(const TypeInfo&) = default; std::string DebugString() const; template friend void AbslStringify(S& sink, TypeInfo type_info) { sink.Append(type_info.DebugString()); } friend constexpr bool operator==(TypeInfo lhs, TypeInfo rhs) noexcept { #ifdef CEL_INTERNAL_HAVE_RTTI return lhs.rep_ == rhs.rep_ || (lhs.rep_ != nullptr && rhs.rep_ != nullptr && *lhs.rep_ == *rhs.rep_); #else return lhs.rep_ == rhs.rep_; #endif } template friend H AbslHashValue(H state, TypeInfo id) { #ifdef CEL_INTERNAL_HAVE_RTTI return H::combine(std::move(state), id.rep_ != nullptr ? id.rep_->hash_code() : size_t{0}); #else return H::combine(std::move(state), absl::bit_cast(id.rep_)); #endif } private: template friend TypeInfo TypeId(); #ifdef CEL_INTERNAL_HAVE_RTTI constexpr explicit TypeInfo(const std::type_info* absl_nullable rep) : rep_(rep) {} const std::type_info* absl_nullable rep_ = nullptr; #else constexpr explicit TypeInfo(const void* absl_nullable rep) : rep_(rep) {} const void* absl_nullable rep_ = nullptr; #endif }; #ifndef CEL_INTERNAL_HAVE_RTTI namespace common_internal { template struct TypeTag final { static constexpr char value = 0; }; } // namespace common_internal #endif template TypeInfo TypeId() { static_assert(std::is_same_v>); static_assert(!std::is_same_v>); #ifdef CEL_INTERNAL_HAVE_RTTI return TypeInfo(&typeid(T)); #else return TypeInfo(&common_internal::TypeTag::value); #endif } inline constexpr bool operator!=(TypeInfo lhs, TypeInfo rhs) noexcept { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, TypeInfo id) { return out << id.DebugString(); } // Helper class for adapting a type to an index in a tuple or array. // Scope is an arbitrary type used as a namespace for the index. template class TypeIdInSet { public: template static size_t IndexFor() { static size_t index = type_id_set_index_.fetch_add(1, std::memory_order_relaxed); return index; } static size_t Size() { return type_id_set_index_.load(std::memory_order_relaxed); } private: static std::atomic type_id_set_index_; }; template std::atomic TypeIdInSet::type_id_set_index_ = 0; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPEINFO_H_ ================================================ FILE: common/typeinfo_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/typeinfo.h" #include #include #include "absl/hash/hash_testing.h" #include "absl/strings/str_cat.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::IsEmpty; using ::testing::Not; using ::testing::SizeIs; struct Type1 {}; struct Type2 {}; struct Type3 {}; TEST(TypeInfo, ImplementsAbslHashCorrectly) { EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( {TypeInfo(), cel::TypeId(), cel::TypeId(), cel::TypeId()})); } TEST(TypeInfo, Ostream) { std::ostringstream out; out << TypeInfo(); EXPECT_THAT(out.str(), IsEmpty()); out << cel::TypeId(); auto string = out.str(); EXPECT_THAT(string, Not(IsEmpty())); EXPECT_THAT(string, SizeIs(std::strlen(string.c_str()))); } TEST(TypeInfo, AbslStringify) { EXPECT_THAT(absl::StrCat(TypeInfo()), IsEmpty()); EXPECT_THAT(absl::StrCat(cel::TypeId()), Not(IsEmpty())); } struct TestType {}; } // namespace template <> struct NativeTypeTraits final { static TypeInfo Id(const TestType&) { return cel::TypeId(); } }; namespace { TEST(TypeInfo, Of) { EXPECT_EQ(cel::TypeId(TestType()), cel::TypeId()); } } // namespace } // namespace cel ================================================ FILE: common/types/any_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_ANY_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_ANY_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `AnyType` is a special type which has no direct value representation. It is // used to represent `google.protobuf.Any`, which never exists at runtime as // a value. Its primary usage is for type checking and unpacking at runtime. class AnyType final { public: static constexpr TypeKind kKind = TypeKind::kAny; static constexpr absl::string_view kName = "google.protobuf.Any"; AnyType() = default; AnyType(const AnyType&) = default; AnyType(AnyType&&) = default; AnyType& operator=(const AnyType&) = default; AnyType& operator=(AnyType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(AnyType, AnyType) { return true; } inline constexpr bool operator!=(AnyType lhs, AnyType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, AnyType) { // AnyType is really a singleton and all instances are equal. Nothing to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const AnyType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_ANY_TYPE_H_ ================================================ FILE: common/types/any_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(AnyType, Kind) { EXPECT_EQ(AnyType().kind(), AnyType::kKind); EXPECT_EQ(Type(AnyType()).kind(), AnyType::kKind); } TEST(AnyType, Name) { EXPECT_EQ(AnyType().name(), AnyType::kName); EXPECT_EQ(Type(AnyType()).name(), AnyType::kName); } TEST(AnyType, DebugString) { { std::ostringstream out; out << AnyType(); EXPECT_EQ(out.str(), AnyType::kName); } { std::ostringstream out; out << Type(AnyType()); EXPECT_EQ(out.str(), AnyType::kName); } } TEST(AnyType, Hash) { EXPECT_EQ(absl::HashOf(AnyType()), absl::HashOf(AnyType())); } TEST(AnyType, Equal) { EXPECT_EQ(AnyType(), AnyType()); EXPECT_EQ(Type(AnyType()), AnyType()); EXPECT_EQ(AnyType(), Type(AnyType())); EXPECT_EQ(Type(AnyType()), Type(AnyType())); } } // namespace } // namespace cel ================================================ FILE: common/types/basic_struct_type.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "common/type.h" namespace cel { bool IsWellKnownMessageType(absl::string_view name) { static constexpr absl::string_view kPrefix = "google.protobuf."; static constexpr std::array kNames = { // clang-format off // keep-sorted start "Any", "BoolValue", "BytesValue", "DoubleValue", "Duration", "FloatValue", "Int32Value", "Int64Value", "ListValue", "StringValue", "Struct", "Timestamp", "UInt32Value", "UInt64Value", "Value", // keep-sorted end // clang-format on }; if (!absl::ConsumePrefix(&name, kPrefix)) { return false; } return absl::c_binary_search(kNames, name); } } // namespace cel ================================================ FILE: common/types/basic_struct_type.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/types/struct_type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BASIC_STRUCT_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BASIC_STRUCT_TYPE_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // Returns true if the given type name is one of the well known message types // that CEL treats specially. // // For familiarity with textproto, these types may be created using the struct // creation syntax, even though they are not considered a struct type in CEL. bool IsWellKnownMessageType(absl::string_view name); namespace common_internal { class BasicStructType; class BasicStructTypeField; // Constructs `BasicStructType` from a type name. The type name must not be one // of the well known message types we treat specially, if it is behavior is // undefined. The name must also outlive the resulting type. BasicStructType MakeBasicStructType( absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND); class BasicStructType final { public: static constexpr TypeKind kKind = TypeKind::kStruct; BasicStructType() = default; BasicStructType(const BasicStructType&) = default; BasicStructType(BasicStructType&&) = default; BasicStructType& operator=(const BasicStructType&) = default; BasicStructType& operator=(BasicStructType&&) = default; static TypeKind kind() { return kKind; } absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(*this); return name_; } static TypeParameters GetParameters(); std::string DebugString() const { return std::string(static_cast(*this) ? name() : absl::string_view()); } explicit operator bool() const { return !name_.empty(); } private: friend BasicStructType MakeBasicStructType( absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND); explicit BasicStructType(absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND) : name_(name) {} absl::string_view name_; }; inline bool operator==(BasicStructType lhs, BasicStructType rhs) { return static_cast(lhs) == static_cast(rhs) && (!static_cast(lhs) || lhs.name() == rhs.name()); } inline bool operator!=(BasicStructType lhs, BasicStructType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, BasicStructType type) { ABSL_DCHECK(type); return H::combine(std::move(state), static_cast(type) ? type.name() : absl::string_view()); } inline std::ostream& operator<<(std::ostream& out, BasicStructType type) { return out << type.DebugString(); } inline BasicStructType MakeBasicStructType( absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(!IsWellKnownMessageType(name)) << name; return BasicStructType(name); } } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BASIC_STRUCT_TYPE_H_ ================================================ FILE: common/types/basic_struct_type_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" namespace cel::common_internal { namespace { using ::testing::Eq; using ::testing::IsEmpty; TEST(BasicStructType, Kind) { EXPECT_EQ(BasicStructType::kind(), TypeKind::kStruct); } TEST(BasicStructType, Default) { BasicStructType type; EXPECT_FALSE(type); EXPECT_THAT(type.DebugString(), Eq("")); EXPECT_EQ(type, BasicStructType()); } TEST(BasicStructType, Name) { BasicStructType type = MakeBasicStructType("test.Struct"); EXPECT_TRUE(type); EXPECT_THAT(type.name(), Eq("test.Struct")); EXPECT_THAT(type.DebugString(), Eq("test.Struct")); EXPECT_THAT(type.GetParameters(), IsEmpty()); EXPECT_NE(type, BasicStructType()); EXPECT_NE(BasicStructType(), type); } } // namespace } // namespace cel::common_internal ================================================ FILE: common/types/bool_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `BoolType` represents the primitive `bool` type. class BoolType final { public: static constexpr TypeKind kKind = TypeKind::kBool; static constexpr absl::string_view kName = "bool"; BoolType() = default; BoolType(const BoolType&) = default; BoolType(BoolType&&) = default; BoolType& operator=(const BoolType&) = default; BoolType& operator=(BoolType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(BoolType, BoolType) { return true; } inline constexpr bool operator!=(BoolType lhs, BoolType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, BoolType) { // BoolType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const BoolType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_TYPE_H_ ================================================ FILE: common/types/bool_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(BoolType, Kind) { EXPECT_EQ(BoolType().kind(), BoolType::kKind); EXPECT_EQ(Type(BoolType()).kind(), BoolType::kKind); } TEST(BoolType, Name) { EXPECT_EQ(BoolType().name(), BoolType::kName); EXPECT_EQ(Type(BoolType()).name(), BoolType::kName); } TEST(BoolType, DebugString) { { std::ostringstream out; out << BoolType(); EXPECT_EQ(out.str(), BoolType::kName); } { std::ostringstream out; out << Type(BoolType()); EXPECT_EQ(out.str(), BoolType::kName); } } TEST(BoolType, Hash) { EXPECT_EQ(absl::HashOf(BoolType()), absl::HashOf(BoolType())); } TEST(BoolType, Equal) { EXPECT_EQ(BoolType(), BoolType()); EXPECT_EQ(Type(BoolType()), BoolType()); EXPECT_EQ(BoolType(), Type(BoolType())); EXPECT_EQ(Type(BoolType()), Type(BoolType())); } } // namespace } // namespace cel ================================================ FILE: common/types/bool_wrapper_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_WRAPPER_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_WRAPPER_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `BoolWrapperType` is a special type which has no direct value representation. // It is used to represent `google.protobuf.BoolValue`, which never exists at // runtime as a value. Its primary usage is for type checking and unpacking at // runtime. class BoolWrapperType final { public: static constexpr TypeKind kKind = TypeKind::kBoolWrapper; static constexpr absl::string_view kName = "google.protobuf.BoolValue"; BoolWrapperType() = default; BoolWrapperType(const BoolWrapperType&) = default; BoolWrapperType(BoolWrapperType&&) = default; BoolWrapperType& operator=(const BoolWrapperType&) = default; BoolWrapperType& operator=(BoolWrapperType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(BoolWrapperType, BoolWrapperType) { return true; } inline constexpr bool operator!=(BoolWrapperType lhs, BoolWrapperType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, BoolWrapperType) { // BoolWrapperType is really a singleton and all instances are equal. Nothing // to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const BoolWrapperType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BOOL_WRAPPER_TYPE_H_ ================================================ FILE: common/types/bool_wrapper_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(BoolWrapperType, Kind) { EXPECT_EQ(BoolWrapperType().kind(), BoolWrapperType::kKind); EXPECT_EQ(Type(BoolWrapperType()).kind(), BoolWrapperType::kKind); } TEST(BoolWrapperType, Name) { EXPECT_EQ(BoolWrapperType().name(), BoolWrapperType::kName); EXPECT_EQ(Type(BoolWrapperType()).name(), BoolWrapperType::kName); } TEST(BoolWrapperType, DebugString) { { std::ostringstream out; out << BoolWrapperType(); EXPECT_EQ(out.str(), BoolWrapperType::kName); } { std::ostringstream out; out << Type(BoolWrapperType()); EXPECT_EQ(out.str(), BoolWrapperType::kName); } } TEST(BoolWrapperType, Hash) { EXPECT_EQ(absl::HashOf(BoolWrapperType()), absl::HashOf(BoolWrapperType())); } TEST(BoolWrapperType, Equal) { EXPECT_EQ(BoolWrapperType(), BoolWrapperType()); EXPECT_EQ(Type(BoolWrapperType()), BoolWrapperType()); EXPECT_EQ(BoolWrapperType(), Type(BoolWrapperType())); EXPECT_EQ(Type(BoolWrapperType()), Type(BoolWrapperType())); } } // namespace } // namespace cel ================================================ FILE: common/types/bytes_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `BoolType` represents the primitive `bytes` type. class BytesType final { public: static constexpr TypeKind kKind = TypeKind::kBytes; static constexpr absl::string_view kName = "bytes"; BytesType() = default; BytesType(const BytesType&) = default; BytesType(BytesType&&) = default; BytesType& operator=(const BytesType&) = default; BytesType& operator=(BytesType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(BytesType, BytesType) { return true; } inline constexpr bool operator!=(BytesType lhs, BytesType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, BytesType) { // BytesType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const BytesType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_TYPE_H_ ================================================ FILE: common/types/bytes_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(BytesType, Kind) { EXPECT_EQ(BytesType().kind(), BytesType::kKind); EXPECT_EQ(Type(BytesType()).kind(), BytesType::kKind); } TEST(BytesType, Name) { EXPECT_EQ(BytesType().name(), BytesType::kName); EXPECT_EQ(Type(BytesType()).name(), BytesType::kName); } TEST(BytesType, DebugString) { { std::ostringstream out; out << BytesType(); EXPECT_EQ(out.str(), BytesType::kName); } { std::ostringstream out; out << Type(BytesType()); EXPECT_EQ(out.str(), BytesType::kName); } } TEST(BytesType, Hash) { EXPECT_EQ(absl::HashOf(BytesType()), absl::HashOf(BytesType())); } TEST(BytesType, Equal) { EXPECT_EQ(BytesType(), BytesType()); EXPECT_EQ(Type(BytesType()), BytesType()); EXPECT_EQ(BytesType(), Type(BytesType())); EXPECT_EQ(Type(BytesType()), Type(BytesType())); } } // namespace } // namespace cel ================================================ FILE: common/types/bytes_wrapper_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_WRAPPER_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_WRAPPER_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `BytesWrapperType` is a special type which has no direct value // representation. It is used to represent `google.protobuf.BytesValue`, which // never exists at runtime as a value. Its primary usage is for type checking // and unpacking at runtime. class BytesWrapperType final { public: static constexpr TypeKind kKind = TypeKind::kBytesWrapper; static constexpr absl::string_view kName = "google.protobuf.BytesValue"; BytesWrapperType() = default; BytesWrapperType(const BytesWrapperType&) = default; BytesWrapperType(BytesWrapperType&&) = default; BytesWrapperType& operator=(const BytesWrapperType&) = default; BytesWrapperType& operator=(BytesWrapperType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(BytesWrapperType, BytesWrapperType) { return true; } inline constexpr bool operator!=(BytesWrapperType lhs, BytesWrapperType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, BytesWrapperType) { // BytesWrapperType is really a singleton and all instances are equal. Nothing // to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const BytesWrapperType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_BYTES_WRAPPER_TYPE_H_ ================================================ FILE: common/types/bytes_wrapper_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(BytesWrapperType, Kind) { EXPECT_EQ(BytesWrapperType().kind(), BytesWrapperType::kKind); EXPECT_EQ(Type(BytesWrapperType()).kind(), BytesWrapperType::kKind); } TEST(BytesWrapperType, Name) { EXPECT_EQ(BytesWrapperType().name(), BytesWrapperType::kName); EXPECT_EQ(Type(BytesWrapperType()).name(), BytesWrapperType::kName); } TEST(BytesWrapperType, DebugString) { { std::ostringstream out; out << BytesWrapperType(); EXPECT_EQ(out.str(), BytesWrapperType::kName); } { std::ostringstream out; out << Type(BytesWrapperType()); EXPECT_EQ(out.str(), BytesWrapperType::kName); } } TEST(BytesWrapperType, Hash) { EXPECT_EQ(absl::HashOf(BytesWrapperType()), absl::HashOf(BytesWrapperType())); } TEST(BytesWrapperType, Equal) { EXPECT_EQ(BytesWrapperType(), BytesWrapperType()); EXPECT_EQ(Type(BytesWrapperType()), BytesWrapperType()); EXPECT_EQ(BytesWrapperType(), Type(BytesWrapperType())); EXPECT_EQ(Type(BytesWrapperType()), Type(BytesWrapperType())); } } // namespace } // namespace cel ================================================ FILE: common/types/double_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `BoolType` represents the primitive `double` type. class DoubleType final { public: static constexpr TypeKind kKind = TypeKind::kDouble; static constexpr absl::string_view kName = "double"; DoubleType() = default; DoubleType(const DoubleType&) = default; DoubleType(DoubleType&&) = default; DoubleType& operator=(const DoubleType&) = default; DoubleType& operator=(DoubleType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(DoubleType, DoubleType) { return true; } inline constexpr bool operator!=(DoubleType lhs, DoubleType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, DoubleType) { // DoubleType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const DoubleType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_TYPE_H_ ================================================ FILE: common/types/double_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(DoubleType, Kind) { EXPECT_EQ(DoubleType().kind(), DoubleType::kKind); EXPECT_EQ(Type(DoubleType()).kind(), DoubleType::kKind); } TEST(DoubleType, Name) { EXPECT_EQ(DoubleType().name(), DoubleType::kName); EXPECT_EQ(Type(DoubleType()).name(), DoubleType::kName); } TEST(DoubleType, DebugString) { { std::ostringstream out; out << DoubleType(); EXPECT_EQ(out.str(), DoubleType::kName); } { std::ostringstream out; out << Type(DoubleType()); EXPECT_EQ(out.str(), DoubleType::kName); } } TEST(DoubleType, Hash) { EXPECT_EQ(absl::HashOf(DoubleType()), absl::HashOf(DoubleType())); } TEST(DoubleType, Equal) { EXPECT_EQ(DoubleType(), DoubleType()); EXPECT_EQ(Type(DoubleType()), DoubleType()); EXPECT_EQ(DoubleType(), Type(DoubleType())); EXPECT_EQ(Type(DoubleType()), Type(DoubleType())); } } // namespace } // namespace cel ================================================ FILE: common/types/double_wrapper_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_WRAPPER_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_WRAPPER_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `DoubleWrapperType` is a special type which has no direct value // representation. It is used to represent `google.protobuf.DoubleValue`, which // never exists at runtime as a value. Its primary usage is for type checking // and unpacking at runtime. class DoubleWrapperType final { public: static constexpr TypeKind kKind = TypeKind::kDoubleWrapper; static constexpr absl::string_view kName = "google.protobuf.DoubleValue"; DoubleWrapperType() = default; DoubleWrapperType(const DoubleWrapperType&) = default; DoubleWrapperType(DoubleWrapperType&&) = default; DoubleWrapperType& operator=(const DoubleWrapperType&) = default; DoubleWrapperType& operator=(DoubleWrapperType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(DoubleWrapperType, DoubleWrapperType) { return true; } inline constexpr bool operator!=(DoubleWrapperType lhs, DoubleWrapperType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, DoubleWrapperType) { // DoubleWrapperType is really a singleton and all instances are equal. // Nothing to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const DoubleWrapperType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_DOUBLE_WRAPPER_TYPE_H_ ================================================ FILE: common/types/double_wrapper_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(DoubleWrapperType, Kind) { EXPECT_EQ(DoubleWrapperType().kind(), DoubleWrapperType::kKind); EXPECT_EQ(Type(DoubleWrapperType()).kind(), DoubleWrapperType::kKind); } TEST(DoubleWrapperType, Name) { EXPECT_EQ(DoubleWrapperType().name(), DoubleWrapperType::kName); EXPECT_EQ(Type(DoubleWrapperType()).name(), DoubleWrapperType::kName); } TEST(DoubleWrapperType, DebugString) { { std::ostringstream out; out << DoubleWrapperType(); EXPECT_EQ(out.str(), DoubleWrapperType::kName); } { std::ostringstream out; out << Type(DoubleWrapperType()); EXPECT_EQ(out.str(), DoubleWrapperType::kName); } } TEST(DoubleWrapperType, Hash) { EXPECT_EQ(absl::HashOf(DoubleWrapperType()), absl::HashOf(DoubleWrapperType())); } TEST(DoubleWrapperType, Equal) { EXPECT_EQ(DoubleWrapperType(), DoubleWrapperType()); EXPECT_EQ(Type(DoubleWrapperType()), DoubleWrapperType()); EXPECT_EQ(DoubleWrapperType(), Type(DoubleWrapperType())); EXPECT_EQ(Type(DoubleWrapperType()), Type(DoubleWrapperType())); } } // namespace } // namespace cel ================================================ FILE: common/types/duration_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_DURATION_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_DURATION_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `DurationType` represents the primitive `duration` type. class DurationType final { public: static constexpr TypeKind kKind = TypeKind::kDuration; static constexpr absl::string_view kName = "google.protobuf.Duration"; DurationType() = default; DurationType(const DurationType&) = default; DurationType(DurationType&&) = default; DurationType& operator=(const DurationType&) = default; DurationType& operator=(DurationType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(DurationType, DurationType) { return true; } inline constexpr bool operator!=(DurationType lhs, DurationType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, DurationType) { // DurationType is really a singleton and all instances are equal. // Nothing to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const DurationType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_DURATION_TYPE_H_ ================================================ FILE: common/types/duration_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(DurationType, Kind) { EXPECT_EQ(DurationType().kind(), DurationType::kKind); EXPECT_EQ(Type(DurationType()).kind(), DurationType::kKind); } TEST(DurationType, Name) { EXPECT_EQ(DurationType().name(), DurationType::kName); EXPECT_EQ(Type(DurationType()).name(), DurationType::kName); } TEST(DurationType, DebugString) { { std::ostringstream out; out << DurationType(); EXPECT_EQ(out.str(), DurationType::kName); } { std::ostringstream out; out << Type(DurationType()); EXPECT_EQ(out.str(), DurationType::kName); } } TEST(DurationType, Hash) { EXPECT_EQ(absl::HashOf(DurationType()), absl::HashOf(DurationType())); } TEST(DurationType, Equal) { EXPECT_EQ(DurationType(), DurationType()); EXPECT_EQ(Type(DurationType()), DurationType()); EXPECT_EQ(DurationType(), Type(DurationType())); EXPECT_EQ(Type(DurationType()), Type(DurationType())); } } // namespace } // namespace cel ================================================ FILE: common/types/dyn_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_DYN_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_DYN_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `DynType` is a special type which represents any type and has no direct value // representation. class DynType final { public: static constexpr TypeKind kKind = TypeKind::kDyn; static constexpr absl::string_view kName = "dyn"; DynType() = default; DynType(const DynType&) = default; DynType(DynType&&) = default; DynType& operator=(const DynType&) = default; DynType& operator=(DynType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(DynType, DynType) { return true; } inline constexpr bool operator!=(DynType lhs, DynType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, DynType) { // DynType is really a singleton and all instances are equal. Nothing to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const DynType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_DYN_TYPE_H_ ================================================ FILE: common/types/dyn_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(DynType, Kind) { EXPECT_EQ(DynType().kind(), DynType::kKind); EXPECT_EQ(Type(DynType()).kind(), DynType::kKind); } TEST(DynType, Name) { EXPECT_EQ(DynType().name(), DynType::kName); EXPECT_EQ(Type(DynType()).name(), DynType::kName); } TEST(DynType, DebugString) { { std::ostringstream out; out << DynType(); EXPECT_EQ(out.str(), DynType::kName); } { std::ostringstream out; out << Type(DynType()); EXPECT_EQ(out.str(), DynType::kName); } } TEST(DynType, Hash) { EXPECT_EQ(absl::HashOf(DynType()), absl::HashOf(DynType())); } TEST(DynType, Equal) { EXPECT_EQ(DynType(), DynType()); EXPECT_EQ(Type(DynType()), DynType()); EXPECT_EQ(DynType(), Type(DynType())); EXPECT_EQ(Type(DynType()), Type(DynType())); } } // namespace } // namespace cel ================================================ FILE: common/types/enum_type.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/strings/str_cat.h" #include "common/type.h" #include "google/protobuf/descriptor.h" namespace cel { using google::protobuf::EnumDescriptor; bool IsWellKnownEnumType(const EnumDescriptor* absl_nonnull descriptor) { return descriptor->full_name() == "google.protobuf.NullValue"; } std::string EnumType::DebugString() const { if (ABSL_PREDICT_TRUE(static_cast(*this))) { static_assert(sizeof(descriptor_) == 8 || sizeof(descriptor_) == 4, "sizeof(void*) is neither 8 nor 4"); return absl::StrCat(name(), "@0x", absl::Hex(descriptor_, sizeof(descriptor_) == 8 ? absl::PadSpec::kZeroPad16 : absl::PadSpec::kZeroPad8)); } return std::string(); } } // namespace cel ================================================ FILE: common/types/enum_type.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "common/type_kind.h" #include "google/protobuf/descriptor.h" namespace cel { class Type; class TypeParameters; bool IsWellKnownEnumType(const google::protobuf::EnumDescriptor* absl_nonnull descriptor); class EnumType final { public: using element_type = const google::protobuf::EnumDescriptor; static constexpr TypeKind kKind = TypeKind::kEnum; // Constructs `EnumType` from a pointer to `google::protobuf::EnumDescriptor`. The // `google::protobuf::EnumDescriptor` must not be one of the well known enum types we // treat specially, if it is behavior is undefined. If you are unsure, you // should use `Type::Enum`. explicit EnumType(const google::protobuf::EnumDescriptor* absl_nullable descriptor) : descriptor_(descriptor) { ABSL_DCHECK(descriptor == nullptr || !IsWellKnownEnumType(descriptor)) << descriptor->full_name(); } EnumType() = default; EnumType(const EnumType&) = default; EnumType(EnumType&&) = default; EnumType& operator=(const EnumType&) = default; EnumType& operator=(EnumType&&) = default; static TypeKind kind() { return kKind; } absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return (*this)->full_name(); } std::string DebugString() const; static TypeParameters GetParameters(); const google::protobuf::EnumDescriptor& operator*() const { ABSL_DCHECK(*this); return *descriptor_; } const google::protobuf::EnumDescriptor* absl_nonnull operator->() const { ABSL_DCHECK(*this); return descriptor_; } explicit operator bool() const { return descriptor_ != nullptr; } private: friend struct std::pointer_traits; const google::protobuf::EnumDescriptor* absl_nullable descriptor_ = nullptr; }; inline bool operator==(EnumType lhs, EnumType rhs) { return static_cast(lhs) == static_cast(rhs) && (!static_cast(lhs) || lhs.name() == rhs.name()); } inline bool operator!=(EnumType lhs, EnumType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, EnumType enum_type) { return H::combine(std::move(state), static_cast(enum_type) ? enum_type.name() : absl::string_view()); } inline std::ostream& operator<<(std::ostream& out, EnumType type) { return out << type.DebugString(); } } // namespace cel namespace std { template <> struct pointer_traits { using pointer = cel::EnumType; using element_type = typename cel::EnumType::element_type; using difference_type = ptrdiff_t; static element_type* to_address(const pointer& p) noexcept { return p.descriptor_; } }; } // namespace std #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ ================================================ FILE: common/types/enum_type_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "google/protobuf/descriptor.pb.h" #include "common/memory.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { using ::testing::Eq; using ::testing::IsEmpty; using ::testing::NotNull; using ::testing::StartsWith; TEST(EnumType, Kind) { EXPECT_EQ(EnumType::kind(), TypeKind::kEnum); } TEST(EnumType, Default) { EnumType type; EXPECT_FALSE(type); EXPECT_THAT(type.DebugString(), Eq("")); EXPECT_EQ(type, EnumType()); } TEST(EnumType, Descriptor) { google::protobuf::DescriptorPool pool; { google::protobuf::FileDescriptorProto file_desc_proto; file_desc_proto.set_syntax("proto3"); file_desc_proto.set_package("test"); file_desc_proto.set_name("test/enum.proto"); auto* enum_desc = file_desc_proto.add_enum_type(); enum_desc->set_name("Enum"); auto* enum_value_desc = enum_desc->add_value(); enum_value_desc->set_number(0); enum_value_desc->set_name("VALUE"); ASSERT_THAT(pool.BuildFile(file_desc_proto), NotNull()); } const google::protobuf::EnumDescriptor* desc = pool.FindEnumTypeByName("test.Enum"); ASSERT_THAT(desc, NotNull()); EnumType type(desc); EXPECT_TRUE(type); EXPECT_THAT(type.name(), Eq("test.Enum")); EXPECT_THAT(type.DebugString(), StartsWith("test.Enum@0x")); EXPECT_THAT(type.GetParameters(), IsEmpty()); EXPECT_NE(type, EnumType()); EXPECT_NE(EnumType(), type); EXPECT_EQ(cel::to_address(type), desc); } } // namespace } // namespace cel ================================================ FILE: common/types/error_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_ERROR_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_ERROR_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `ErrorType` is a special type which represents an error during type checking // or an error value at runtime. See // https://github.com/google/cel-spec/blob/master/doc/langdef.md#runtime-errors. class ErrorType final { public: static constexpr TypeKind kKind = TypeKind::kError; static constexpr absl::string_view kName = "*error*"; ErrorType() = default; ErrorType(const ErrorType&) = default; ErrorType(ErrorType&&) = default; ErrorType& operator=(const ErrorType&) = default; ErrorType& operator=(ErrorType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(ErrorType, ErrorType) { return true; } inline constexpr bool operator!=(ErrorType lhs, ErrorType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, ErrorType) { // ErrorType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const ErrorType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_ERROR_TYPE_H_ ================================================ FILE: common/types/error_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(ErrorType, Kind) { EXPECT_EQ(ErrorType().kind(), ErrorType::kKind); EXPECT_EQ(Type(ErrorType()).kind(), ErrorType::kKind); } TEST(ErrorType, Name) { EXPECT_EQ(ErrorType().name(), ErrorType::kName); EXPECT_EQ(Type(ErrorType()).name(), ErrorType::kName); } TEST(ErrorType, DebugString) { { std::ostringstream out; out << ErrorType(); EXPECT_EQ(out.str(), ErrorType::kName); } { std::ostringstream out; out << Type(ErrorType()); EXPECT_EQ(out.str(), ErrorType::kName); } } TEST(ErrorType, Hash) { EXPECT_EQ(absl::HashOf(ErrorType()), absl::HashOf(ErrorType())); } TEST(ErrorType, Equal) { EXPECT_EQ(ErrorType(), ErrorType()); EXPECT_EQ(Type(ErrorType()), ErrorType()); EXPECT_EQ(ErrorType(), Type(ErrorType())); EXPECT_EQ(Type(ErrorType()), Type(ErrorType())); } } // namespace } // namespace cel ================================================ FILE: common/types/function_type.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/type.h" #include "google/protobuf/arena.h" namespace cel { namespace { struct TypeFormatter { void operator()(std::string* out, const Type& type) const { out->append(type.DebugString()); } }; std::string FunctionDebugString(const Type& result, absl::Span args) { return absl::StrCat("(", absl::StrJoin(args, ", ", TypeFormatter{}), ") -> ", result.DebugString()); } } // namespace namespace common_internal { FunctionTypeData* absl_nonnull FunctionTypeData::Create( google::protobuf::Arena* absl_nonnull arena, const Type& result, absl::Span args) { return ::new (arena->AllocateAligned( offsetof(FunctionTypeData, args) + ((1 + args.size()) * sizeof(Type)), alignof(FunctionTypeData))) FunctionTypeData(result, args); } FunctionTypeData::FunctionTypeData(const Type& result, absl::Span args) : args_size(1 + args.size()) { this->args[0] = result; std::memcpy(this->args + 1, args.data(), args.size() * sizeof(Type)); } } // namespace common_internal FunctionType::FunctionType(google::protobuf::Arena* absl_nonnull arena, const Type& result, absl::Span args) : FunctionType( common_internal::FunctionTypeData::Create(arena, result, args)) {} std::string FunctionType::DebugString() const { return FunctionDebugString(result(), args()); } TypeParameters FunctionType::GetParameters() const { ABSL_DCHECK(*this); return TypeParameters(absl::MakeConstSpan(data_->args, data_->args_size)); } const Type& FunctionType::result() const { ABSL_DCHECK(*this); return data_->args[0]; } absl::Span FunctionType::args() const { ABSL_DCHECK(*this); return absl::MakeConstSpan(data_->args + 1, data_->args_size - 1); } } // namespace cel ================================================ FILE: common/types/function_type.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" namespace cel { class Type; class TypeParameters; namespace common_internal { struct FunctionTypeData; } // namespace common_internal class FunctionType final { public: static constexpr TypeKind kKind = TypeKind::kFunction; static constexpr absl::string_view kName = "function"; FunctionType(google::protobuf::Arena* absl_nonnull arena, const Type& result, absl::Span args); FunctionType() = default; FunctionType(const FunctionType&) = default; FunctionType(FunctionType&&) = default; FunctionType& operator=(const FunctionType&) = default; FunctionType& operator=(FunctionType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; std::string DebugString() const; const Type& result() const ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::Span args() const ABSL_ATTRIBUTE_LIFETIME_BOUND; explicit operator bool() const { return data_ != nullptr; } private: explicit FunctionType( const common_internal::FunctionTypeData* absl_nullable data) : data_(data) {} const common_internal::FunctionTypeData* absl_nullable data_ = nullptr; }; bool operator==(const FunctionType& lhs, const FunctionType& rhs); inline bool operator!=(const FunctionType& lhs, const FunctionType& rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const FunctionType& type); inline std::ostream& operator<<(std::ostream& out, const FunctionType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_H_ ================================================ FILE: common/types/function_type_pool.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/types/function_type_pool.h" #include "absl/types/span.h" #include "common/type.h" namespace cel::common_internal { FunctionType FunctionTypePool::InternFunctionType(const Type& result, absl::Span args) { return *function_types_.lazy_emplace( AsTuple(result, args), [&](const auto& ctor) { ctor(FunctionType(arena_, result, args)); }); } } // namespace cel::common_internal ================================================ FILE: common/types/function_type_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_POOL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_POOL_H_ #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "absl/log/die_if_null.h" #include "absl/types/span.h" #include "common/type.h" #include "google/protobuf/arena.h" namespace cel::common_internal { // `FunctionTypePool` is a thread unsafe interning factory for `FunctionType`. class FunctionTypePool final { public: explicit FunctionTypePool(google::protobuf::Arena* absl_nonnull arena) : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK // Returns a `FunctionType` which has the provided parameters, interning as // necessary. FunctionType InternFunctionType(const Type& result, absl::Span args); private: using FunctionTypeTuple = std::tuple, absl::Span>; static FunctionTypeTuple AsTuple(const FunctionType& function_type) { return AsTuple(function_type.result(), function_type.args()); } static FunctionTypeTuple AsTuple(const Type& result, absl::Span args) { return FunctionTypeTuple{std::cref(result), args}; } struct Hasher { using is_transparent = void; size_t operator()(const FunctionType& data) const { return (*this)(AsTuple(data)); } size_t operator()(const FunctionTypeTuple& tuple) const { return absl::Hash{}(tuple); } }; struct Equaler { using is_transparent = void; bool operator()(const FunctionType& lhs, const FunctionType& rhs) const { return (*this)(AsTuple(lhs), AsTuple(rhs)); } bool operator()(const FunctionType& lhs, const FunctionTypeTuple& rhs) const { return (*this)(AsTuple(lhs), rhs); } bool operator()(const FunctionTypeTuple& lhs, const FunctionType& rhs) const { return (*this)(lhs, AsTuple(rhs)); } bool operator()(const FunctionTypeTuple& lhs, const FunctionTypeTuple& rhs) const { return std::get<0>(lhs) == std::get<0>(rhs) && absl::c_equal(std::get<1>(lhs), std::get<1>(rhs)); } }; google::protobuf::Arena* absl_nonnull const arena_; absl::flat_hash_set function_types_; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_FUNCTION_TYPE_POOL_H_ ================================================ FILE: common/types/function_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { TEST(FunctionType, Kind) { google::protobuf::Arena arena; EXPECT_EQ(FunctionType(&arena, DynType{}, {BytesType()}).kind(), FunctionType::kKind); EXPECT_EQ(Type(FunctionType(&arena, DynType{}, {BytesType()})).kind(), FunctionType::kKind); } TEST(FunctionType, Name) { google::protobuf::Arena arena; EXPECT_EQ(FunctionType(&arena, DynType{}, {BytesType()}).name(), "function"); EXPECT_EQ(Type(FunctionType(&arena, DynType{}, {BytesType()})).name(), "function"); } TEST(FunctionType, DebugString) { google::protobuf::Arena arena; { std::ostringstream out; out << FunctionType(&arena, DynType{}, {BytesType()}); EXPECT_EQ(out.str(), "(bytes) -> dyn"); } { std::ostringstream out; out << Type(FunctionType(&arena, DynType{}, {BytesType()})); EXPECT_EQ(out.str(), "(bytes) -> dyn"); } } TEST(FunctionType, Hash) { google::protobuf::Arena arena; EXPECT_EQ(absl::HashOf(FunctionType(&arena, DynType{}, {BytesType()})), absl::HashOf(FunctionType(&arena, DynType{}, {BytesType()}))); } TEST(FunctionType, Equal) { google::protobuf::Arena arena; EXPECT_EQ(FunctionType(&arena, DynType{}, {BytesType()}), FunctionType(&arena, DynType{}, {BytesType()})); EXPECT_EQ(Type(FunctionType(&arena, DynType{}, {BytesType()})), FunctionType(&arena, DynType{}, {BytesType()})); EXPECT_EQ(FunctionType(&arena, DynType{}, {BytesType()}), Type(FunctionType(&arena, DynType{}, {BytesType()}))); EXPECT_EQ(Type(FunctionType(&arena, DynType{}, {BytesType()})), Type(FunctionType(&arena, DynType{}, {BytesType()}))); } } // namespace } // namespace cel ================================================ FILE: common/types/int_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `IntType` represents the primitive `int` type. class IntType final { public: static constexpr TypeKind kKind = TypeKind::kInt; static constexpr absl::string_view kName = "int"; IntType() = default; IntType(const IntType&) = default; IntType(IntType&&) = default; IntType& operator=(const IntType&) = default; IntType& operator=(IntType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(IntType, IntType) { return true; } inline constexpr bool operator!=(IntType lhs, IntType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, IntType) { // IntType is really a singleton and all instances are equal. Nothing to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const IntType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_TYPE_H_ ================================================ FILE: common/types/int_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(IntType, Kind) { EXPECT_EQ(IntType().kind(), IntType::kKind); EXPECT_EQ(Type(IntType()).kind(), IntType::kKind); } TEST(IntType, Name) { EXPECT_EQ(IntType().name(), IntType::kName); EXPECT_EQ(Type(IntType()).name(), IntType::kName); } TEST(IntType, DebugString) { { std::ostringstream out; out << IntType(); EXPECT_EQ(out.str(), IntType::kName); } { std::ostringstream out; out << Type(IntType()); EXPECT_EQ(out.str(), IntType::kName); } } TEST(IntType, Hash) { EXPECT_EQ(absl::HashOf(IntType()), absl::HashOf(IntType())); } TEST(IntType, Equal) { EXPECT_EQ(IntType(), IntType()); EXPECT_EQ(Type(IntType()), IntType()); EXPECT_EQ(IntType(), Type(IntType())); EXPECT_EQ(Type(IntType()), Type(IntType())); } } // namespace } // namespace cel ================================================ FILE: common/types/int_wrapper_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_WRAPPER_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_WRAPPER_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `IntWrapperType` is a special type which has no direct value // representation. It is used to represent `google.protobuf.Int64Value`, which // never exists at runtime as a value. Its primary usage is for type checking // and unpacking at runtime. class IntWrapperType final { public: static constexpr TypeKind kKind = TypeKind::kIntWrapper; static constexpr absl::string_view kName = "google.protobuf.Int64Value"; IntWrapperType() = default; IntWrapperType(const IntWrapperType&) = default; IntWrapperType(IntWrapperType&&) = default; IntWrapperType& operator=(const IntWrapperType&) = default; IntWrapperType& operator=(IntWrapperType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(IntWrapperType, IntWrapperType) { return true; } inline constexpr bool operator!=(IntWrapperType lhs, IntWrapperType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, IntWrapperType) { // IntWrapperType is really a singleton and all instances are equal. Nothing // to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const IntWrapperType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_INT_WRAPPER_TYPE_H_ ================================================ FILE: common/types/int_wrapper_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(IntWrapperType, Kind) { EXPECT_EQ(IntWrapperType().kind(), IntWrapperType::kKind); EXPECT_EQ(Type(IntWrapperType()).kind(), IntWrapperType::kKind); } TEST(IntWrapperType, Name) { EXPECT_EQ(IntWrapperType().name(), IntWrapperType::kName); EXPECT_EQ(Type(IntWrapperType()).name(), IntWrapperType::kName); } TEST(IntWrapperType, DebugString) { { std::ostringstream out; out << IntWrapperType(); EXPECT_EQ(out.str(), IntWrapperType::kName); } { std::ostringstream out; out << Type(IntWrapperType()); EXPECT_EQ(out.str(), IntWrapperType::kName); } } TEST(IntWrapperType, Hash) { EXPECT_EQ(absl::HashOf(IntWrapperType()), absl::HashOf(IntWrapperType())); } TEST(IntWrapperType, Equal) { EXPECT_EQ(IntWrapperType(), IntWrapperType()); EXPECT_EQ(Type(IntWrapperType()), IntWrapperType()); EXPECT_EQ(IntWrapperType(), Type(IntWrapperType())); EXPECT_EQ(Type(IntWrapperType()), Type(IntWrapperType())); } } // namespace } // namespace cel ================================================ FILE: common/types/legacy_type_introspector.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_INTROSPECTOR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_INTROSPECTOR_H_ #include "common/type_introspector.h" namespace cel::common_internal { // `LegacyTypeIntrospector` is an implementation which should be used when // converting between `cel::Value` and `google::api::expr::runtime::CelValue` // and only then. class LegacyTypeIntrospector : public virtual TypeIntrospector { public: LegacyTypeIntrospector() = default; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_INTROSPECTOR_H_ ================================================ FILE: common/types/list_type.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "common/type.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { namespace common_internal { namespace { ABSL_CONST_INIT const ListTypeData kDynListTypeData; } // namespace ListTypeData* absl_nonnull ListTypeData::Create( google::protobuf::Arena* absl_nonnull arena, const Type& element) { return ::new (arena->AllocateAligned( sizeof(ListTypeData), alignof(ListTypeData))) ListTypeData(element); } ListTypeData::ListTypeData(const Type& element) : element(element) {} } // namespace common_internal ListType::ListType() : ListType(&common_internal::kDynListTypeData) {} ListType::ListType(google::protobuf::Arena* absl_nonnull arena, const Type& element) : ListType(element.IsDyn() ? &common_internal::kDynListTypeData : common_internal::ListTypeData::Create(arena, element)) {} std::string ListType::DebugString() const { return absl::StrCat("list<", TypeKindToString(GetElement().kind()), ">"); } TypeParameters ListType::GetParameters() const { return TypeParameters(GetElement()); } Type ListType::GetElement() const { ABSL_DCHECK_NE(data_, 0); if ((data_ & kBasicBit) == kBasicBit) { return reinterpret_cast(data_ & kPointerMask) ->element; } if ((data_ & kProtoBit) == kProtoBit) { return common_internal::SingularMessageFieldType( reinterpret_cast(data_ & kPointerMask)); } return Type(); } Type ListType::element() const { return GetElement(); } } // namespace cel ================================================ FILE: common/types/list_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/numeric/bits.h" #include "absl/strings/string_view.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { class Type; class TypeParameters; namespace common_internal { struct ListTypeData; } // namespace common_internal class ListType final { private: static constexpr uintptr_t kBasicBit = 1; static constexpr uintptr_t kProtoBit = 2; static constexpr uintptr_t kBits = kBasicBit | kProtoBit; static constexpr uintptr_t kPointerMask = ~kBits; public: static constexpr TypeKind kKind = TypeKind::kList; static constexpr absl::string_view kName = "list"; ListType(google::protobuf::Arena* absl_nonnull arena, const Type& element); // By default, this type is `list(dyn)`. Unless you can help it, you should // use a more specific list type. ListType(); ListType(const ListType&) = default; ListType(ListType&&) = default; ListType& operator=(const ListType&) = default; ListType& operator=(ListType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } std::string DebugString() const; TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_DEPRECATED("Use GetElement") Type element() const; Type GetElement() const; private: friend class Type; explicit ListType(const common_internal::ListTypeData* absl_nonnull data) : data_(reinterpret_cast(data) | kBasicBit) { ABSL_DCHECK_GE(absl::countr_zero(reinterpret_cast(data)), 2) << "alignment must be greater than 2"; } explicit ListType(const google::protobuf::FieldDescriptor* absl_nonnull descriptor) : data_(reinterpret_cast(descriptor) | kProtoBit) { ABSL_DCHECK_GE(absl::countr_zero(reinterpret_cast(descriptor)), 2) << "alignment must be greater than 2"; ABSL_DCHECK(descriptor->is_repeated()); ABSL_DCHECK(!descriptor->is_map()); } uintptr_t data_; }; bool operator==(const ListType& lhs, const ListType& rhs); inline bool operator!=(const ListType& lhs, const ListType& rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const ListType& type); inline std::ostream& operator<<(std::ostream& out, const ListType& type) { return out << type.DebugString(); } inline ListType JsonListType() { return ListType(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_H_ ================================================ FILE: common/types/list_type_pool.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/types/list_type_pool.h" #include "common/type.h" namespace cel::common_internal { ListType ListTypePool::InternListType(const Type& element) { if (element.IsDyn()) { return ListType(); } return *list_types_.lazy_emplace( element, [&](const auto& ctor) { ctor(ListType(arena_, element)); }); } } // namespace cel::common_internal ================================================ FILE: common/types/list_type_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_POOL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_POOL_H_ #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "absl/log/die_if_null.h" #include "common/type.h" #include "google/protobuf/arena.h" namespace cel::common_internal { // `ListTypePool` is a thread unsafe interning factory for `ListType`. class ListTypePool final { public: explicit ListTypePool(google::protobuf::Arena* absl_nonnull arena) : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK // Returns a `ListType` which has the provided parameters, interning as // necessary. ListType InternListType(const Type& element); private: struct Hasher { using is_transparent = void; size_t operator()(const ListType& list_type) const { return (*this)(list_type.element()); } size_t operator()(const Type& type) const { return absl::Hash{}(type); } }; struct Equaler { using is_transparent = void; bool operator()(const ListType& lhs, const ListType& rhs) const { return (*this)(lhs.element(), rhs.element()); } bool operator()(const ListType& lhs, const Type& rhs) const { return (*this)(lhs.element(), rhs); } bool operator()(const Type& lhs, const ListType& rhs) const { return (*this)(lhs, rhs.element()); } bool operator()(const Type& lhs, const Type& rhs) const { return lhs == rhs; } }; google::protobuf::Arena* absl_nonnull const arena_; absl::flat_hash_set list_types_; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LIST_TYPE_POOL_H_ ================================================ FILE: common/types/list_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { TEST(ListType, Default) { ListType list_type; EXPECT_EQ(list_type.element(), DynType()); } TEST(ListType, Kind) { google::protobuf::Arena arena; EXPECT_EQ(ListType(&arena, BoolType()).kind(), ListType::kKind); EXPECT_EQ(Type(ListType(&arena, BoolType())).kind(), ListType::kKind); } TEST(ListType, Name) { google::protobuf::Arena arena; EXPECT_EQ(ListType(&arena, BoolType()).name(), ListType::kName); EXPECT_EQ(Type(ListType(&arena, BoolType())).name(), ListType::kName); } TEST(ListType, DebugString) { google::protobuf::Arena arena; { std::ostringstream out; out << ListType(&arena, BoolType()); EXPECT_EQ(out.str(), "list"); } { std::ostringstream out; out << Type(ListType(&arena, BoolType())); EXPECT_EQ(out.str(), "list"); } } TEST(ListType, Hash) { google::protobuf::Arena arena; EXPECT_EQ(absl::HashOf(ListType(&arena, BoolType())), absl::HashOf(ListType(&arena, BoolType()))); } TEST(ListType, Equal) { google::protobuf::Arena arena; EXPECT_EQ(ListType(&arena, BoolType()), ListType(&arena, BoolType())); EXPECT_EQ(Type(ListType(&arena, BoolType())), ListType(&arena, BoolType())); EXPECT_EQ(ListType(&arena, BoolType()), Type(ListType(&arena, BoolType()))); EXPECT_EQ(Type(ListType(&arena, BoolType())), Type(ListType(&arena, BoolType()))); } } // namespace } // namespace cel ================================================ FILE: common/types/map_type.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "common/type.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { namespace common_internal { namespace { ABSL_CONST_INIT const MapTypeData kDynDynMapTypeData = { .key_and_value = {DynType(), DynType()}, }; ABSL_CONST_INIT const MapTypeData kStringDynMapTypeData = { .key_and_value = {StringType(), DynType()}, }; } // namespace MapTypeData* absl_nonnull MapTypeData::Create(google::protobuf::Arena* absl_nonnull arena, const Type& key, const Type& value) { MapTypeData* data = ::new (arena->AllocateAligned(sizeof(MapTypeData), alignof(MapTypeData))) MapTypeData; data->key_and_value[0] = key; data->key_and_value[1] = value; return data; } } // namespace common_internal MapType::MapType() : MapType(&common_internal::kDynDynMapTypeData) {} MapType::MapType(google::protobuf::Arena* absl_nonnull arena, const Type& key, const Type& value) : MapType(key.IsDyn() && value.IsDyn() ? &common_internal::kDynDynMapTypeData : common_internal::MapTypeData::Create(arena, key, value)) {} std::string MapType::DebugString() const { return absl::StrCat("map<", TypeKindToString(key().kind()), ", ", TypeKindToString(value().kind()), ">"); } TypeParameters MapType::GetParameters() const { ABSL_DCHECK_NE(data_, 0); if ((data_ & kBasicBit) == kBasicBit) { const auto* data = reinterpret_cast( data_ & kPointerMask); return TypeParameters(data->key_and_value[0], data->key_and_value[1]); } if ((data_ & kProtoBit) == kProtoBit) { const auto* descriptor = reinterpret_cast(data_ & kPointerMask); return TypeParameters(Type::Field(descriptor->map_key()), Type::Field(descriptor->map_value())); } return TypeParameters(Type(), Type()); } Type MapType::GetKey() const { ABSL_DCHECK_NE(data_, 0); if ((data_ & kBasicBit) == kBasicBit) { return reinterpret_cast(data_ & kPointerMask) ->key_and_value[0]; } if ((data_ & kProtoBit) == kProtoBit) { return Type::Field( reinterpret_cast(data_ & kPointerMask) ->map_key()); } return Type(); } Type MapType::key() const { return GetKey(); } Type MapType::GetValue() const { ABSL_DCHECK_NE(data_, 0); if ((data_ & kBasicBit) == kBasicBit) { return reinterpret_cast(data_ & kPointerMask) ->key_and_value[1]; } if ((data_ & kProtoBit) == kProtoBit) { return Type::Field( reinterpret_cast(data_ & kPointerMask) ->map_value()); } return Type(); } Type MapType::value() const { return GetValue(); } MapType JsonMapType() { return MapType(&common_internal::kStringDynMapTypeData); } } // namespace cel ================================================ FILE: common/types/map_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/numeric/bits.h" #include "absl/strings/string_view.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { class Type; class TypeParameters; namespace common_internal { struct MapTypeData; } // namespace common_internal class MapType; MapType JsonMapType(); class MapType final { private: static constexpr uintptr_t kBasicBit = 1; static constexpr uintptr_t kProtoBit = 2; static constexpr uintptr_t kBits = kBasicBit | kProtoBit; static constexpr uintptr_t kPointerMask = ~kBits; public: static constexpr TypeKind kKind = TypeKind::kMap; static constexpr absl::string_view kName = "map"; MapType(google::protobuf::Arena* absl_nonnull arena, const Type& key, const Type& value); // By default, this type is `map(dyn, dyn)`. Unless you can help it, you // should use a more specific map type. MapType(); MapType(const MapType&) = default; MapType(MapType&&) = default; MapType& operator=(const MapType&) = default; MapType& operator=(MapType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } std::string DebugString() const; TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; ABSL_DEPRECATED("Use GetKey") Type key() const; Type GetKey() const; ABSL_DEPRECATED("Use GetValue") Type value() const; Type GetValue() const; private: friend class Type; friend MapType JsonMapType(); explicit MapType(const common_internal::MapTypeData* absl_nonnull data) : data_(reinterpret_cast(data) | kBasicBit) { ABSL_DCHECK_GE(absl::countr_zero(reinterpret_cast(data)), 2) << "alignment must be greater than 2"; } explicit MapType(const google::protobuf::Descriptor* absl_nonnull descriptor) : data_(reinterpret_cast(descriptor) | kProtoBit) { ABSL_DCHECK_GE(absl::countr_zero(reinterpret_cast(descriptor)), 2) << "alignment must be greater than 2"; ABSL_DCHECK(descriptor->map_key() != nullptr); ABSL_DCHECK(descriptor->map_value() != nullptr); } uintptr_t data_; }; bool operator==(const MapType& lhs, const MapType& rhs); inline bool operator!=(const MapType& lhs, const MapType& rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const MapType& type); inline std::ostream& operator<<(std::ostream& out, const MapType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_H_ ================================================ FILE: common/types/map_type_pool.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/types/map_type_pool.h" #include "common/type.h" namespace cel::common_internal { MapType MapTypePool::InternMapType(const Type& key, const Type& value) { if (key.IsDyn() && value.IsDyn()) { return MapType(); } return *map_types_.lazy_emplace(AsTuple(key, value), [&](const auto& ctor) { ctor(MapType(arena_, key, value)); }); } } // namespace cel::common_internal ================================================ FILE: common/types/map_type_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_POOL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_POOL_H_ #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "absl/log/die_if_null.h" #include "common/type.h" #include "google/protobuf/arena.h" namespace cel::common_internal { // `MapTypePool` is a thread unsafe interning factory for `MapType`. class MapTypePool final { public: explicit MapTypePool(google::protobuf::Arena* absl_nonnull arena) : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK // Returns a `MapType` which has the provided parameters, interning as // necessary. MapType InternMapType(const Type& key, const Type& value); private: using MapTypeTuple = std::tuple, std::reference_wrapper>; static MapTypeTuple AsTuple(const MapType& map_type) { return AsTuple(map_type.key(), map_type.value()); } static MapTypeTuple AsTuple(const Type& key, const Type& value) { return MapTypeTuple{std::cref(key), std::cref(value)}; } struct Hasher { using is_transparent = void; size_t operator()(const MapType& map_type) const { return (*this)(AsTuple(map_type)); } size_t operator()(const MapTypeTuple& tuple) const { return absl::Hash{}(tuple); } }; struct Equaler { using is_transparent = void; bool operator()(const MapType& lhs, const MapType& rhs) const { return (*this)(AsTuple(lhs), AsTuple(rhs)); } bool operator()(const MapType& lhs, const MapTypeTuple& rhs) const { return (*this)(AsTuple(lhs), rhs); } bool operator()(const MapTypeTuple& lhs, const MapType& rhs) const { return (*this)(lhs, AsTuple(rhs)); } bool operator()(const MapTypeTuple& lhs, const MapTypeTuple& rhs) const { return lhs == rhs; } }; google::protobuf::Arena* absl_nonnull const arena_; absl::flat_hash_set map_types_; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_MAP_TYPE_POOL_H_ ================================================ FILE: common/types/map_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { TEST(MapType, Default) { MapType map_type; EXPECT_EQ(map_type.key(), DynType()); EXPECT_EQ(map_type.value(), DynType()); } TEST(MapType, Kind) { google::protobuf::Arena arena; EXPECT_EQ(MapType(&arena, StringType(), BytesType()).kind(), MapType::kKind); EXPECT_EQ(Type(MapType(&arena, StringType(), BytesType())).kind(), MapType::kKind); } TEST(MapType, Name) { google::protobuf::Arena arena; EXPECT_EQ(MapType(&arena, StringType(), BytesType()).name(), MapType::kName); EXPECT_EQ(Type(MapType(&arena, StringType(), BytesType())).name(), MapType::kName); } TEST(MapType, DebugString) { google::protobuf::Arena arena; { std::ostringstream out; out << MapType(&arena, StringType(), BytesType()); EXPECT_EQ(out.str(), "map"); } { std::ostringstream out; out << Type(MapType(&arena, StringType(), BytesType())); EXPECT_EQ(out.str(), "map"); } } TEST(MapType, Hash) { google::protobuf::Arena arena; EXPECT_EQ(absl::HashOf(MapType(&arena, StringType(), BytesType())), absl::HashOf(MapType(&arena, StringType(), BytesType()))); } TEST(MapType, Equal) { google::protobuf::Arena arena; EXPECT_EQ(MapType(&arena, StringType(), BytesType()), MapType(&arena, StringType(), BytesType())); EXPECT_EQ(Type(MapType(&arena, StringType(), BytesType())), MapType(&arena, StringType(), BytesType())); EXPECT_EQ(MapType(&arena, StringType(), BytesType()), Type(MapType(&arena, StringType(), BytesType()))); EXPECT_EQ(Type(MapType(&arena, StringType(), BytesType())), Type(MapType(&arena, StringType(), BytesType()))); } } // namespace } // namespace cel ================================================ FILE: common/types/message_type.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "common/type.h" #include "google/protobuf/descriptor.h" namespace cel { using google::protobuf::Descriptor; bool IsWellKnownMessageType(const Descriptor* absl_nonnull descriptor) { switch (descriptor->well_known_type()) { case Descriptor::WELLKNOWNTYPE_BOOLVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_INT32VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_INT64VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_UINT32VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_UINT64VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_FLOATVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_BYTESVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_STRINGVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_ANY: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_DURATION: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_TIMESTAMP: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_LISTVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_STRUCT: return true; default: return false; } } std::string MessageType::DebugString() const { if (ABSL_PREDICT_TRUE(static_cast(*this))) { static_assert(sizeof(descriptor_) == 8 || sizeof(descriptor_) == 4, "sizeof(void*) is neither 8 nor 4"); return absl::StrCat(name(), "@0x", absl::Hex(descriptor_, sizeof(descriptor_) == 8 ? absl::PadSpec::kZeroPad16 : absl::PadSpec::kZeroPad8)); } return std::string(); } std::string MessageTypeField::DebugString() const { if (ABSL_PREDICT_TRUE(static_cast(*this))) { static_assert(sizeof(descriptor_) == 8 || sizeof(descriptor_) == 4, "sizeof(void*) is neither 8 nor 4"); return absl::StrCat("[", (*this)->number(), "]", (*this)->name(), "@0x", absl::Hex(descriptor_, sizeof(descriptor_) == 8 ? absl::PadSpec::kZeroPad16 : absl::PadSpec::kZeroPad8)); } return std::string(); } Type MessageTypeField::GetType() const { ABSL_DCHECK(*this); return Type::Field(descriptor_); } } // namespace cel ================================================ FILE: common/types/message_type.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/types/struct_type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_MESSAGE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_MESSAGE_TYPE_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "common/type_kind.h" #include "google/protobuf/descriptor.h" namespace cel { class Type; class TypeParameters; bool IsWellKnownMessageType(const google::protobuf::Descriptor* absl_nonnull descriptor); class MessageTypeField; class MessageType final { public: using element_type = const google::protobuf::Descriptor; static constexpr TypeKind kKind = TypeKind::kStruct; // Constructs `MessageType` from a pointer to `google::protobuf::Descriptor`. The // `google::protobuf::Descriptor` must not be one of the well known message types we // treat specially, if it is behavior is undefined. If you are unsure, you // should use `Type::Message`. explicit MessageType(const google::protobuf::Descriptor* absl_nullable descriptor) : descriptor_(descriptor) { ABSL_DCHECK(descriptor == nullptr || !IsWellKnownMessageType(descriptor)) << descriptor->full_name(); } // Constructs a `MessageType` in an empty state. // // Most operations on an empty `MessageType` result in undefined behavior. Use // `operator bool` to test if a `MessageType` is empty. MessageType() = default; MessageType(const MessageType&) = default; MessageType(MessageType&&) = default; MessageType& operator=(const MessageType&) = default; MessageType& operator=(MessageType&&) = default; static TypeKind kind() { return kKind; } absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return (*this)->full_name(); } std::string DebugString() const; static TypeParameters GetParameters(); const google::protobuf::Descriptor& operator*() const { ABSL_DCHECK(*this); return *descriptor_; } const google::protobuf::Descriptor* absl_nonnull operator->() const { ABSL_DCHECK(*this); return descriptor_; } explicit operator bool() const { return descriptor_ != nullptr; } private: friend struct std::pointer_traits; const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; }; inline bool operator==(MessageType lhs, MessageType rhs) { return static_cast(lhs) == static_cast(rhs) && (!static_cast(lhs) || lhs.name() == rhs.name()); } inline bool operator!=(MessageType lhs, MessageType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, MessageType message_type) { return H::combine(std::move(state), static_cast(message_type) ? message_type.name() : absl::string_view()); } inline std::ostream& operator<<(std::ostream& out, MessageType type) { return out << type.DebugString(); } } // namespace cel namespace std { template <> struct pointer_traits { using pointer = cel::MessageType; using element_type = typename cel::MessageType::element_type; using difference_type = ptrdiff_t; static element_type* to_address(const pointer& p) noexcept { return p.descriptor_; } }; } // namespace std namespace cel { class MessageTypeField final { public: using element_type = const google::protobuf::FieldDescriptor; explicit MessageTypeField( const google::protobuf::FieldDescriptor* absl_nullable descriptor) : descriptor_(descriptor) {} MessageTypeField() = default; MessageTypeField(const MessageTypeField&) = default; MessageTypeField(MessageTypeField&&) = default; MessageTypeField& operator=(const MessageTypeField&) = default; MessageTypeField& operator=(MessageTypeField&&) = default; std::string DebugString() const; absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return (*this)->name(); } int32_t number() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return (*this)->number(); } Type GetType() const; const google::protobuf::FieldDescriptor& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(*this); return *descriptor_; } const google::protobuf::FieldDescriptor* absl_nonnull operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(*this); return descriptor_; } explicit operator bool() const { return descriptor_ != nullptr; } private: friend struct std::pointer_traits; const google::protobuf::FieldDescriptor* absl_nullable descriptor_ = nullptr; }; } // namespace cel namespace std { template <> struct pointer_traits { using pointer = cel::MessageTypeField; using element_type = typename cel::MessageTypeField::element_type; using difference_type = ptrdiff_t; static element_type* to_address(const pointer& p) noexcept { return p.descriptor_; } }; } // namespace std #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_MESSAGE_TYPE_H_ ================================================ FILE: common/types/message_type_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "google/protobuf/descriptor.pb.h" #include "common/memory.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { using ::testing::An; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::NotNull; using ::testing::Optional; using ::testing::StartsWith; TEST(MessageType, Kind) { EXPECT_EQ(MessageType::kind(), TypeKind::kStruct); } TEST(MessageType, Default) { MessageType type; EXPECT_FALSE(type); EXPECT_THAT(type.DebugString(), Eq("")); EXPECT_EQ(type, MessageType()); } TEST(MessageType, Descriptor) { google::protobuf::DescriptorPool pool; { google::protobuf::FileDescriptorProto file_desc_proto; file_desc_proto.set_syntax("proto3"); file_desc_proto.set_package("test"); file_desc_proto.set_name("test/struct.proto"); file_desc_proto.add_message_type()->set_name("Struct"); ASSERT_THAT(pool.BuildFile(file_desc_proto), NotNull()); } const google::protobuf::Descriptor* desc = pool.FindMessageTypeByName("test.Struct"); ASSERT_THAT(desc, NotNull()); MessageType type(desc); EXPECT_TRUE(type); EXPECT_THAT(type.name(), Eq("test.Struct")); EXPECT_THAT(type.DebugString(), StartsWith("test.Struct@0x")); EXPECT_THAT(type.GetParameters(), IsEmpty()); EXPECT_NE(type, MessageType()); EXPECT_NE(MessageType(), type); EXPECT_EQ(cel::to_address(type), desc); } TEST(MessageTypeField, Descriptor) { google::protobuf::DescriptorPool pool; { google::protobuf::FileDescriptorProto file_desc_proto; file_desc_proto.set_syntax("proto3"); file_desc_proto.set_package("test"); file_desc_proto.set_name("test/struct.proto"); auto* message_type = file_desc_proto.add_message_type(); message_type->set_name("Struct"); auto* field = message_type->add_field(); field->set_name("foo"); field->set_json_name("foo"); field->set_number(1); field->set_type(google::protobuf::FieldDescriptorProto::TYPE_INT64); field->set_label(google::protobuf::FieldDescriptorProto::LABEL_OPTIONAL); ASSERT_THAT(pool.BuildFile(file_desc_proto), NotNull()); } const google::protobuf::Descriptor* desc = pool.FindMessageTypeByName("test.Struct"); ASSERT_THAT(desc, NotNull()); const google::protobuf::FieldDescriptor* field_desc = desc->FindFieldByName("foo"); ASSERT_THAT(desc, NotNull()); MessageTypeField message_type_field(field_desc); EXPECT_TRUE(message_type_field); EXPECT_THAT(message_type_field.name(), Eq("foo")); EXPECT_THAT(message_type_field.DebugString(), StartsWith("[1]foo@0x")); EXPECT_THAT(message_type_field.number(), Eq(1)); EXPECT_THAT(message_type_field.GetType(), IntType()); EXPECT_EQ(cel::to_address(message_type_field), field_desc); StructTypeField struct_type_field = message_type_field; EXPECT_TRUE(struct_type_field.IsMessage()); EXPECT_THAT(struct_type_field.AsMessage(), Optional(An())); EXPECT_THAT(static_cast(struct_type_field), An()); EXPECT_EQ(struct_type_field.name(), message_type_field.name()); EXPECT_EQ(struct_type_field.number(), message_type_field.number()); EXPECT_EQ(struct_type_field.GetType(), message_type_field.GetType()); } } // namespace } // namespace cel ================================================ FILE: common/types/null_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_NULL_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_NULL_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `NullType` represents the primitive `null_type` type. class NullType final { public: static constexpr TypeKind kKind = TypeKind::kNull; static constexpr absl::string_view kName = "null_type"; NullType() = default; NullType(const NullType&) = default; NullType(NullType&&) = default; NullType& operator=(const NullType&) = default; NullType& operator=(NullType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(NullType, NullType) { return true; } inline constexpr bool operator!=(NullType lhs, NullType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, NullType) { // NullType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const NullType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_NULL_TYPE_H_ ================================================ FILE: common/types/null_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(NullType, Kind) { EXPECT_EQ(NullType().kind(), NullType::kKind); EXPECT_EQ(Type(NullType()).kind(), NullType::kKind); } TEST(NullType, Name) { EXPECT_EQ(NullType().name(), NullType::kName); EXPECT_EQ(Type(NullType()).name(), NullType::kName); } TEST(NullType, DebugString) { { std::ostringstream out; out << NullType(); EXPECT_EQ(out.str(), NullType::kName); } { std::ostringstream out; out << Type(NullType()); EXPECT_EQ(out.str(), NullType::kName); } } TEST(NullType, Hash) { EXPECT_EQ(absl::HashOf(NullType()), absl::HashOf(NullType())); } TEST(NullType, Equal) { EXPECT_EQ(NullType(), NullType()); EXPECT_EQ(Type(NullType()), NullType()); EXPECT_EQ(NullType(), Type(NullType())); EXPECT_EQ(Type(NullType()), Type(NullType())); } } // namespace } // namespace cel ================================================ FILE: common/types/opaque_type.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/utility/utility.h" #include "common/type.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" namespace cel { namespace { std::string OpaqueDebugString(absl::string_view name, absl::Span parameters) { if (parameters.empty()) { return std::string(name); } return absl::StrCat(name, "<", absl::StrJoin(parameters, ", ", [](std::string* out, const Type& type) { absl::StrAppend( out, TypeKindToString(type.kind())); }), ">"); } } // namespace namespace common_internal { OpaqueTypeData* absl_nonnull OpaqueTypeData::Create( google::protobuf::Arena* absl_nonnull arena, absl::string_view name, absl::Span parameters) { return ::new (arena->AllocateAligned( offsetof(OpaqueTypeData, parameters) + (parameters.size() * sizeof(Type)), alignof(OpaqueTypeData))) OpaqueTypeData(name, parameters); } OpaqueTypeData::OpaqueTypeData(absl::string_view name, absl::Span parameters) : name(name), parameters_size(parameters.size()) { std::memcpy(this->parameters, parameters.data(), parameters_size * sizeof(Type)); } } // namespace common_internal OpaqueType::OpaqueType(google::protobuf::Arena* absl_nonnull arena, absl::string_view name, absl::Span parameters) : OpaqueType( common_internal::OpaqueTypeData::Create(arena, name, parameters)) {} std::string OpaqueType::DebugString() const { ABSL_DCHECK(*this); return OpaqueDebugString(name(), GetParameters()); } absl::string_view OpaqueType::name() const { ABSL_DCHECK(*this); return data_->name; } TypeParameters OpaqueType::GetParameters() const { ABSL_DCHECK(*this); return TypeParameters( absl::MakeConstSpan(data_->parameters, data_->parameters_size)); } bool OpaqueType::IsOptional() const { return name() == OptionalType::kName && GetParameters().size() == 1; } absl::optional OpaqueType::AsOptional() const { if (IsOptional()) { return OptionalType(absl::in_place, *this); } return absl::nullopt; } OptionalType OpaqueType::GetOptional() const { ABSL_DCHECK(IsOptional()) << DebugString(); return OptionalType(absl::in_place, *this); } } // namespace cel ================================================ FILE: common/types/opaque_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" // IWYU pragma: friend "common/types/optional_type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" namespace cel { class Type; class OptionalType; class TypeParameters; namespace common_internal { struct OpaqueTypeData; } // namespace common_internal class OpaqueType final { public: static constexpr TypeKind kKind = TypeKind::kOpaque; // `name` must outlive the instance. OpaqueType(google::protobuf::Arena* absl_nonnull arena, absl::string_view name, absl::Span parameters); // NOLINTNEXTLINE(google-explicit-constructor) OpaqueType(OptionalType type); // NOLINTNEXTLINE(google-explicit-constructor) OpaqueType& operator=(OptionalType type); OpaqueType() = default; OpaqueType(const OpaqueType&) = default; OpaqueType(OpaqueType&&) = default; OpaqueType& operator=(const OpaqueType&) = default; OpaqueType& operator=(OpaqueType&&) = default; static TypeKind kind() { return kKind; } absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; std::string DebugString() const; TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; explicit operator bool() const { return data_ != nullptr; } bool IsOptional() const; template std::enable_if_t, bool> Is() const { return IsOptional(); } absl::optional AsOptional() const; template std::enable_if_t, absl::optional> As() const; OptionalType GetOptional() const; template std::enable_if_t, OptionalType> Get() const; private: friend class OptionalType; constexpr explicit OpaqueType( const common_internal::OpaqueTypeData* absl_nullable data) : data_(data) {} const common_internal::OpaqueTypeData* absl_nullable data_ = nullptr; }; bool operator==(const OpaqueType& lhs, const OpaqueType& rhs); inline bool operator!=(const OpaqueType& lhs, const OpaqueType& rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const OpaqueType& type); inline std::ostream& operator<<(std::ostream& out, const OpaqueType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_H_ ================================================ FILE: common/types/opaque_type_pool.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/types/opaque_type_pool.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/type.h" namespace cel::common_internal { OpaqueType OpaqueTypePool::InternOpaqueType(absl::string_view name, absl::Span parameters) { if (name.empty() && parameters.empty()) { return OpaqueType(); } return *opaque_types_.lazy_emplace( AsTuple(name, parameters), [&](const auto& ctor) { ctor(OpaqueType(arena_, name, parameters)); }); } } // namespace cel::common_internal ================================================ FILE: common/types/opaque_type_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_POOL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_POOL_H_ #include #include #include #include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "absl/log/die_if_null.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/type.h" #include "google/protobuf/arena.h" namespace cel::common_internal { // `OpaqueTypePool` is a thread unsafe interning factory for `OpaqueType`. class OpaqueTypePool final { public: explicit OpaqueTypePool(google::protobuf::Arena* absl_nonnull arena) : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK // Returns a `OpaqueType` which has the provided parameters, interning as // necessary. OpaqueType InternOpaqueType(absl::string_view name, absl::Span parameters); private: using OpaqueTypeTuple = std::tuple>; static OpaqueTypeTuple AsTuple(const OpaqueType& opaque_type) { return AsTuple(opaque_type.name(), opaque_type.GetParameters()); } static OpaqueTypeTuple AsTuple(absl::string_view name, absl::Span parameters) { return OpaqueTypeTuple{name, parameters}; } struct Hasher { using is_transparent = void; size_t operator()(const OpaqueType& data) const { return (*this)(AsTuple(data)); } size_t operator()(const OpaqueTypeTuple& tuple) const { return absl::Hash{}(tuple); } }; struct Equaler { using is_transparent = void; bool operator()(const OpaqueType& lhs, const OpaqueType& rhs) const { return (*this)(AsTuple(lhs), AsTuple(rhs)); } bool operator()(const OpaqueType& lhs, const OpaqueTypeTuple& rhs) const { return (*this)(AsTuple(lhs), rhs); } bool operator()(const OpaqueTypeTuple& lhs, const OpaqueType& rhs) const { return (*this)(lhs, AsTuple(rhs)); } bool operator()(const OpaqueTypeTuple& lhs, const OpaqueTypeTuple& rhs) const { return std::get<0>(lhs) == std::get<0>(rhs) && absl::c_equal(std::get<1>(lhs), std::get<1>(rhs)); } }; google::protobuf::Arena* absl_nonnull const arena_; absl::flat_hash_set opaque_types_; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPAQUE_TYPE_POOL_H_ ================================================ FILE: common/types/opaque_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { TEST(OpaqueType, Kind) { google::protobuf::Arena arena; EXPECT_EQ(OpaqueType(&arena, "test.Opaque", {BytesType()}).kind(), OpaqueType::kKind); EXPECT_EQ(Type(OpaqueType(&arena, "test.Opaque", {BytesType()})).kind(), OpaqueType::kKind); } TEST(OpaqueType, Name) { google::protobuf::Arena arena; EXPECT_EQ(OpaqueType(&arena, "test.Opaque", {BytesType()}).name(), "test.Opaque"); EXPECT_EQ(Type(OpaqueType(&arena, "test.Opaque", {BytesType()})).name(), "test.Opaque"); } TEST(OpaqueType, DebugString) { google::protobuf::Arena arena; { std::ostringstream out; out << OpaqueType(&arena, "test.Opaque", {BytesType()}); EXPECT_EQ(out.str(), "test.Opaque"); } { std::ostringstream out; out << Type(OpaqueType(&arena, "test.Opaque", {BytesType()})); EXPECT_EQ(out.str(), "test.Opaque"); } { std::ostringstream out; out << OpaqueType(&arena, "test.Opaque", {}); EXPECT_EQ(out.str(), "test.Opaque"); } } TEST(OpaqueType, Hash) { google::protobuf::Arena arena; EXPECT_EQ(absl::HashOf(OpaqueType(&arena, "test.Opaque", {BytesType()})), absl::HashOf(OpaqueType(&arena, "test.Opaque", {BytesType()}))); } TEST(OpaqueType, Equal) { google::protobuf::Arena arena; EXPECT_EQ(OpaqueType(&arena, "test.Opaque", {BytesType()}), OpaqueType(&arena, "test.Opaque", {BytesType()})); EXPECT_EQ(Type(OpaqueType(&arena, "test.Opaque", {BytesType()})), OpaqueType(&arena, "test.Opaque", {BytesType()})); EXPECT_EQ(OpaqueType(&arena, "test.Opaque", {BytesType()}), Type(OpaqueType(&arena, "test.Opaque", {BytesType()}))); EXPECT_EQ(Type(OpaqueType(&arena, "test.Opaque", {BytesType()})), Type(OpaqueType(&arena, "test.Opaque", {BytesType()}))); } } // namespace } // namespace cel ================================================ FILE: common/types/optional_type.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/base/attributes.h" #include "absl/strings/string_view.h" #include "common/type.h" namespace cel { namespace common_internal { namespace { struct OptionalTypeData final { const absl::string_view name; const size_t parameters_size; const Type parameter; }; // Here by dragons. In order to make `OptionalType` default constructible // without some sort of dynamic static initializer, we perform some // type-punning. `OptionalTypeData` and `OpaqueTypeData` must have the same // layout, with the only exception being that `OptionalTypeData` as a single // `Type` where `OpaqueTypeData` as a flexible array. union DynOptionalTypeData final { OptionalTypeData optional; OpaqueTypeData opaque; }; static_assert(offsetof(OptionalTypeData, name) == offsetof(OpaqueTypeData, name)); static_assert(offsetof(OptionalTypeData, parameters_size) == offsetof(OpaqueTypeData, parameters_size)); static_assert(offsetof(OptionalTypeData, parameter) == offsetof(OpaqueTypeData, parameters)); ABSL_CONST_INIT const DynOptionalTypeData kDynOptionalTypeData = { .optional = { .name = OptionalType::kName, .parameters_size = 1, .parameter = DynType(), }, }; } // namespace } // namespace common_internal OptionalType::OptionalType() : opaque_(&common_internal::kDynOptionalTypeData.opaque) {} Type OptionalType::GetParameter() const { return GetParameters().front(); } } // namespace cel ================================================ FILE: common/types/optional_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPTIONAL_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPTIONAL_TYPE_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/utility/utility.h" #include "common/type_kind.h" #include "common/types/opaque_type.h" #include "google/protobuf/arena.h" namespace cel { class Type; class TypeParameters; class OptionalType final { public: static constexpr TypeKind kKind = TypeKind::kOpaque; static constexpr absl::string_view kName = "optional_type"; // By default, this type is `optional(dyn)`. Unless you can help it, you // should choose a more specific optional type. OptionalType(); OptionalType(google::protobuf::Arena* absl_nonnull arena, const Type& parameter) : OptionalType( absl::in_place, OpaqueType(arena, kName, absl::MakeConstSpan(¶meter, 1))) {} static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } std::string DebugString() const { return opaque_.DebugString(); } TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; Type GetParameter() const ABSL_ATTRIBUTE_LIFETIME_BOUND; explicit operator bool() const { return static_cast(opaque_); } template friend H AbslHashValue(H state, const OptionalType& type) { return H::combine(std::move(state), type.opaque_); } friend bool operator==(const OptionalType& lhs, const OptionalType& rhs) { return lhs.opaque_ == rhs.opaque_; } private: friend class OpaqueType; OptionalType(absl::in_place_t, OpaqueType type) : opaque_(std::move(type)) {} OpaqueType opaque_; }; inline bool operator!=(const OptionalType& lhs, const OptionalType& rhs) { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, const OptionalType& type) { return out << type.DebugString(); } inline OpaqueType::OpaqueType(OptionalType type) : OpaqueType(std::move(type.opaque_)) {} inline OpaqueType& OpaqueType::operator=(OptionalType type) { return *this = std::move(type.opaque_); } template inline std::enable_if_t, absl::optional> OpaqueType::As() const { return AsOptional(); } template inline std::enable_if_t, OptionalType> OpaqueType::Get() const { return GetOptional(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_OPTIONAL_TYPE_H_ ================================================ FILE: common/types/optional_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel { namespace { TEST(OptionalType, Default) { OptionalType optional_type; EXPECT_EQ(optional_type.GetParameter(), DynType()); } TEST(OptionalType, Kind) { google::protobuf::Arena arena; EXPECT_EQ(OptionalType(&arena, BoolType()).kind(), OptionalType::kKind); EXPECT_EQ(Type(OptionalType(&arena, BoolType())).kind(), OptionalType::kKind); } TEST(OptionalType, Name) { google::protobuf::Arena arena; EXPECT_EQ(OptionalType(&arena, BoolType()).name(), OptionalType::kName); EXPECT_EQ(Type(OptionalType(&arena, BoolType())).name(), OptionalType::kName); } TEST(OptionalType, DebugString) { google::protobuf::Arena arena; { std::ostringstream out; out << OptionalType(&arena, BoolType()); EXPECT_EQ(out.str(), "optional_type"); } { std::ostringstream out; out << Type(OptionalType(&arena, BoolType())); EXPECT_EQ(out.str(), "optional_type"); } } TEST(OptionalType, Parameter) { google::protobuf::Arena arena; EXPECT_EQ(OptionalType(&arena, BoolType()).GetParameter(), BoolType()); } TEST(OptionalType, Hash) { google::protobuf::Arena arena; EXPECT_EQ(absl::HashOf(OptionalType(&arena, BoolType())), absl::HashOf(OptionalType(&arena, BoolType()))); } TEST(OptionalType, Equal) { google::protobuf::Arena arena; EXPECT_EQ(OptionalType(&arena, BoolType()), OptionalType(&arena, BoolType())); EXPECT_EQ(Type(OptionalType(&arena, BoolType())), OptionalType(&arena, BoolType())); EXPECT_EQ(OptionalType(&arena, BoolType()), Type(OptionalType(&arena, BoolType()))); EXPECT_EQ(Type(OptionalType(&arena, BoolType())), Type(OptionalType(&arena, BoolType()))); } } // namespace } // namespace cel ================================================ FILE: common/types/string_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `StringType` represents the primitive `string` type. class StringType final { public: static constexpr TypeKind kKind = TypeKind::kString; static constexpr absl::string_view kName = "string"; StringType() = default; StringType(const StringType&) = default; StringType(StringType&&) = default; StringType& operator=(const StringType&) = default; StringType& operator=(StringType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); std::string DebugString() const { return std::string(name()); } }; inline constexpr bool operator==(StringType, StringType) { return true; } inline constexpr bool operator!=(StringType lhs, StringType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, StringType) { // StringType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const StringType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_TYPE_H_ ================================================ FILE: common/types/string_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(StringType, Kind) { EXPECT_EQ(StringType().kind(), StringType::kKind); EXPECT_EQ(Type(StringType()).kind(), StringType::kKind); } TEST(StringType, Name) { EXPECT_EQ(StringType().name(), StringType::kName); EXPECT_EQ(Type(StringType()).name(), StringType::kName); } TEST(StringType, DebugString) { { std::ostringstream out; out << StringType(); EXPECT_EQ(out.str(), StringType::kName); } { std::ostringstream out; out << Type(StringType()); EXPECT_EQ(out.str(), StringType::kName); } } TEST(StringType, Hash) { EXPECT_EQ(absl::HashOf(StringType()), absl::HashOf(StringType())); } TEST(StringType, Equal) { EXPECT_EQ(StringType(), StringType()); EXPECT_EQ(Type(StringType()), StringType()); EXPECT_EQ(StringType(), Type(StringType())); EXPECT_EQ(Type(StringType()), Type(StringType())); } } // namespace } // namespace cel ================================================ FILE: common/types/string_wrapper_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_WRAPPER_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_WRAPPER_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `StringWrapperType` is a special type which has no direct value // representation. It is used to represent `google.protobuf.StringValue`, which // never exists at runtime as a value. Its primary usage is for type checking // and unpacking at runtime. class StringWrapperType final { public: static constexpr TypeKind kKind = TypeKind::kStringWrapper; static constexpr absl::string_view kName = "google.protobuf.StringValue"; StringWrapperType() = default; StringWrapperType(const StringWrapperType&) = default; StringWrapperType(StringWrapperType&&) = default; StringWrapperType& operator=(const StringWrapperType&) = default; StringWrapperType& operator=(StringWrapperType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } constexpr void swap(StringWrapperType&) noexcept {} }; inline constexpr void swap(StringWrapperType& lhs, StringWrapperType& rhs) noexcept { lhs.swap(rhs); } inline constexpr bool operator==(StringWrapperType, StringWrapperType) { return true; } inline constexpr bool operator!=(StringWrapperType lhs, StringWrapperType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, StringWrapperType) { // StringWrapperType is really a singleton and all instances are equal. // Nothing to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const StringWrapperType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRING_WRAPPER_TYPE_H_ ================================================ FILE: common/types/string_wrapper_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(StringWrapperType, Kind) { EXPECT_EQ(StringWrapperType().kind(), StringWrapperType::kKind); EXPECT_EQ(Type(StringWrapperType()).kind(), StringWrapperType::kKind); } TEST(StringWrapperType, Name) { EXPECT_EQ(StringWrapperType().name(), StringWrapperType::kName); EXPECT_EQ(Type(StringWrapperType()).name(), StringWrapperType::kName); } TEST(StringWrapperType, DebugString) { { std::ostringstream out; out << StringWrapperType(); EXPECT_EQ(out.str(), StringWrapperType::kName); } { std::ostringstream out; out << Type(StringWrapperType()); EXPECT_EQ(out.str(), StringWrapperType::kName); } } TEST(StringWrapperType, Hash) { EXPECT_EQ(absl::HashOf(StringWrapperType()), absl::HashOf(StringWrapperType())); } TEST(StringWrapperType, Equal) { EXPECT_EQ(StringWrapperType(), StringWrapperType()); EXPECT_EQ(Type(StringWrapperType()), StringWrapperType()); EXPECT_EQ(StringWrapperType(), Type(StringWrapperType())); EXPECT_EQ(Type(StringWrapperType()), Type(StringWrapperType())); } } // namespace } // namespace cel ================================================ FILE: common/types/struct_type.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "common/type.h" #include "common/types/types.h" namespace cel { absl::string_view StructType::name() const { ABSL_DCHECK(*this); return absl::visit( absl::Overload([](absl::monostate) { return absl::string_view(); }, [](const common_internal::BasicStructType& alt) { return alt.name(); }, [](const MessageType& alt) { return alt.name(); }), variant_); } TypeParameters StructType::GetParameters() const { ABSL_DCHECK(*this); return absl::visit( absl::Overload( [](absl::monostate) { return TypeParameters(); }, [](const common_internal::BasicStructType& alt) { return alt.GetParameters(); }, [](const MessageType& alt) { return alt.GetParameters(); }), variant_); } std::string StructType::DebugString() const { return absl::visit( absl::Overload([](absl::monostate) { return std::string(); }, [](common_internal::BasicStructType alt) { return alt.DebugString(); }, [](MessageType alt) { return alt.DebugString(); }), variant_); } absl::optional StructType::AsMessage() const { if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } return absl::nullopt; } MessageType StructType::GetMessage() const { ABSL_DCHECK(IsMessage()) << DebugString(); return absl::get(variant_); } common_internal::TypeVariant StructType::ToTypeVariant() const { return absl::visit( absl::Overload( [](absl::monostate) { return common_internal::TypeVariant(); }, [](common_internal::BasicStructType alt) { return static_cast(alt) ? common_internal::TypeVariant(alt) : common_internal::TypeVariant(); }, [](MessageType alt) { return static_cast(alt) ? common_internal::TypeVariant(alt) : common_internal::TypeVariant(); }), variant_); } } // namespace cel ================================================ FILE: common/types/struct_type.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRUCT_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRUCT_TYPE_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/optimization.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "common/type_kind.h" #include "common/types/basic_struct_type.h" #include "common/types/message_type.h" #include "common/types/types.h" namespace cel { class Type; class TypeParameters; class StructType final { public: static constexpr TypeKind kKind = TypeKind::kStruct; // NOLINTNEXTLINE(google-explicit-constructor) StructType(MessageType other) : StructType() { if (ABSL_PREDICT_TRUE(other)) { variant_.emplace(other); } } // NOLINTNEXTLINE(google-explicit-constructor) StructType(common_internal::BasicStructType other) : StructType() { if (ABSL_PREDICT_TRUE(other)) { variant_.emplace(other); } } // NOLINTNEXTLINE(google-explicit-constructor) StructType& operator=(MessageType other) { if (ABSL_PREDICT_TRUE(other)) { variant_.emplace(other); } else { variant_.emplace(); } return *this; } // NOLINTNEXTLINE(google-explicit-constructor) StructType& operator=(common_internal::BasicStructType other) { if (ABSL_PREDICT_TRUE(other)) { variant_.emplace(other); } else { variant_.emplace(); } return *this; } StructType() = default; StructType(const StructType&) = default; StructType(StructType&&) = default; StructType& operator=(const StructType&) = default; StructType& operator=(StructType&&) = default; static TypeKind kind() { return kKind; } absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND; TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; std::string DebugString() const; bool IsMessage() const { return absl::holds_alternative(variant_); } template std::enable_if_t, bool> Is() const { return IsMessage(); } absl::optional AsMessage() const; template std::enable_if_t, absl::optional> As() const { return AsMessage(); } MessageType GetMessage() const; template std::enable_if_t, MessageType> Get() const { return GetMessage(); } explicit operator bool() const { return !absl::holds_alternative(variant_); } private: friend class Type; friend class MessageType; friend class common_internal::BasicStructType; common_internal::TypeVariant ToTypeVariant() const; // The default state is well formed but invalid. It can be checked by using // the explicit bool operator. This is to allow cases where you want to // construct the type and later assign to it before using it. It is required // that any instance returned from a function call or passed to a function // call must not be in the default state. common_internal::StructTypeVariant variant_; }; inline bool operator==(const StructType& lhs, const StructType& rhs) { return static_cast(lhs) == static_cast(rhs) && (!static_cast(lhs) || lhs.name() == rhs.name()); } inline bool operator!=(const StructType& lhs, const StructType& rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const StructType& type) { return H::combine(std::move(state), static_cast(type) ? type.name() : absl::string_view()); } inline std::ostream& operator<<(std::ostream& out, const StructType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_STRUCT_TYPE_H_ ================================================ FILE: common/types/struct_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "google/protobuf/descriptor.pb.h" #include "absl/base/nullability.h" #include "absl/hash/hash.h" #include "absl/log/absl_check.h" #include "absl/log/die_if_null.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { using ::testing::Test; class StructTypeTest : public Test { public: void SetUp() override { { google::protobuf::FileDescriptorProto file_desc_proto; file_desc_proto.set_syntax("proto3"); file_desc_proto.set_package("test"); file_desc_proto.set_name("test/struct.proto"); file_desc_proto.add_message_type()->set_name("Struct"); ABSL_CHECK(pool_.BuildFile(file_desc_proto) != nullptr); } } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { return ABSL_DIE_IF_NULL(pool_.FindMessageTypeByName("test.Struct")); } MessageType GetMessageType() const { return MessageType(GetDescriptor()); } common_internal::BasicStructType GetBasicStructType() const { return common_internal::MakeBasicStructType("test.Struct"); } private: google::protobuf::DescriptorPool pool_; }; TEST(StructType, Kind) { EXPECT_EQ(StructType::kind(), TypeKind::kStruct); } TEST_F(StructTypeTest, Name) { EXPECT_EQ(StructType(GetMessageType()).name(), GetMessageType().name()); EXPECT_EQ(StructType(GetBasicStructType()).name(), GetBasicStructType().name()); } TEST_F(StructTypeTest, DebugString) { EXPECT_EQ(StructType(GetMessageType()).DebugString(), GetMessageType().DebugString()); EXPECT_EQ(StructType(GetBasicStructType()).DebugString(), GetBasicStructType().DebugString()); } TEST_F(StructTypeTest, Hash) { EXPECT_EQ(absl::HashOf(StructType(GetMessageType())), absl::HashOf(StructType(GetBasicStructType()))); } TEST_F(StructTypeTest, Equal) { EXPECT_EQ(StructType(GetMessageType()), StructType(GetBasicStructType())); } } // namespace } // namespace cel ================================================ FILE: common/types/timestamp_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TIMESTAMP_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TIMESTAMP_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `TimestampType` represents the primitive `timestamp` type. class TimestampType final { public: static constexpr TypeKind kKind = TypeKind::kTimestamp; static constexpr absl::string_view kName = "google.protobuf.Timestamp"; TimestampType() = default; TimestampType(const TimestampType&) = default; TimestampType(TimestampType&&) = default; TimestampType& operator=(const TimestampType&) = default; TimestampType& operator=(TimestampType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(TimestampType, TimestampType) { return true; } inline constexpr bool operator!=(TimestampType lhs, TimestampType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, TimestampType) { // TimestampType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const TimestampType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TIMESTAMP_TYPE_H_ ================================================ FILE: common/types/timestamp_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(TimestampType, Kind) { EXPECT_EQ(TimestampType().kind(), TimestampType::kKind); EXPECT_EQ(Type(TimestampType()).kind(), TimestampType::kKind); } TEST(TimestampType, Name) { EXPECT_EQ(TimestampType().name(), TimestampType::kName); EXPECT_EQ(Type(TimestampType()).name(), TimestampType::kName); } TEST(TimestampType, DebugString) { { std::ostringstream out; out << TimestampType(); EXPECT_EQ(out.str(), TimestampType::kName); } { std::ostringstream out; out << Type(TimestampType()); EXPECT_EQ(out.str(), TimestampType::kName); } } TEST(TimestampType, Hash) { EXPECT_EQ(absl::HashOf(TimestampType()), absl::HashOf(TimestampType())); } TEST(TimestampType, Equal) { EXPECT_EQ(TimestampType(), TimestampType()); EXPECT_EQ(Type(TimestampType()), TimestampType()); EXPECT_EQ(TimestampType(), Type(TimestampType())); EXPECT_EQ(Type(TimestampType()), Type(TimestampType())); } } // namespace } // namespace cel ================================================ FILE: common/types/type_param_type.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_PARAM_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_PARAM_TYPE_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; class TypeParamType final { public: static constexpr TypeKind kKind = TypeKind::kTypeParam; explicit TypeParamType(absl::string_view name ABSL_ATTRIBUTE_LIFETIME_BOUND) : name_(name) {} TypeParamType() = default; TypeParamType(const TypeParamType&) = default; TypeParamType(TypeParamType&&) = default; TypeParamType& operator=(const TypeParamType&) = default; TypeParamType& operator=(TypeParamType&&) = default; static TypeKind kind() { return kKind; } absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return name_; } static TypeParameters GetParameters(); std::string DebugString() const { return std::string(name()); } private: absl::string_view name_; }; inline bool operator==(const TypeParamType& lhs, const TypeParamType& rhs) { return lhs.name() == rhs.name(); } inline bool operator!=(const TypeParamType& lhs, const TypeParamType& rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const TypeParamType& type) { return H::combine(std::move(state), type.name()); } inline std::ostream& operator<<(std::ostream& out, const TypeParamType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_PARAM_TYPE_H_ ================================================ FILE: common/types/type_param_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type.h" #include #include "absl/hash/hash.h" #include "internal/testing.h" namespace cel { namespace { TEST(TypeParamType, Kind) { EXPECT_EQ(TypeParamType("T").kind(), TypeParamType::kKind); EXPECT_EQ(Type(TypeParamType("T")).kind(), TypeParamType::kKind); } TEST(TypeParamType, Name) { EXPECT_EQ(TypeParamType("T").name(), "T"); EXPECT_EQ(Type(TypeParamType("T")).name(), "T"); } TEST(TypeParamType, DebugString) { { std::ostringstream out; out << TypeParamType("T"); EXPECT_EQ(out.str(), "T"); } { std::ostringstream out; out << Type(TypeParamType("T")); EXPECT_EQ(out.str(), "T"); } } TEST(TypeParamType, Hash) { EXPECT_EQ(absl::HashOf(TypeParamType("T")), absl::HashOf(TypeParamType("T"))); } TEST(TypeParamType, Equal) { EXPECT_EQ(TypeParamType("T"), TypeParamType("T")); EXPECT_EQ(Type(TypeParamType("T")), TypeParamType("T")); EXPECT_EQ(TypeParamType("T"), Type(TypeParamType("T"))); EXPECT_EQ(Type(TypeParamType("T")), Type(TypeParamType("T"))); } } // namespace } // namespace cel ================================================ FILE: common/types/type_pool.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/types/type_pool.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "common/type.h" namespace cel::common_internal { StructType TypePool::MakeStructType(absl::string_view name) { ABSL_DCHECK(!IsWellKnownMessageType(name)) << name; if (ABSL_PREDICT_FALSE(name.empty())) { return StructType(); } if (const auto* descriptor = descriptors_->FindMessageTypeByName(name); descriptor != nullptr) { return MessageType(descriptor); } return MakeBasicStructType(InternString(name)); } FunctionType TypePool::MakeFunctionType(const Type& result, absl::Span args) { absl::MutexLock lock(functions_mutex_); return functions_.InternFunctionType(result, args); } ListType TypePool::MakeListType(const Type& element) { if (element.IsDyn()) { return ListType(); } absl::MutexLock lock(lists_mutex_); return lists_.InternListType(element); } MapType TypePool::MakeMapType(const Type& key, const Type& value) { if (key.IsDyn() && value.IsDyn()) { return MapType(); } if (key.IsString() && value.IsDyn()) { return JsonMapType(); } absl::MutexLock lock(maps_mutex_); return maps_.InternMapType(key, value); } OpaqueType TypePool::MakeOpaqueType(absl::string_view name, absl::Span parameters) { if (name == OptionalType::kName) { if (parameters.size() == 1 && parameters.front().IsDyn()) { return OptionalType(); } name = OptionalType::kName; } else { name = InternString(name); } absl::MutexLock lock(opaques_mutex_); return opaques_.InternOpaqueType(name, parameters); } OptionalType TypePool::MakeOptionalType(const Type& parameter) { return MakeOpaqueType(OptionalType::kName, absl::MakeConstSpan(¶meter, 1)) .GetOptional(); } TypeParamType TypePool::MakeTypeParamType(absl::string_view name) { return TypeParamType(InternString(name)); } TypeType TypePool::MakeTypeType(const Type& type) { absl::MutexLock lock(types_mutex_); return types_.InternTypeType(type); } absl::string_view TypePool::InternString(absl::string_view string) { absl::MutexLock lock(strings_mutex_); return strings_.InternString(string); } } // namespace cel::common_internal ================================================ FILE: common/types/type_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" #include "absl/log/die_if_null.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "common/type.h" #include "common/types/function_type_pool.h" #include "common/types/list_type_pool.h" #include "common/types/map_type_pool.h" #include "common/types/opaque_type_pool.h" #include "common/types/type_type_pool.h" #include "internal/string_pool.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel::common_internal { // `TypePool` is a thread safe interning factory for complex types. All types // are allocated using the provided `google::protobuf::Arena`. class TypePool final { public: TypePool(const google::protobuf::DescriptorPool* absl_nonnull descriptors ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : descriptors_(ABSL_DIE_IF_NULL(descriptors)), // Crash OK arena_(ABSL_DIE_IF_NULL(arena)), // Crash OK strings_(arena_), functions_(arena_), lists_(arena_), maps_(arena_), opaques_(arena_), types_(arena_) {} TypePool(const TypePool&) = delete; TypePool(TypePool&&) = delete; TypePool& operator=(const TypePool&) = delete; TypePool& operator=(TypePool&&) = delete; StructType MakeStructType(absl::string_view name); FunctionType MakeFunctionType(const Type& result, absl::Span args); ListType MakeListType(const Type& element); MapType MakeMapType(const Type& key, const Type& value); OpaqueType MakeOpaqueType(absl::string_view name, absl::Span parameters); OptionalType MakeOptionalType(const Type& parameter); TypeParamType MakeTypeParamType(absl::string_view name); TypeType MakeTypeType(const Type& type); private: absl::string_view InternString(absl::string_view string); const google::protobuf::DescriptorPool* absl_nonnull const descriptors_; google::protobuf::Arena* absl_nonnull const arena_; absl::Mutex strings_mutex_; internal::StringPool strings_ ABSL_GUARDED_BY(strings_mutex_); absl::Mutex functions_mutex_; FunctionTypePool functions_ ABSL_GUARDED_BY(functions_mutex_); absl::Mutex lists_mutex_; ListTypePool lists_ ABSL_GUARDED_BY(lists_mutex_); absl::Mutex maps_mutex_; MapTypePool maps_ ABSL_GUARDED_BY(maps_mutex_); absl::Mutex opaques_mutex_; OpaqueTypePool opaques_ ABSL_GUARDED_BY(opaques_mutex_); absl::Mutex types_mutex_; TypeTypePool types_ ABSL_GUARDED_BY(types_mutex_); }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_POOL_H_ ================================================ FILE: common/types/type_pool_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/types/type_pool.h" #include "common/type.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" namespace cel::common_internal { namespace { using ::cel::internal::GetTestingDescriptorPool; using ::testing::_; TEST(TypePool, MakeStructType) { google::protobuf::Arena arena; TypePool type_pool(GetTestingDescriptorPool(), &arena); EXPECT_EQ(type_pool.MakeStructType("foo.Bar"), MakeBasicStructType("foo.Bar")); EXPECT_TRUE( type_pool.MakeStructType("cel.expr.conformance.proto3.TestAllTypes") .IsMessage()); EXPECT_DEBUG_DEATH( static_cast(type_pool.MakeStructType("google.protobuf.BoolValue")), _); } TEST(TypePool, MakeFunctionType) { google::protobuf::Arena arena; TypePool type_pool(GetTestingDescriptorPool(), &arena); EXPECT_EQ(type_pool.MakeFunctionType(BoolType(), {IntType(), IntType()}), FunctionType(&arena, BoolType(), {IntType(), IntType()})); } TEST(TypePool, MakeListType) { google::protobuf::Arena arena; TypePool type_pool(GetTestingDescriptorPool(), &arena); EXPECT_EQ(type_pool.MakeListType(DynType()), ListType()); EXPECT_EQ(type_pool.MakeListType(DynType()), JsonListType()); EXPECT_EQ(type_pool.MakeListType(StringType()), ListType(&arena, StringType())); } TEST(TypePool, MakeMapType) { google::protobuf::Arena arena; TypePool type_pool(GetTestingDescriptorPool(), &arena); EXPECT_EQ(type_pool.MakeMapType(DynType(), DynType()), MapType()); EXPECT_EQ(type_pool.MakeMapType(StringType(), DynType()), JsonMapType()); EXPECT_EQ(type_pool.MakeMapType(StringType(), StringType()), MapType(&arena, StringType(), StringType())); } TEST(TypePool, MakeOpaqueType) { google::protobuf::Arena arena; TypePool type_pool(GetTestingDescriptorPool(), &arena); EXPECT_EQ(type_pool.MakeOpaqueType("custom_type", {DynType(), DynType()}), OpaqueType(&arena, "custom_type", {DynType(), DynType()})); } TEST(TypePool, MakeOptionalType) { google::protobuf::Arena arena; TypePool type_pool(GetTestingDescriptorPool(), &arena); EXPECT_EQ(type_pool.MakeOptionalType(DynType()), OptionalType()); EXPECT_EQ(type_pool.MakeOptionalType(StringType()), OptionalType(&arena, StringType())); } TEST(TypePool, MakeTypeParamType) { google::protobuf::Arena arena; TypePool type_pool(GetTestingDescriptorPool(), &arena); EXPECT_EQ(type_pool.MakeTypeParamType("T"), TypeParamType("T")); } TEST(TypePool, MakeTypeType) { google::protobuf::Arena arena; TypePool type_pool(GetTestingDescriptorPool(), &arena); EXPECT_EQ(type_pool.MakeTypeType(BoolType()), TypeType(&arena, BoolType())); } } // namespace } // namespace cel::common_internal ================================================ FILE: common/types/type_type.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type.h" #include #include "absl/base/nullability.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" namespace cel { namespace common_internal { struct TypeTypeData final { static TypeTypeData* Create(google::protobuf::Arena* absl_nonnull arena, const Type& type) { return google::protobuf::Arena::Create(arena, type); } explicit TypeTypeData(const Type& type) : type(type) {} TypeTypeData() = delete; TypeTypeData(const TypeTypeData&) = delete; TypeTypeData(TypeTypeData&&) = delete; TypeTypeData& operator=(const TypeTypeData&) = delete; TypeTypeData& operator=(TypeTypeData&&) = delete; const Type type; }; } // namespace common_internal std::string TypeType::DebugString() const { std::string s(name()); if (!GetParameters().empty()) { absl::StrAppend(&s, "(", TypeKindToString(GetParameters().front().kind()), ")"); } return s; } TypeType::TypeType(google::protobuf::Arena* absl_nonnull arena, const Type& parameter) : TypeType(common_internal::TypeTypeData::Create(arena, parameter)) {} TypeParameters TypeType::GetParameters() const { if (data_) { return TypeParameters(absl::MakeConstSpan(&data_->type, 1)); } return {}; } Type TypeType::GetType() const { if (data_) { return data_->type; } return Type(); } } // namespace cel ================================================ FILE: common/types/type_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "common/type_kind.h" #include "google/protobuf/arena.h" namespace cel { class Type; class TypeParameters; namespace common_internal { struct TypeTypeData; } // namespace common_internal // `TypeType` is a special type which represents the type of a type. class TypeType final { public: static constexpr TypeKind kKind = TypeKind::kType; static constexpr absl::string_view kName = "type"; TypeType(google::protobuf::Arena* absl_nonnull arena, const Type& parameter); TypeType() = default; TypeType(const TypeType&) = default; TypeType(TypeType&&) = default; TypeType& operator=(const TypeType&) = default; TypeType& operator=(TypeType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } TypeParameters GetParameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND; std::string DebugString() const; Type GetType() const; private: explicit TypeType(const common_internal::TypeTypeData* absl_nullable data) : data_(data) {} const common_internal::TypeTypeData* absl_nullable data_ = nullptr; }; inline constexpr bool operator==(const TypeType&, const TypeType&) { return true; } inline constexpr bool operator!=(const TypeType& lhs, const TypeType& rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, const TypeType&) { // TypeType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const TypeType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_H_ ================================================ FILE: common/types/type_type_pool.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/types/type_type_pool.h" #include "common/type.h" namespace cel::common_internal { TypeType TypeTypePool::InternTypeType(const Type& type) { return *type_types_.lazy_emplace( type, [&](const auto& ctor) { ctor(TypeType(arena_, type)); }); } } // namespace cel::common_internal ================================================ FILE: common/types/type_type_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_POOL_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_POOL_H_ #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "absl/log/absl_check.h" #include "absl/log/die_if_null.h" #include "common/type.h" #include "google/protobuf/arena.h" namespace cel::common_internal { // `TypeTypePool` is a thread unsafe interning factory for `TypeType`. class TypeTypePool final { public: explicit TypeTypePool(google::protobuf::Arena* absl_nonnull arena) : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK // Returns a `TypeType` which has the provided parameters, interning as // necessary. TypeType InternTypeType(const Type& type); private: struct Hasher { using is_transparent = void; size_t operator()(const TypeType& type_type) const { ABSL_DCHECK_EQ(type_type.GetParameters().size(), 1); return (*this)(type_type.GetParameters().front()); } size_t operator()(const Type& type) const { return absl::Hash{}(type); } }; struct Equaler { using is_transparent = void; bool operator()(const TypeType& lhs, const TypeType& rhs) const { ABSL_DCHECK_EQ(lhs.GetParameters().size(), 1); ABSL_DCHECK_EQ(rhs.GetParameters().size(), 1); return (*this)(lhs.GetParameters().front(), rhs.GetParameters().front()); } bool operator()(const TypeType& lhs, const Type& rhs) const { ABSL_DCHECK_EQ(lhs.GetParameters().size(), 1); return (*this)(lhs.GetParameters().front(), rhs); } bool operator()(const Type& lhs, const TypeType& rhs) const { ABSL_DCHECK_EQ(rhs.GetParameters().size(), 1); return (*this)(lhs, rhs.GetParameters().front()); } bool operator()(const Type& lhs, const Type& rhs) const { return lhs == rhs; } }; google::protobuf::Arena* absl_nonnull const arena_; absl::flat_hash_set type_types_; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPE_TYPE_POOL_H_ ================================================ FILE: common/types/type_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/type.h" #include #include "absl/hash/hash.h" #include "internal/testing.h" namespace cel { namespace { TEST(TypeType, Kind) { EXPECT_EQ(TypeType().kind(), TypeType::kKind); EXPECT_EQ(Type(TypeType()).kind(), TypeType::kKind); } TEST(TypeType, Name) { EXPECT_EQ(TypeType().name(), TypeType::kName); EXPECT_EQ(Type(TypeType()).name(), TypeType::kName); } TEST(TypeType, DebugString) { { std::ostringstream out; out << TypeType(); EXPECT_EQ(out.str(), TypeType::kName); } { std::ostringstream out; out << Type(TypeType()); EXPECT_EQ(out.str(), TypeType::kName); } } TEST(TypeType, Hash) { EXPECT_EQ(absl::HashOf(TypeType()), absl::HashOf(TypeType())); } TEST(TypeType, Equal) { EXPECT_EQ(TypeType(), TypeType()); EXPECT_EQ(Type(TypeType()), TypeType()); EXPECT_EQ(TypeType(), Type(TypeType())); EXPECT_EQ(Type(TypeType()), Type(TypeType())); } } // namespace } // namespace cel ================================================ FILE: common/types/types.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPES_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPES_H_ #include #include "absl/meta/type_traits.h" #include "absl/types/variant.h" namespace cel { class Type; class AnyType; class BoolType; class BoolWrapperType; class BytesType; class BytesWrapperType; class DoubleType; class DoubleWrapperType; class DurationType; class DynType; class EnumType; class ErrorType; class FunctionType; class IntType; class IntWrapperType; class ListType; class MapType; class NullType; class OpaqueType; class OptionalType; class StringType; class StringWrapperType; class StructType; class MessageType; class TimestampType; class TypeParamType; class TypeType; class UintType; class UintWrapperType; class UnknownType; namespace common_internal { class BasicStructType; template > struct IsTypeAlternative : std::bool_constant, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same>> {}; template inline constexpr bool IsTypeAlternativeV = IsTypeAlternative::value; using TypeVariant = absl::variant; using StructTypeVariant = absl::variant; } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_TYPES_H_ ================================================ FILE: common/types/uint_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `UintType` represents the primitive `uint` type. class UintType final { public: static constexpr TypeKind kKind = TypeKind::kUint; static constexpr absl::string_view kName = "uint"; UintType() = default; UintType(const UintType&) = default; UintType(UintType&&) = default; UintType& operator=(const UintType&) = default; UintType& operator=(UintType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(UintType, UintType) { return true; } inline constexpr bool operator!=(UintType lhs, UintType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, UintType) { // UintType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const UintType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_TYPE_H_ ================================================ FILE: common/types/uint_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(UintType, Kind) { EXPECT_EQ(UintType().kind(), UintType::kKind); EXPECT_EQ(Type(UintType()).kind(), UintType::kKind); } TEST(UintType, Name) { EXPECT_EQ(UintType().name(), UintType::kName); EXPECT_EQ(Type(UintType()).name(), UintType::kName); } TEST(UintType, DebugString) { { std::ostringstream out; out << UintType(); EXPECT_EQ(out.str(), UintType::kName); } { std::ostringstream out; out << Type(UintType()); EXPECT_EQ(out.str(), UintType::kName); } } TEST(UintType, Hash) { EXPECT_EQ(absl::HashOf(UintType()), absl::HashOf(UintType())); } TEST(UintType, Equal) { EXPECT_EQ(UintType(), UintType()); EXPECT_EQ(Type(UintType()), UintType()); EXPECT_EQ(UintType(), Type(UintType())); EXPECT_EQ(Type(UintType()), Type(UintType())); } } // namespace } // namespace cel ================================================ FILE: common/types/uint_wrapper_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_WRAPPER_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_WRAPPER_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `UintWrapperType` is a special type which has no direct value // representation. It is used to represent `google.protobuf.UInt64Value`, which // never exists at runtime as a value. Its primary usage is for type checking // and unpacking at runtime. class UintWrapperType final { public: static constexpr TypeKind kKind = TypeKind::kUintWrapper; static constexpr absl::string_view kName = "google.protobuf.UInt64Value"; UintWrapperType() = default; UintWrapperType(const UintWrapperType&) = default; UintWrapperType(UintWrapperType&&) = default; UintWrapperType& operator=(const UintWrapperType&) = default; UintWrapperType& operator=(UintWrapperType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(UintWrapperType, UintWrapperType) { return true; } inline constexpr bool operator!=(UintWrapperType lhs, UintWrapperType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, UintWrapperType) { // UintWrapperType is really a singleton and all instances are equal. Nothing // to hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const UintWrapperType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_UINT_WRAPPER_TYPE_H_ ================================================ FILE: common/types/uint_wrapper_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(UintWrapperType, Kind) { EXPECT_EQ(UintWrapperType().kind(), UintWrapperType::kKind); EXPECT_EQ(Type(UintWrapperType()).kind(), UintWrapperType::kKind); } TEST(UintWrapperType, Name) { EXPECT_EQ(UintWrapperType().name(), UintWrapperType::kName); EXPECT_EQ(Type(UintWrapperType()).name(), UintWrapperType::kName); } TEST(UintWrapperType, DebugString) { { std::ostringstream out; out << UintWrapperType(); EXPECT_EQ(out.str(), UintWrapperType::kName); } { std::ostringstream out; out << Type(UintWrapperType()); EXPECT_EQ(out.str(), UintWrapperType::kName); } } TEST(UintWrapperType, Hash) { EXPECT_EQ(absl::HashOf(UintWrapperType()), absl::HashOf(UintWrapperType())); } TEST(UintWrapperType, Equal) { EXPECT_EQ(UintWrapperType(), UintWrapperType()); EXPECT_EQ(Type(UintWrapperType()), UintWrapperType()); EXPECT_EQ(UintWrapperType(), Type(UintWrapperType())); EXPECT_EQ(Type(UintWrapperType()), Type(UintWrapperType())); } } // namespace } // namespace cel ================================================ FILE: common/types/unknown_type.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/type.h" // IWYU pragma: friend "common/type.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_UNKNOWN_TYPE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPES_UNKNOWN_TYPE_H_ #include #include #include #include "absl/strings/string_view.h" #include "common/type_kind.h" namespace cel { class Type; class TypeParameters; // `UnknownType` is a special type which represents an unknown at runtime. It // has no in-language representation. class UnknownType final { public: static constexpr TypeKind kKind = TypeKind::kUnknown; static constexpr absl::string_view kName = "*unknown*"; UnknownType() = default; UnknownType(const UnknownType&) = default; UnknownType(UnknownType&&) = default; UnknownType& operator=(const UnknownType&) = default; UnknownType& operator=(UnknownType&&) = default; static TypeKind kind() { return kKind; } static absl::string_view name() { return kName; } static TypeParameters GetParameters(); static std::string DebugString() { return std::string(name()); } }; inline constexpr bool operator==(UnknownType, UnknownType) { return true; } inline constexpr bool operator!=(UnknownType lhs, UnknownType rhs) { return !operator==(lhs, rhs); } template H AbslHashValue(H state, UnknownType) { // UnknownType is really a singleton and all instances are equal. Nothing to // hash. return std::move(state); } inline std::ostream& operator<<(std::ostream& out, const UnknownType& type) { return out << type.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_UNKNOWN_TYPE_H_ ================================================ FILE: common/types/unknown_type_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "common/type.h" #include "internal/testing.h" namespace cel { namespace { TEST(UnknownType, Kind) { EXPECT_EQ(UnknownType().kind(), UnknownType::kKind); EXPECT_EQ(Type(UnknownType()).kind(), UnknownType::kKind); } TEST(UnknownType, Name) { EXPECT_EQ(UnknownType().name(), UnknownType::kName); EXPECT_EQ(Type(UnknownType()).name(), UnknownType::kName); } TEST(UnknownType, DebugString) { { std::ostringstream out; out << UnknownType(); EXPECT_EQ(out.str(), UnknownType::kName); } { std::ostringstream out; out << Type(UnknownType()); EXPECT_EQ(out.str(), UnknownType::kName); } } TEST(UnknownType, Hash) { EXPECT_EQ(absl::HashOf(UnknownType()), absl::HashOf(UnknownType())); } TEST(UnknownType, Equal) { EXPECT_EQ(UnknownType(), UnknownType()); EXPECT_EQ(Type(UnknownType()), UnknownType()); EXPECT_EQ(UnknownType(), Type(UnknownType())); EXPECT_EQ(Type(UnknownType()), Type(UnknownType())); } } // namespace } // namespace cel ================================================ FILE: common/unknown.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_UNKNOWN_H_ #define THIRD_PARTY_CEL_CPP_COMMON_UNKNOWN_H_ #include "base/internal/unknown_set.h" namespace cel { // `Unknown` is a collection of unknown attributes and function results. using Unknown = base_internal::UnknownSet; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_UNKNOWN_H_ ================================================ FILE: common/value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/value.h" #include #include #include #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/meta/type_traits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "common/allocator.h" #include "common/memory.h" #include "common/optional_ref.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/list_value_builder.h" #include "common/values/map_value_builder.h" #include "common/values/struct_value_builder.h" #include "common/values/values.h" #include "internal/number.h" #include "internal/protobuf_runtime_version.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #undef GetMessage namespace cel { namespace { google::protobuf::Arena* absl_nonnull MessageArenaOr( const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull or_arena) { google::protobuf::Arena* absl_nullable arena = message->GetArena(); if (arena == nullptr) { arena = or_arena; } return arena; } } // namespace Type Value::GetRuntimeType() const { switch (kind()) { case ValueKind::kNull: return NullType(); case ValueKind::kBool: return BoolType(); case ValueKind::kInt: return IntType(); case ValueKind::kUint: return UintType(); case ValueKind::kDouble: return DoubleType(); case ValueKind::kString: return StringType(); case ValueKind::kBytes: return BytesType(); case ValueKind::kStruct: return this->GetStruct().GetRuntimeType(); case ValueKind::kDuration: return DurationType(); case ValueKind::kTimestamp: return TimestampType(); case ValueKind::kList: return ListType(); case ValueKind::kMap: return MapType(); case ValueKind::kUnknown: return UnknownType(); case ValueKind::kType: return TypeType(); case ValueKind::kError: return ErrorType(); case ValueKind::kOpaque: return this->GetOpaque().GetRuntimeType(); default: return cel::Type(); } } namespace { template struct IsMonostate : std::is_same, absl::monostate> {}; } // namespace absl::string_view Value::GetTypeName() const { return variant_.Visit([](const auto& alternative) -> absl::string_view { return alternative.GetTypeName(); }); } std::string Value::DebugString() const { return variant_.Visit([](const auto& alternative) -> std::string { return alternative.DebugString(); }); } absl::Status Value::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.SerializeTo(descriptor_pool, message_factory, output); }); } absl::Status Value::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); return variant_.Visit([descriptor_pool, message_factory, json](const auto& alternative) -> absl::Status { return alternative.ConvertToJson(descriptor_pool, message_factory, json); }); } absl::Status Value::ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); return variant_.Visit(absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( const common_internal::LegacyListValue& alternative) -> absl::Status { return alternative.ConvertToJsonArray(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const CustomListValue& alternative) -> absl::Status { return alternative.ConvertToJsonArray(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const ParsedRepeatedFieldValue& alternative) -> absl::Status { return alternative.ConvertToJsonArray(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const ParsedJsonListValue& alternative) -> absl::Status { return alternative.ConvertToJsonArray(descriptor_pool, message_factory, json); }, [](const auto& alternative) -> absl::Status { return TypeConversionError(alternative.GetTypeName(), "google.protobuf.ListValue") .NativeValue(); })); } absl::Status Value::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); return variant_.Visit(absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( const common_internal::LegacyMapValue& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const CustomMapValue& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const ParsedMapFieldValue& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const ParsedJsonMapValue& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const common_internal::LegacyStructValue& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const CustomStructValue& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }, [descriptor_pool, message_factory, json](const ParsedMessageValue& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }, [](const auto& alternative) -> absl::Status { return TypeConversionError(alternative.GetTypeName(), "google.protobuf.Struct") .NativeValue(); })); } absl::Status Value::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&other, descriptor_pool, message_factory, arena, result](const auto& alternative) -> absl::Status { return alternative.Equal(other, descriptor_pool, message_factory, arena, result); }); } bool Value::IsZeroValue() const { return variant_.Visit([](const auto& alternative) -> bool { return alternative.IsZeroValue(); }); } namespace { template struct HasCloneMethod : std::false_type {}; template struct HasCloneMethod().Clone( std::declval()))>> : std::true_type {}; } // namespace Value Value::Clone(google::protobuf::Arena* absl_nonnull arena) const { return variant_.Visit([arena](const auto& alternative) -> Value { if constexpr (IsMonostate::value) { return Value(); } else if constexpr (HasCloneMethod>::value) { return alternative.Clone(arena); } else { return alternative; } }); } std::ostream& operator<<(std::ostream& out, const Value& value) { return value.variant_.Visit([&out](const auto& alternative) -> std::ostream& { return out << alternative; }); } namespace { Value NonNullEnumValue(const google::protobuf::EnumValueDescriptor* absl_nonnull value) { ABSL_DCHECK(value != nullptr); return IntValue(value->number()); } Value NonNullEnumValue(const google::protobuf::EnumDescriptor* absl_nonnull type, int32_t number) { ABSL_DCHECK(type != nullptr); if (type->is_closed()) { if (ABSL_PREDICT_FALSE(type->FindValueByNumber(number) == nullptr)) { return ErrorValue(absl::InvalidArgumentError(absl::StrCat( "closed enum has no such value: ", type->full_name(), ".", number))); } } return IntValue(number); } } // namespace Value Value::Enum(const google::protobuf::EnumValueDescriptor* absl_nonnull value) { ABSL_DCHECK(value != nullptr); if (value->type()->full_name() == "google.protobuf.NullValue") { ABSL_DCHECK_EQ(value->number(), 0); return NullValue(); } return NonNullEnumValue(value); } Value Value::Enum(const google::protobuf::EnumDescriptor* absl_nonnull type, int32_t number) { ABSL_DCHECK(type != nullptr); if (type->full_name() == "google.protobuf.NullValue") { ABSL_DCHECK_EQ(number, 0); return NullValue(); } return NonNullEnumValue(type, number); } namespace common_internal { namespace { void BoolMapFieldKeyAccessor(const google::protobuf::MapKey& key, const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); *result = BoolValue(key.GetBoolValue()); } void Int32MapFieldKeyAccessor(const google::protobuf::MapKey& key, const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); *result = IntValue(key.GetInt32Value()); } void Int64MapFieldKeyAccessor(const google::protobuf::MapKey& key, const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); *result = IntValue(key.GetInt64Value()); } void UInt32MapFieldKeyAccessor(const google::protobuf::MapKey& key, const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); *result = UintValue(key.GetUInt32Value()); } void UInt64MapFieldKeyAccessor(const google::protobuf::MapKey& key, const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); *result = UintValue(key.GetUInt64Value()); } void StringMapFieldKeyAccessor(const google::protobuf::MapKey& key, const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); #if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) *result = StringValue(Borrower::Arena(MessageArenaOr(message, arena)), key.GetStringValue()); #else *result = StringValue(arena, key.GetStringValue()); #endif } } // namespace absl::StatusOr MapFieldKeyAccessorFor( const google::protobuf::FieldDescriptor* absl_nonnull field) { switch (field->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: return &BoolMapFieldKeyAccessor; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: return &Int32MapFieldKeyAccessor; case google::protobuf::FieldDescriptor::CPPTYPE_INT64: return &Int64MapFieldKeyAccessor; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: return &UInt32MapFieldKeyAccessor; case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: return &UInt64MapFieldKeyAccessor; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: return &StringMapFieldKeyAccessor; default: return absl::InvalidArgumentError( absl::StrCat("unexpected map key type: ", field->cpp_type_name())); } } namespace { void DoubleMapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE); *result = DoubleValue(value.GetDoubleValue()); } void FloatMapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_FLOAT); *result = DoubleValue(value.GetFloatValue()); } void Int64MapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_INT64); *result = IntValue(value.GetInt64Value()); } void UInt64MapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_UINT64); *result = UintValue(value.GetUInt64Value()); } void Int32MapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_INT32); *result = IntValue(value.GetInt32Value()); } void UInt32MapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_UINT32); *result = UintValue(value.GetUInt32Value()); } void BoolMapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_BOOL); *result = BoolValue(value.GetBoolValue()); } void StringMapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->type(), google::protobuf::FieldDescriptor::TYPE_STRING); if (message->GetArena() == nullptr) { *result = StringValue(arena, value.GetStringValue()); } else { *result = StringValue(Borrower::Arena(arena), value.GetStringValue()); } } void MessageMapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE); *result = Value::WrapMessage(&value.GetMessageValue(), descriptor_pool, message_factory, arena); } void BytesMapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->type(), google::protobuf::FieldDescriptor::TYPE_BYTES); if (message->GetArena() == nullptr) { *result = BytesValue(arena, value.GetStringValue()); } else { *result = BytesValue(Borrower::Arena(arena), value.GetStringValue()); } } void EnumMapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_ENUM); *result = NonNullEnumValue(field->enum_type(), value.GetEnumValue()); } void NullMapFieldValueAccessor( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK(field->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM && field->enum_type()->full_name() == "google.protobuf.NullValue"); *result = NullValue(); } } // namespace absl::StatusOr MapFieldValueAccessorFor( const google::protobuf::FieldDescriptor* absl_nonnull field) { switch (field->type()) { case google::protobuf::FieldDescriptor::TYPE_DOUBLE: return &DoubleMapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_FLOAT: return &FloatMapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_INT64: return &Int64MapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_FIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT64: return &UInt64MapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_INT32: return &Int32MapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_BOOL: return &BoolMapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_STRING: return &StringMapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: return &MessageMapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_BYTES: return &BytesMapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT32: return &UInt32MapFieldValueAccessor; case google::protobuf::FieldDescriptor::TYPE_ENUM: if (field->enum_type()->full_name() == "google.protobuf.NullValue") { return &NullMapFieldValueAccessor; } return &EnumMapFieldValueAccessor; default: return absl::InvalidArgumentError( absl::StrCat("unexpected protocol buffer message field type: ", field->type_name())); } } namespace { void DoubleRepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = DoubleValue(reflection->GetRepeatedDouble(*message, field, index)); } void FloatRepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_FLOAT); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = DoubleValue(reflection->GetRepeatedFloat(*message, field, index)); } void Int64RepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_INT64); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = IntValue(reflection->GetRepeatedInt64(*message, field, index)); } void UInt64RepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_UINT64); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = UintValue(reflection->GetRepeatedUInt64(*message, field, index)); } void Int32RepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_INT32); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = IntValue(reflection->GetRepeatedInt32(*message, field, index)); } void UInt32RepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_UINT32); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = UintValue(reflection->GetRepeatedUInt32(*message, field, index)); } void BoolRepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_BOOL); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = BoolValue(reflection->GetRepeatedBool(*message, field, index)); } void StringRepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->type(), google::protobuf::FieldDescriptor::TYPE_STRING); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); std::string scratch; absl::visit( absl::Overload( [&](absl::string_view string) { if (string.data() == scratch.data() && string.size() == scratch.size()) { *result = StringValue(arena, std::move(scratch)); } else { if (message->GetArena() == nullptr) { *result = StringValue(arena, string); } else { *result = StringValue(Borrower::Arena(arena), string); } } }, [&](absl::Cord&& cord) { *result = StringValue(std::move(cord)); }), well_known_types::AsVariant(well_known_types::GetRepeatedStringField( *message, field, index, scratch))); } void MessageRepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = Value::WrapMessage( &reflection->GetRepeatedMessage(*message, field, index), descriptor_pool, message_factory, arena); } void BytesRepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->type(), google::protobuf::FieldDescriptor::TYPE_BYTES); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); std::string scratch; absl::visit( absl::Overload( [&](absl::string_view string) { if (string.data() == scratch.data() && string.size() == scratch.size()) { *result = BytesValue(arena, std::move(scratch)); } else { if (message->GetArena() == nullptr) { *result = BytesValue(arena, string); } else { *result = BytesValue(Borrower::Arena(arena), string); } } }, [&](absl::Cord&& cord) { *result = BytesValue(std::move(cord)); }), well_known_types::AsVariant(well_known_types::GetRepeatedBytesField( *message, field, index, scratch))); } void EnumRepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_ENUM); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = NonNullEnumValue( field->enum_type(), reflection->GetRepeatedEnumValue(*message, field, index)); } void NullRepeatedFieldAccessor( int index, const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(reflection != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK_EQ(reflection, message->GetReflection()); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(field->is_repeated()); ABSL_DCHECK(field->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM && field->enum_type()->full_name() == "google.protobuf.NullValue"); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK_LT(index, reflection->FieldSize(*message, field)); *result = NullValue(); } } // namespace absl::StatusOr RepeatedFieldAccessorFor( const google::protobuf::FieldDescriptor* absl_nonnull field) { switch (field->type()) { case google::protobuf::FieldDescriptor::TYPE_DOUBLE: return &DoubleRepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_FLOAT: return &FloatRepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_INT64: return &Int64RepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_FIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT64: return &UInt64RepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_INT32: return &Int32RepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_BOOL: return &BoolRepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_STRING: return &StringRepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: return &MessageRepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_BYTES: return &BytesRepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT32: return &UInt32RepeatedFieldAccessor; case google::protobuf::FieldDescriptor::TYPE_ENUM: if (field->enum_type()->full_name() == "google.protobuf.NullValue") { return &NullRepeatedFieldAccessor; } return &EnumRepeatedFieldAccessor; default: return absl::InvalidArgumentError( absl::StrCat("unexpected protocol buffer message field type: ", field->type_name())); } } } // namespace common_internal namespace { // Overloads for `well_known_types::Value` which handles the primitive values // which require no special handling based on allocators. Value VistWellKnownTypeValue(std::nullptr_t) { return NullValue(); } Value VistWellKnownTypeValue(bool value) { return BoolValue(value); } Value VistWellKnownTypeValue(int32_t value) { return IntValue(value); } Value VistWellKnownTypeValue(int64_t value) { return IntValue(value); } Value VistWellKnownTypeValue(uint32_t value) { return UintValue(value); } Value VistWellKnownTypeValue(uint64_t value) { return UintValue(value); } Value VistWellKnownTypeValue(float value) { return DoubleValue(value); } Value VistWellKnownTypeValue(double value) { return DoubleValue(value); } Value VistWellKnownTypeValue(absl::Duration value) { return DurationValue(value); } Value VistWellKnownTypeValue(absl::Time value) { return TimestampValue(value); } struct OwningWellKnownTypesValueVisitor { google::protobuf::Arena* absl_nullable arena; std::string* absl_nonnull scratch; Value operator()(well_known_types::BytesValue&& value) const { return absl::visit(absl::Overload( [&](absl::string_view string) -> BytesValue { if (string.empty()) { return BytesValue(); } if (scratch->data() == string.data() && scratch->size() == string.size()) { return BytesValue(arena, std::move(*scratch)); } return BytesValue(arena, string); }, [&](absl::Cord&& cord) -> BytesValue { if (cord.empty()) { return BytesValue(); } return BytesValue(arena, cord); }), well_known_types::AsVariant(std::move(value))); } Value operator()(well_known_types::StringValue&& value) const { return absl::visit(absl::Overload( [&](absl::string_view string) -> StringValue { if (string.empty()) { return StringValue(); } if (scratch->data() == string.data() && scratch->size() == string.size()) { return StringValue(arena, std::move(*scratch)); } return StringValue(arena, string); }, [&](absl::Cord&& cord) -> StringValue { if (cord.empty()) { return StringValue(); } return StringValue(arena, cord); }), well_known_types::AsVariant(std::move(value))); } Value operator()(well_known_types::ListValue&& value) const { return absl::visit( absl::Overload( [&](well_known_types::ListValueConstRef value) -> ListValue { auto* cloned = value.get().New(arena); cloned->CopyFrom(value.get()); return ParsedJsonListValue(cloned, arena); }, [&](well_known_types::ListValuePtr value) -> ListValue { if (value->GetArena() != arena) { auto* cloned = value->New(arena); cloned->CopyFrom(*value); return ParsedJsonListValue(cloned, arena); } return ParsedJsonListValue(value.release(), arena); }), well_known_types::AsVariant(std::move(value))); } Value operator()(well_known_types::Struct&& value) const { return absl::visit( absl::Overload( [&](well_known_types::StructConstRef value) -> MapValue { auto* cloned = value.get().New(arena); cloned->CopyFrom(value.get()); return ParsedJsonMapValue(cloned, arena); }, [&](well_known_types::StructPtr value) -> MapValue { if (value.arena() != arena) { auto* cloned = value->New(arena); cloned->CopyFrom(*value); return ParsedJsonMapValue(cloned, arena); } return ParsedJsonMapValue(value.release(), arena); }), well_known_types::AsVariant(std::move(value))); } Value operator()(Unique value) const { if (value->GetArena() != arena) { auto* cloned = value->New(arena); cloned->CopyFrom(*value); return ParsedMessageValue(cloned, arena); } return ParsedMessageValue(value.release(), arena); } template Value operator()(T t) const { return VistWellKnownTypeValue(t); } }; struct BorrowingWellKnownTypesValueVisitor { const google::protobuf::Message* absl_nonnull message; google::protobuf::Arena* absl_nonnull arena; std::string* absl_nonnull scratch; Value operator()(well_known_types::BytesValue&& value) const { return absl::visit( absl::Overload( [&](absl::string_view string) -> BytesValue { if (string.data() == scratch->data() && string.size() == scratch->size()) { return BytesValue(arena, std::move(*scratch)); } else { return BytesValue( Borrower::Arena(MessageArenaOr(message, arena)), string); } }, [&](absl::Cord&& cord) -> BytesValue { return BytesValue(std::move(cord)); }), well_known_types::AsVariant(std::move(value))); } Value operator()(well_known_types::StringValue&& value) const { return absl::visit( absl::Overload( [&](absl::string_view string) -> StringValue { if (string.data() == scratch->data() && string.size() == scratch->size()) { return StringValue(arena, std::move(*scratch)); } else { return StringValue( Borrower::Arena(MessageArenaOr(message, arena)), string); } }, [&](absl::Cord&& cord) -> StringValue { return StringValue(std::move(cord)); }), well_known_types::AsVariant(std::move(value))); } Value operator()(well_known_types::ListValue&& value) const { return absl::visit( absl::Overload( [&](well_known_types::ListValueConstRef value) -> ParsedJsonListValue { return ParsedJsonListValue(&value.get(), MessageArenaOr(&value.get(), arena)); }, [&](well_known_types::ListValuePtr value) -> ParsedJsonListValue { if (value->GetArena() != arena) { auto* cloned = value->New(arena); cloned->CopyFrom(*value); return ParsedJsonListValue(cloned, arena); } return ParsedJsonListValue(value.release(), arena); }), well_known_types::AsVariant(std::move(value))); } Value operator()(well_known_types::Struct&& value) const { return absl::visit( absl::Overload( [&](well_known_types::StructConstRef value) -> ParsedJsonMapValue { return ParsedJsonMapValue(&value.get(), MessageArenaOr(&value.get(), arena)); }, [&](well_known_types::StructPtr value) -> ParsedJsonMapValue { if (value->GetArena() != arena) { auto* cloned = value->New(arena); cloned->CopyFrom(*value); return ParsedJsonMapValue(cloned, arena); } return ParsedJsonMapValue(value.release(), arena); }), well_known_types::AsVariant(std::move(value))); } Value operator()(Unique&& value) const { if (value->GetArena() != arena) { auto* cloned = value->New(arena); cloned->CopyFrom(*value); return ParsedMessageValue(cloned, arena); } return ParsedMessageValue(value.release(), arena); } template Value operator()(T t) const { return VistWellKnownTypeValue(t); } }; } // namespace Value Value::FromMessage( const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); std::string scratch; auto status_or_adapted = well_known_types::AdaptFromMessage( arena, message, descriptor_pool, message_factory, scratch); if (ABSL_PREDICT_FALSE(!status_or_adapted.ok())) { return ErrorValue(std::move(status_or_adapted).status()); } return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, [&](absl::monostate) -> Value { auto* cloned = message.New(arena); cloned->CopyFrom(message); return ParsedMessageValue(cloned, arena); }), std::move(status_or_adapted).value()); } Value Value::FromMessage( google::protobuf::Message&& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); std::string scratch; auto status_or_adapted = well_known_types::AdaptFromMessage( arena, message, descriptor_pool, message_factory, scratch); if (ABSL_PREDICT_FALSE(!status_or_adapted.ok())) { return ErrorValue(std::move(status_or_adapted).status()); } return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, [&](absl::monostate) -> Value { auto* cloned = message.New(arena); cloned->GetReflection()->Swap(cloned, &message); return ParsedMessageValue(cloned, arena); }), std::move(status_or_adapted).value()); } Value Value::WrapMessage( const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); std::string scratch; absl::StatusOr adapted_value = well_known_types::AdaptFromMessage(arena, *message, descriptor_pool, message_factory, scratch); if (ABSL_PREDICT_FALSE(!adapted_value.ok())) { return ErrorValue(std::move(adapted_value).status()); } return absl::visit( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, [&](absl::monostate) -> Value { if (message->GetArena() != arena) { auto* cloned = message->New(arena); cloned->CopyFrom(*message); return ParsedMessageValue(cloned, arena); } return ParsedMessageValue(message, arena); }), std::move(adapted_value).value()); } Value Value::WrapMessageUnsafe( const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); std::string scratch; absl::StatusOr adapted_value = well_known_types::AdaptFromMessage(arena, *message, descriptor_pool, message_factory, scratch); if (ABSL_PREDICT_FALSE(!adapted_value.ok())) { return ErrorValue(std::move(adapted_value).status()); } return absl::visit( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, [&](absl::monostate) -> Value { if (message->GetArena() != arena) { return UnsafeParsedMessageValue(message); } return ParsedMessageValue(message, arena); }), std::move(adapted_value).value()); } namespace { bool IsWellKnownMessageWrapperType( const google::protobuf::Descriptor* absl_nonnull descriptor) { switch (descriptor->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: return true; default: return false; } } template Value WrapFieldImpl( ProtoWrapperTypeOptions wrapper_type_options, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(field != nullptr); ABSL_DCHECK_EQ(message->GetDescriptor(), field->containing_type()); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(!IsWellKnownMessageType(message->GetDescriptor())); const auto* reflection = message->GetReflection(); if (field->is_map()) { if (reflection->FieldSize(*message, field) == 0) { return MapValue(); } if constexpr (Unsafe::value) { return UnsafeParsedMapFieldValue(message, field); } else { return ParsedMapFieldValue(message, field, MessageArenaOr(message, arena)); } } if (field->is_repeated()) { if (reflection->FieldSize(*message, field) == 0) { return ListValue(); } if constexpr (Unsafe::value) { return UnsafeParsedRepeatedFieldValue(message, field); } else { return ParsedRepeatedFieldValue(message, field, MessageArenaOr(message, arena)); } } switch (field->type()) { case google::protobuf::FieldDescriptor::TYPE_DOUBLE: return DoubleValue(reflection->GetDouble(*message, field)); case google::protobuf::FieldDescriptor::TYPE_FLOAT: return DoubleValue(reflection->GetFloat(*message, field)); case google::protobuf::FieldDescriptor::TYPE_INT64: return IntValue(reflection->GetInt64(*message, field)); case google::protobuf::FieldDescriptor::TYPE_UINT64: return UintValue(reflection->GetUInt64(*message, field)); case google::protobuf::FieldDescriptor::TYPE_INT32: return IntValue(reflection->GetInt32(*message, field)); case google::protobuf::FieldDescriptor::TYPE_FIXED64: return UintValue(reflection->GetUInt64(*message, field)); case google::protobuf::FieldDescriptor::TYPE_FIXED32: return UintValue(reflection->GetUInt32(*message, field)); case google::protobuf::FieldDescriptor::TYPE_BOOL: return BoolValue(reflection->GetBool(*message, field)); case google::protobuf::FieldDescriptor::TYPE_STRING: { std::string scratch; return absl::visit( absl::Overload( [&](absl::string_view string) -> StringValue { if (string.data() == scratch.data() && string.size() == scratch.size()) { return StringValue(arena, std::move(scratch)); } if constexpr (Unsafe::value) { return StringValue::WrapUnsafe(string); } else { return StringValue( Borrower::Arena(MessageArenaOr(message, arena)), string); } }, [&](absl::Cord&& cord) -> StringValue { return StringValue(std::move(cord)); }), well_known_types::AsVariant( well_known_types::GetStringField(*message, field, scratch))); } case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: if (wrapper_type_options == ProtoWrapperTypeOptions::kUnsetNull && IsWellKnownMessageWrapperType(field->message_type()) && !reflection->HasField(*message, field)) { return NullValue(); } if constexpr (Unsafe::value) { return Value::WrapMessageUnsafe( &reflection->GetMessage(*message, field), descriptor_pool, message_factory, arena); } else { return Value::WrapMessage(&reflection->GetMessage(*message, field), descriptor_pool, message_factory, arena); } case google::protobuf::FieldDescriptor::TYPE_BYTES: { std::string scratch; return absl::visit( absl::Overload( [&](absl::string_view string) -> BytesValue { if (string.data() == scratch.data() && string.size() == scratch.size()) { return BytesValue(arena, std::move(scratch)); } if constexpr (Unsafe::value) { return BytesValue::WrapUnsafe(string); } else { return BytesValue( Borrower::Arena(MessageArenaOr(message, arena)), string); } }, [&](absl::Cord&& cord) -> BytesValue { return BytesValue(std::move(cord)); }), well_known_types::AsVariant( well_known_types::GetBytesField(*message, field, scratch))); } case google::protobuf::FieldDescriptor::TYPE_UINT32: return UintValue(reflection->GetUInt32(*message, field)); case google::protobuf::FieldDescriptor::TYPE_ENUM: return Value::Enum(field->enum_type(), reflection->GetEnumValue(*message, field)); case google::protobuf::FieldDescriptor::TYPE_SFIXED32: return IntValue(reflection->GetInt32(*message, field)); case google::protobuf::FieldDescriptor::TYPE_SFIXED64: return IntValue(reflection->GetInt64(*message, field)); case google::protobuf::FieldDescriptor::TYPE_SINT32: return IntValue(reflection->GetInt32(*message, field)); case google::protobuf::FieldDescriptor::TYPE_SINT64: return IntValue(reflection->GetInt64(*message, field)); default: return ErrorValue(absl::InvalidArgumentError( absl::StrCat("unexpected protocol buffer message field type: ", field->type_name()))); } } template Value WrapRepeatedFieldImpl( int index, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(field != nullptr); ABSL_DCHECK_EQ(field->containing_type(), message->GetDescriptor()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_GE(index, 0); ABSL_DCHECK(message != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); const auto* reflection = message->GetReflection(); const int size = reflection->FieldSize(*message, field); if (ABSL_PREDICT_FALSE(index < 0 || index >= size)) { return ErrorValue(absl::InvalidArgumentError( absl::StrCat("index out of bounds: ", index))); } switch (field->type()) { case google::protobuf::FieldDescriptor::TYPE_DOUBLE: return DoubleValue(reflection->GetRepeatedDouble(*message, field, index)); case google::protobuf::FieldDescriptor::TYPE_FLOAT: return DoubleValue(reflection->GetRepeatedFloat(*message, field, index)); case google::protobuf::FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_INT64: return IntValue(reflection->GetRepeatedInt64(*message, field, index)); case google::protobuf::FieldDescriptor::TYPE_FIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT64: return UintValue(reflection->GetRepeatedUInt64(*message, field, index)); case google::protobuf::FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_INT32: return IntValue(reflection->GetRepeatedInt32(*message, field, index)); case google::protobuf::FieldDescriptor::TYPE_BOOL: return BoolValue(reflection->GetRepeatedBool(*message, field, index)); case google::protobuf::FieldDescriptor::TYPE_STRING: { std::string scratch; return absl::visit( absl::Overload( [&](absl::string_view string) -> StringValue { if (string.data() == scratch.data() && string.size() == scratch.size()) { return StringValue(arena, std::move(scratch)); } if constexpr (Unsafe::value) { return StringValue::WrapUnsafe(string); } else { return StringValue( Borrower::Arena(MessageArenaOr(message, arena)), string); } }, [&](absl::Cord&& cord) -> StringValue { return StringValue(std::move(cord)); }), well_known_types::AsVariant(well_known_types::GetRepeatedStringField( reflection, *message, field, index, scratch))); } case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: if constexpr (Unsafe::value) { return Value::WrapMessageUnsafe( &reflection->GetRepeatedMessage(*message, field, index), descriptor_pool, message_factory, arena); } else { return Value::WrapMessage( &reflection->GetRepeatedMessage(*message, field, index), descriptor_pool, message_factory, arena); } case google::protobuf::FieldDescriptor::TYPE_BYTES: { std::string scratch; return absl::visit( absl::Overload( [&](absl::string_view string) -> BytesValue { if (string.data() == scratch.data() && string.size() == scratch.size()) { return BytesValue(arena, std::move(scratch)); } if constexpr (Unsafe::value) { return BytesValue::WrapUnsafe(string); } else { return BytesValue( Borrower::Arena(MessageArenaOr(message, arena)), string); } }, [&](absl::Cord&& cord) -> BytesValue { return BytesValue(std::move(cord)); }), well_known_types::AsVariant(well_known_types::GetRepeatedBytesField( reflection, *message, field, index, scratch))); } case google::protobuf::FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT32: return UintValue(reflection->GetRepeatedUInt32(*message, field, index)); case google::protobuf::FieldDescriptor::TYPE_ENUM: return Value::Enum(field->enum_type(), reflection->GetRepeatedEnumValue( *message, field, index)); default: return ErrorValue(absl::InvalidArgumentError( absl::StrCat("unexpected message field type: ", field->type_name()))); } } template Value WrapMapFieldValueImpl( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(field != nullptr); ABSL_DCHECK_EQ(field->containing_type()->containing_type(), message->GetDescriptor()); ABSL_DCHECK(!field->is_map() && !field->is_repeated()); ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(message != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); switch (field->type()) { case google::protobuf::FieldDescriptor::TYPE_DOUBLE: return DoubleValue(value.GetDoubleValue()); case google::protobuf::FieldDescriptor::TYPE_FLOAT: return DoubleValue(value.GetFloatValue()); case google::protobuf::FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_INT64: return IntValue(value.GetInt64Value()); case google::protobuf::FieldDescriptor::TYPE_FIXED64: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT64: return UintValue(value.GetUInt64Value()); case google::protobuf::FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_SINT32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_INT32: return IntValue(value.GetInt32Value()); case google::protobuf::FieldDescriptor::TYPE_BOOL: return BoolValue(value.GetBoolValue()); case google::protobuf::FieldDescriptor::TYPE_STRING: if constexpr (Unsafe::value) { return StringValue::WrapUnsafe(value.GetStringValue()); } else { return StringValue(Borrower::Arena(MessageArenaOr(message, arena)), value.GetStringValue()); } case google::protobuf::FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_MESSAGE: if constexpr (Unsafe::value) { return Value::WrapMessageUnsafe( &value.GetMessageValue(), descriptor_pool, message_factory, arena); } else { return Value::WrapMessage(&value.GetMessageValue(), descriptor_pool, message_factory, arena); } case google::protobuf::FieldDescriptor::TYPE_BYTES: if constexpr (Unsafe::value) { return BytesValue::WrapUnsafe(value.GetStringValue()); } else { return BytesValue(Borrower::Arena(MessageArenaOr(message, arena)), value.GetStringValue()); } case google::protobuf::FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::FieldDescriptor::TYPE_UINT32: return UintValue(value.GetUInt32Value()); case google::protobuf::FieldDescriptor::TYPE_ENUM: return Value::Enum(field->enum_type(), value.GetEnumValue()); default: return ErrorValue(absl::InvalidArgumentError( absl::StrCat("unexpected message field type: ", field->type_name()))); } } } // namespace Value Value::WrapField( ProtoWrapperTypeOptions wrapper_type_options, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { using Unsafe = std::false_type; return WrapFieldImpl(wrapper_type_options, message, field, descriptor_pool, message_factory, arena); } Value Value::WrapFieldUnsafe( ProtoWrapperTypeOptions wrapper_type_options, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { using Unsafe = std::true_type; return WrapFieldImpl(wrapper_type_options, message, field, descriptor_pool, message_factory, arena); } Value Value::WrapRepeatedField( int index, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { using Unsafe = std::false_type; return WrapRepeatedFieldImpl(index, message, field, descriptor_pool, message_factory, arena); } Value Value::WrapRepeatedFieldUnsafe( int index, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { using Unsafe = std::true_type; return WrapRepeatedFieldImpl(index, message, field, descriptor_pool, message_factory, arena); } StringValue Value::WrapMapFieldKeyString( const google::protobuf::MapKey& key, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK_EQ(key.type(), google::protobuf::FieldDescriptor::CPPTYPE_STRING); #if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) return StringValue(Borrower::Arena(MessageArenaOr(message, arena)), key.GetStringValue()); #else return StringValue(arena, key.GetStringValue()); #endif } Value Value::WrapMapFieldValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { using Unsafe = std::false_type; return WrapMapFieldValueImpl(value, message, field, descriptor_pool, message_factory, arena); } Value Value::WrapMapFieldValueUnsafe( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { using Unsafe = std::true_type; return WrapMapFieldValueImpl(value, message, field, descriptor_pool, message_factory, arena); } optional_ref Value::AsBytes() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsBytes() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } absl::optional Value::AsDouble() const { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsDuration() const { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } optional_ref Value::AsError() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsError() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } absl::optional Value::AsInt() const { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsList() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsList() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } absl::optional Value::AsMap() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsMap() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } absl::optional Value::AsMessage() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsMessage() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } absl::optional Value::AsNull() const { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } optional_ref Value::AsOpaque() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsOpaque() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsOptional() const& { if (const auto* alternative = variant_.As(); alternative != nullptr && alternative->IsOptional()) { return static_cast(*alternative); } return absl::nullopt; } absl::optional Value::AsOptional() && { if (auto* alternative = variant_.As(); alternative != nullptr && alternative->IsOptional()) { return static_cast(*alternative); } return absl::nullopt; } optional_ref Value::AsParsedJsonList() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsParsedJsonList() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsParsedJsonMap() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsParsedJsonMap() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsCustomList() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsCustomList() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsCustomMap() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsCustomMap() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsParsedMapField() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsParsedMapField() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsParsedMessage() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsParsedMessage() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsParsedRepeatedField() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsParsedRepeatedField() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsCustomStruct() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsCustomStruct() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref Value::AsString() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsString() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } absl::optional Value::AsStruct() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsStruct() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } absl::optional Value::AsTimestamp() const { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } optional_ref Value::AsType() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsType() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } absl::optional Value::AsUint() const { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } optional_ref Value::AsUnknown() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional Value::AsUnknown() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } const BytesValue& Value::GetBytes() const& { ABSL_DCHECK(IsBytes()) << *this; return variant_.Get(); } BytesValue Value::GetBytes() && { ABSL_DCHECK(IsBytes()) << *this; return std::move(variant_).Get(); } DoubleValue Value::GetDouble() const { ABSL_DCHECK(IsDouble()) << *this; return variant_.Get(); } DurationValue Value::GetDuration() const { ABSL_DCHECK(IsDuration()) << *this; return variant_.Get(); } const ErrorValue& Value::GetError() const& { ABSL_DCHECK(IsError()) << *this; return variant_.Get(); } ErrorValue Value::GetError() && { ABSL_DCHECK(IsError()) << *this; return std::move(variant_).Get(); } IntValue Value::GetInt() const { ABSL_DCHECK(IsInt()) << *this; return variant_.Get(); } #ifdef ABSL_HAVE_EXCEPTIONS #define CEL_VALUE_THROW_BAD_VARIANT_ACCESS() throw absl::bad_variant_access() #else #define CEL_VALUE_THROW_BAD_VARIANT_ACCESS() \ ABSL_LOG(FATAL) << absl::bad_variant_access().what() /* Crash OK */ #endif ListValue Value::GetList() const& { ABSL_DCHECK(IsList()) << *this; if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); } ListValue Value::GetList() && { ABSL_DCHECK(IsList()) << *this; if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); } MapValue Value::GetMap() const& { ABSL_DCHECK(IsMap()) << *this; if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); } MapValue Value::GetMap() && { ABSL_DCHECK(IsMap()) << *this; if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); } MessageValue Value::GetMessage() const& { ABSL_DCHECK(IsMessage()) << *this; return variant_.Get(); } MessageValue Value::GetMessage() && { ABSL_DCHECK(IsMessage()) << *this; return std::move(variant_).Get(); } NullValue Value::GetNull() const { ABSL_DCHECK(IsNull()) << *this; return variant_.Get(); } const OpaqueValue& Value::GetOpaque() const& { ABSL_DCHECK(IsOpaque()) << *this; return variant_.Get(); } OpaqueValue Value::GetOpaque() && { ABSL_DCHECK(IsOpaque()) << *this; return std::move(variant_).Get(); } const OptionalValue& Value::GetOptional() const& { ABSL_DCHECK(IsOptional()) << *this; return static_cast(variant_.Get()); } OptionalValue Value::GetOptional() && { ABSL_DCHECK(IsOptional()) << *this; return static_cast(std::move(variant_).Get()); } const ParsedJsonListValue& Value::GetParsedJsonList() const& { ABSL_DCHECK(IsParsedJsonList()) << *this; return variant_.Get(); } ParsedJsonListValue Value::GetParsedJsonList() && { ABSL_DCHECK(IsParsedJsonList()) << *this; return std::move(variant_).Get(); } const ParsedJsonMapValue& Value::GetParsedJsonMap() const& { ABSL_DCHECK(IsParsedJsonMap()) << *this; return variant_.Get(); } ParsedJsonMapValue Value::GetParsedJsonMap() && { ABSL_DCHECK(IsParsedJsonMap()) << *this; return std::move(variant_).Get(); } const CustomListValue& Value::GetCustomList() const& { ABSL_DCHECK(IsCustomList()) << *this; return variant_.Get(); } CustomListValue Value::GetCustomList() && { ABSL_DCHECK(IsCustomList()) << *this; return std::move(variant_).Get(); } const CustomMapValue& Value::GetCustomMap() const& { ABSL_DCHECK(IsCustomMap()) << *this; return variant_.Get(); } CustomMapValue Value::GetCustomMap() && { ABSL_DCHECK(IsCustomMap()) << *this; return std::move(variant_).Get(); } const ParsedMapFieldValue& Value::GetParsedMapField() const& { ABSL_DCHECK(IsParsedMapField()) << *this; return variant_.Get(); } ParsedMapFieldValue Value::GetParsedMapField() && { ABSL_DCHECK(IsParsedMapField()) << *this; return std::move(variant_).Get(); } const ParsedMessageValue& Value::GetParsedMessage() const& { ABSL_DCHECK(IsParsedMessage()) << *this; return variant_.Get(); } ParsedMessageValue Value::GetParsedMessage() && { ABSL_DCHECK(IsParsedMessage()) << *this; return std::move(variant_).Get(); } const ParsedRepeatedFieldValue& Value::GetParsedRepeatedField() const& { ABSL_DCHECK(IsParsedRepeatedField()) << *this; return variant_.Get(); } ParsedRepeatedFieldValue Value::GetParsedRepeatedField() && { ABSL_DCHECK(IsParsedRepeatedField()) << *this; return std::move(variant_).Get(); } const CustomStructValue& Value::GetCustomStruct() const& { ABSL_DCHECK(IsCustomStruct()) << *this; return variant_.Get(); } CustomStructValue Value::GetCustomStruct() && { ABSL_DCHECK(IsCustomStruct()) << *this; return std::move(variant_).Get(); } const StringValue& Value::GetString() const& { ABSL_DCHECK(IsString()) << *this; return variant_.Get(); } StringValue Value::GetString() && { ABSL_DCHECK(IsString()) << *this; return std::move(variant_).Get(); } StructValue Value::GetStruct() const& { ABSL_DCHECK(IsStruct()) << *this; if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); } StructValue Value::GetStruct() && { ABSL_DCHECK(IsStruct()) << *this; if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } CEL_VALUE_THROW_BAD_VARIANT_ACCESS(); } TimestampValue Value::GetTimestamp() const { ABSL_DCHECK(IsTimestamp()) << *this; return variant_.Get(); } const TypeValue& Value::GetType() const& { ABSL_DCHECK(IsType()) << *this; return variant_.Get(); } TypeValue Value::GetType() && { ABSL_DCHECK(IsType()) << *this; return std::move(variant_).Get(); } UintValue Value::GetUint() const { ABSL_DCHECK(IsUint()) << *this; return variant_.Get(); } const UnknownValue& Value::GetUnknown() const& { ABSL_DCHECK(IsUnknown()) << *this; return variant_.Get(); } UnknownValue Value::GetUnknown() && { ABSL_DCHECK(IsUnknown()) << *this; return std::move(variant_).Get(); } namespace { class EmptyValueIterator final : public ValueIterator { public: bool HasNext() override { return false; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return absl::FailedPreconditionError( "`ValueIterator::Next` called after `ValueIterator::HasNext` returned " "false"); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); return false; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); return false; } }; } // namespace absl_nonnull std::unique_ptr NewEmptyValueIterator() { return std::make_unique(); } absl_nonnull ListValueBuilderPtr NewListValueBuilder(google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(arena != nullptr); return common_internal::NewListValueBuilder(arena); } absl_nonnull MapValueBuilderPtr NewMapValueBuilder(google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(arena != nullptr); return common_internal::NewMapValueBuilder(arena); } absl_nullable StructValueBuilderPtr NewStructValueBuilder( google::protobuf::Arena* absl_nonnull arena, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, absl::string_view name) { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); return common_internal::NewStructValueBuilder(arena, descriptor_pool, message_factory, name); } bool operator==(IntValue lhs, UintValue rhs) { return internal::Number::FromInt64(lhs.NativeValue()) == internal::Number::FromUint64(rhs.NativeValue()); } bool operator==(UintValue lhs, IntValue rhs) { return internal::Number::FromUint64(lhs.NativeValue()) == internal::Number::FromInt64(rhs.NativeValue()); } bool operator==(IntValue lhs, DoubleValue rhs) { return internal::Number::FromInt64(lhs.NativeValue()) == internal::Number::FromDouble(rhs.NativeValue()); } bool operator==(DoubleValue lhs, IntValue rhs) { return internal::Number::FromDouble(lhs.NativeValue()) == internal::Number::FromInt64(rhs.NativeValue()); } bool operator==(UintValue lhs, DoubleValue rhs) { return internal::Number::FromUint64(lhs.NativeValue()) == internal::Number::FromDouble(rhs.NativeValue()); } bool operator==(DoubleValue lhs, UintValue rhs) { return internal::Number::FromDouble(lhs.NativeValue()) == internal::Number::FromUint64(rhs.NativeValue()); } absl::StatusOr ValueIterator::Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull value) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(value != nullptr); if (HasNext()) { CEL_RETURN_IF_ERROR(Next(descriptor_pool, message_factory, arena, value)); return true; } return false; } } // namespace cel ================================================ FILE: common/value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUE_H_ #include #include #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/meta/type_traits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/utility/utility.h" #include "base/attribute.h" #include "common/arena.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/type.h" #include "common/typeinfo.h" #include "common/value_kind.h" #include "common/values/bool_value.h" // IWYU pragma: export #include "common/values/bytes_value.h" // IWYU pragma: export #include "common/values/bytes_value_input_stream.h" // IWYU pragma: export #include "common/values/bytes_value_output_stream.h" // IWYU pragma: export #include "common/values/custom_list_value.h" // IWYU pragma: export #include "common/values/custom_map_value.h" // IWYU pragma: export #include "common/values/custom_struct_value.h" // IWYU pragma: export #include "common/values/double_value.h" // IWYU pragma: export #include "common/values/duration_value.h" // IWYU pragma: export #include "common/values/enum_value.h" // IWYU pragma: export #include "common/values/error_value.h" // IWYU pragma: export #include "common/values/int_value.h" // IWYU pragma: export #include "common/values/list_value.h" // IWYU pragma: export #include "common/values/map_value.h" // IWYU pragma: export #include "common/values/message_value.h" // IWYU pragma: export #include "common/values/null_value.h" // IWYU pragma: export #include "common/values/opaque_value.h" // IWYU pragma: export #include "common/values/optional_value.h" // IWYU pragma: export #include "common/values/parsed_json_list_value.h" // IWYU pragma: export #include "common/values/parsed_json_map_value.h" // IWYU pragma: export #include "common/values/parsed_map_field_value.h" // IWYU pragma: export #include "common/values/parsed_message_value.h" // IWYU pragma: export #include "common/values/parsed_repeated_field_value.h" // IWYU pragma: export #include "common/values/string_value.h" // IWYU pragma: export #include "common/values/struct_value.h" // IWYU pragma: export #include "common/values/timestamp_value.h" // IWYU pragma: export #include "common/values/type_value.h" // IWYU pragma: export #include "common/values/uint_value.h" // IWYU pragma: export #include "common/values/unknown_value.h" // IWYU pragma: export #include "common/values/value_variant.h" #include "common/values/values.h" #include "internal/status_macros.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/generated_enum_reflection.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #pragma push_macro("GetMessage") #ifdef GetMessage // GetMessage in windows API headers might be defined as a macro. Depending on // ordering, might cause issues with Value::GetMessage or // google::protobuf::Reflection::GetMessage. #undef GetMessage #endif namespace cel { // `Value` is a composition type which encompasses all values supported by the // Common Expression Language. When default constructed or moved, `Value` is in // a known but invalid state. Any attempt to use it from then on, without // assigning another type, is undefined behavior. In debug builds, we do our // best to fail. class Value final : private common_internal::ValueMixin { public: // Returns an appropriate `Value` for the dynamic protobuf enum. For open // enums, returns `cel::IntValue`. For closed enums, returns `cel::ErrorValue` // if the value is not present in the enum otherwise returns `cel::IntValue`. static Value Enum(const google::protobuf::EnumValueDescriptor* absl_nonnull value); static Value Enum(const google::protobuf::EnumDescriptor* absl_nonnull type, int32_t number); // SFINAE overload for generated protobuf enums which are not well-known. // Always returns `cel::IntValue`. template static common_internal::EnableIfGeneratedEnum Enum(T value) { return IntValue(value); } // SFINAE overload for google::protobuf::NullValue. Always returns // `cel::NullValue`. template static common_internal::EnableIfWellKnownEnum Enum(T) { return NullValue(); } // Returns an appropriate `Value` for the dynamic protobuf message. If // `message` is the well known type `google.protobuf.Any`, `descriptor_pool` // and `message_factory` will be used to unpack the value. Both must outlive // the resulting value and any of its shallow copies. Otherwise the message is // copied using `arena`. static Value FromMessage( const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static Value FromMessage( google::protobuf::Message&& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Value` for the dynamic protobuf message. If // `message` is the well known type `google.protobuf.Any`, `descriptor_pool` // and `message_factory` will be used to unpack the value. Both must outlive // the resulting value and any of its shallow copies. Otherwise the message is // borrowed (no copying). If the message is on an arena, that arena will be // attributed as the owner. Otherwise `arena` is used. static Value WrapMessage( const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Value` for the dynamic protobuf message. If // `message` is the well known type `google.protobuf.Any`, `descriptor_pool` // and `message_factory` will be used to unpack the value. Both must outlive // the resulting value and any of its shallow copies. Otherwise the message is // borrowed (no copying). This function does not attempt to validate arena // ownership of a dynamic message that was not unpacked from a well known // type. Caller is responsible for ensuring the resulting value and any // derived values do not outlive the input message. static Value WrapMessageUnsafe( const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Value` for the dynamic protobuf message field. If // `field` in `message` is the well known type `google.protobuf.Any`, // `descriptor_pool` and `message_factory` will be used to unpack the value. // Both must outlive the resulting value and any of its shallow copies. // Otherwise the field is borrowed (no copying). If the message is on an // arena, that arena will be attributed as the owner. Otherwise `arena` is // used. static Value WrapField( ProtoWrapperTypeOptions wrapper_type_options, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static Value WrapField( const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { return WrapField(ProtoWrapperTypeOptions::kUnsetNull, message, field, descriptor_pool, message_factory, arena); } // Returns an appropriate `Value` for the dynamic protobuf message field. If // `field` in `message` is the well known type `google.protobuf.Any`, // `descriptor_pool` and `message_factory` will be used to unpack the value. // Both must outlive the resulting value and any of its shallow copies. // Otherwise the field is borrowed (no copying). Caller is responsible for // ensuring the resulting value and any derived values do not outlive the // input message. static Value WrapFieldUnsafe( ProtoWrapperTypeOptions wrapper_type_options, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Value` for the dynamic protobuf message repeated // field. If `field` in `message` is the well known type // `google.protobuf.Any`, `descriptor_pool` and `message_factory` will be used // to unpack the value. Both must outlive the resulting value and any of its // shallow copies. static Value WrapRepeatedField( int index, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Value` for the dynamic protobuf message repeated // field. If `field` in `message` is the well known type // `google.protobuf.Any`, `descriptor_pool` and `message_factory` will be used // to unpack the value. Both must outlive the resulting value and any of its // shallow copies. static Value WrapRepeatedFieldUnsafe( int index, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `StringValue` for the dynamic protobuf message map // field key. The map field key must be a string or the behavior is undefined. static StringValue WrapMapFieldKeyString( const google::protobuf::MapKey& key, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Value` for the dynamic protobuf message map // field value. If `field` in `message`, which is `value`, is the well known // type `google.protobuf.Any`, `descriptor_pool` and `message_factory` will be // used to unpack the value. Both must outlive the resulting value and any of // its shallow copies. static Value WrapMapFieldValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); // Returns an appropriate `Value` for the dynamic protobuf message map // field value. If `field` in `message`, which is `value`, is the well known // type `google.protobuf.Any`, `descriptor_pool` and `message_factory` will be // used to unpack the value. Both must outlive the resulting value and any of // its shallow copies. Caller is responsible for ensuring the resulting value // and any derived values do not outlive the input message. static Value WrapMapFieldValueUnsafe( const google::protobuf::MapValueConstRef& value, const google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); Value() = default; Value(const Value&) = default; Value& operator=(const Value&) = default; Value(Value&& other) = default; Value& operator=(Value&&) = default; // NOLINTNEXTLINE(google-explicit-constructor) Value(const ListValue& value) : variant_(value.ToValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Value(ListValue&& value) : variant_(std::move(value).ToValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(const ListValue& value) { variant_ = value.ToValueVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(ListValue&& value) { variant_ = std::move(value).ToValueVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value(const MapValue& value) : variant_(value.ToValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Value(MapValue&& value) : variant_(std::move(value).ToValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(const MapValue& value) { variant_ = value.ToValueVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(MapValue&& value) { variant_ = std::move(value).ToValueVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value(const StructValue& value) : variant_(value.ToValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Value(StructValue&& value) : variant_(std::move(value).ToValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(const StructValue& value) { variant_ = value.ToValueVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(StructValue&& value) { variant_ = std::move(value).ToValueVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value(const MessageValue& value) : variant_(value.ToValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Value(MessageValue&& value) : variant_(std::move(value).ToValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(const MessageValue& value) { variant_ = value.ToValueVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(MessageValue&& value) { variant_ = std::move(value).ToValueVariant(); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value(const OptionalValue& value) : variant_(absl::in_place_type, static_cast(value)) {} // NOLINTNEXTLINE(google-explicit-constructor) Value(OptionalValue&& value) : variant_(absl::in_place_type, static_cast(value)) {} // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(const OptionalValue& value) { variant_.Assign(static_cast(value)); return *this; } // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(OptionalValue&& value) { variant_.Assign(static_cast(value)); return *this; } template >>> // NOLINTNEXTLINE(google-explicit-constructor) Value(T&& alternative) noexcept : variant_(absl::in_place_type>, std::forward(alternative)) {} template >>> // NOLINTNEXTLINE(google-explicit-constructor) Value& operator=(T&& alternative) noexcept { variant_.Assign(std::forward(alternative)); return *this; } ValueKind kind() const { return variant_.kind(); } Type GetRuntimeType() const; absl::string_view GetTypeName() const; std::string DebugString() const; // `SerializeTo` serializes this value to `output`. If an error is returned, // `output` is in a valid but unspecified state. If this value does not // support serialization, `FAILED_PRECONDITION` is returned. absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // `ConvertToJson` converts this value to its JSON representation. The // argument `json` **MUST** be an instance of `google.protobuf.Value` which is // can either be the generated message or a dynamic message. The descriptor // pool `descriptor_pool` and message factory `message_factory` are used to // deal with serialized messages and a few corners cases. absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // `ConvertToJsonArray` converts this value to its JSON representation if and // only if it can be represented as an array. The argument `json` **MUST** be // an instance of `google.protobuf.ListValue` which is can either be the // generated message or a dynamic message. The descriptor pool // `descriptor_pool` and message factory `message_factory` are used to deal // with serialized messages and a few corners cases. absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // `ConvertToJsonArray` converts this value to its JSON representation if and // only if it can be represented as an object. The argument `json` **MUST** be // an instance of `google.protobuf.Struct` which is can either be the // generated message or a dynamic message. The descriptor pool // `descriptor_pool` and message factory `message_factory` are used to deal // with serialized messages and a few corners cases. absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const; // Clones the value to another arena, if necessary, such that the lifetime of // the value is tied to the arena. Value Clone(google::protobuf::Arena* absl_nonnull arena) const; friend void swap(Value& lhs, Value& rhs) noexcept { using std::swap; swap(lhs.variant_, rhs.variant_); } friend std::ostream& operator<<(std::ostream& out, const Value& value); ABSL_DEPRECATED("Just use operator.()") Value* operator->() { return this; } ABSL_DEPRECATED("Just use operator.()") const Value* operator->() const { return this; } // Returns `true` if this value is an instance of a bool value. bool IsBool() const { return variant_.Is(); } // Returns `true` if this value is an instance of a bool value and true. bool IsTrue() const { return IsBool() && GetBool().NativeValue(); } // Returns `true` if this value is an instance of a bool value and false. bool IsFalse() const { return IsBool() && !GetBool().NativeValue(); } // Returns `true` if this value is an instance of a bytes value. bool IsBytes() const { return variant_.Is(); } // Returns `true` if this value is an instance of a double value. bool IsDouble() const { return variant_.Is(); } // Returns `true` if this value is an instance of a duration value. bool IsDuration() const { return variant_.Is(); } // Returns `true` if this value is an instance of an error value. bool IsError() const { return variant_.Is(); } // Returns `true` if this value is an instance of an int value. bool IsInt() const { return variant_.Is(); } // Returns `true` if this value is an instance of a list value. bool IsList() const { return variant_.Is() || variant_.Is() || variant_.Is() || variant_.Is(); } // Returns `true` if this value is an instance of a map value. bool IsMap() const { return variant_.Is() || variant_.Is() || variant_.Is() || variant_.Is(); } // Returns `true` if this value is an instance of a message value. If `true` // is returned, it is implied that `IsStruct()` would also return true. bool IsMessage() const { return variant_.Is(); } // Returns `true` if this value is an instance of a null value. bool IsNull() const { return variant_.Is(); } // Returns `true` if this value is an instance of an opaque value. bool IsOpaque() const { return variant_.Is(); } // Returns `true` if this value is an instance of an optional value. If `true` // is returned, it is implied that `IsOpaque()` would also return true. bool IsOptional() const { if (const auto* alternative = variant_.As(); alternative != nullptr) { return alternative->IsOptional(); } return false; } // Returns `true` if this value is an instance of a parsed JSON list value. If // `true` is returned, it is implied that `IsList()` would also return // true. bool IsParsedJsonList() const { return variant_.Is(); } // Returns `true` if this value is an instance of a parsed JSON map value. If // `true` is returned, it is implied that `IsMap()` would also return // true. bool IsParsedJsonMap() const { return variant_.Is(); } // Returns `true` if this value is an instance of a custom list value. If // `true` is returned, it is implied that `IsList()` would also return // true. bool IsCustomList() const { return variant_.Is(); } // Returns `true` if this value is an instance of a custom map value. If // `true` is returned, it is implied that `IsMap()` would also return // true. bool IsCustomMap() const { return variant_.Is(); } // Returns `true` if this value is an instance of a parsed map field value. If // `true` is returned, it is implied that `IsMap()` would also return // true. bool IsParsedMapField() const { return variant_.Is(); } // Returns `true` if this value is an instance of a parsed message value. If // `true` is returned, it is implied that `IsMessage()` would also return // true. bool IsParsedMessage() const { return variant_.Is(); } // Returns `true` if this value is an instance of a parsed repeated field // value. If `true` is returned, it is implied that `IsList()` would also // return true. bool IsParsedRepeatedField() const { return variant_.Is(); } // Returns `true` if this value is an instance of a custom struct value. If // `true` is returned, it is implied that `IsStruct()` would also return // true. bool IsCustomStruct() const { return variant_.Is(); } // Returns `true` if this value is an instance of a string value. bool IsString() const { return variant_.Is(); } // Returns `true` if this value is an instance of a struct value. bool IsStruct() const { return variant_.Is() || variant_.Is() || variant_.Is(); } // Returns `true` if this value is an instance of a timestamp value. bool IsTimestamp() const { return variant_.Is(); } // Returns `true` if this value is an instance of a type value. bool IsType() const { return variant_.Is(); } // Returns `true` if this value is an instance of a uint value. bool IsUint() const { return variant_.Is(); } // Returns `true` if this value is an instance of an unknown value. bool IsUnknown() const { return variant_.Is(); } // Convenience method for use with template metaprogramming. See // `IsBool()`. template std::enable_if_t, bool> Is() const { return IsBool(); } // Convenience method for use with template metaprogramming. See // `IsBytes()`. template std::enable_if_t, bool> Is() const { return IsBytes(); } // Convenience method for use with template metaprogramming. See // `IsDouble()`. template std::enable_if_t, bool> Is() const { return IsDouble(); } // Convenience method for use with template metaprogramming. See // `IsDuration()`. template std::enable_if_t, bool> Is() const { return IsDuration(); } // Convenience method for use with template metaprogramming. See // `IsError()`. template std::enable_if_t, bool> Is() const { return IsError(); } // Convenience method for use with template metaprogramming. See // `IsInt()`. template std::enable_if_t, bool> Is() const { return IsInt(); } // Convenience method for use with template metaprogramming. See // `IsList()`. template std::enable_if_t, bool> Is() const { return IsList(); } // Convenience method for use with template metaprogramming. See // `IsMap()`. template std::enable_if_t, bool> Is() const { return IsMap(); } // Convenience method for use with template metaprogramming. See // `IsMessage()`. template std::enable_if_t, bool> Is() const { return IsMessage(); } // Convenience method for use with template metaprogramming. See // `IsNull()`. template std::enable_if_t, bool> Is() const { return IsNull(); } // Convenience method for use with template metaprogramming. See // `IsOpaque()`. template std::enable_if_t, bool> Is() const { return IsOpaque(); } // Convenience method for use with template metaprogramming. See // `IsOptional()`. template std::enable_if_t, bool> Is() const { return IsOptional(); } // Convenience method for use with template metaprogramming. See // `IsParsedJsonList()`. template std::enable_if_t, bool> Is() const { return IsParsedJsonList(); } // Convenience method for use with template metaprogramming. See // `IsParsedJsonMap()`. template std::enable_if_t, bool> Is() const { return IsParsedJsonMap(); } // Convenience method for use with template metaprogramming. See // `IsCustomList()`. template std::enable_if_t, bool> Is() const { return IsCustomList(); } // Convenience method for use with template metaprogramming. See // `IsCustomMap()`. template std::enable_if_t, bool> Is() const { return IsCustomMap(); } // Convenience method for use with template metaprogramming. See // `IsParsedMapField()`. template std::enable_if_t, bool> Is() const { return IsParsedMapField(); } // Convenience method for use with template metaprogramming. See // `IsParsedMessage()`. template std::enable_if_t, bool> Is() const { return IsParsedMessage(); } // Convenience method for use with template metaprogramming. See // `IsParsedRepeatedField()`. template std::enable_if_t, bool> Is() const { return IsParsedRepeatedField(); } // Convenience method for use with template metaprogramming. See // `IsParsedStruct()`. template std::enable_if_t, bool> Is() const { return IsCustomStruct(); } // Convenience method for use with template metaprogramming. See // `IsString()`. template std::enable_if_t, bool> Is() const { return IsString(); } // Convenience method for use with template metaprogramming. See // `IsStruct()`. template std::enable_if_t, bool> Is() const { return IsStruct(); } // Convenience method for use with template metaprogramming. See // `IsTimestamp()`. template std::enable_if_t, bool> Is() const { return IsTimestamp(); } // Convenience method for use with template metaprogramming. See // `IsType()`. template std::enable_if_t, bool> Is() const { return IsType(); } // Convenience method for use with template metaprogramming. See // `IsUint()`. template std::enable_if_t, bool> Is() const { return IsUint(); } // Convenience method for use with template metaprogramming. See // `IsUnknown()`. template std::enable_if_t, bool> Is() const { return IsUnknown(); } // Performs a checked cast from a value to a bool value, // returning a non-empty optional with either a value or reference to the // bool value. Otherwise an empty optional is returned. absl::optional AsBool() const { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } // Performs a checked cast from a value to a bytes value, // returning a non-empty optional with either a value or reference to the // bytes value. Otherwise an empty optional is returned. optional_ref AsBytes() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsBytes(); } optional_ref AsBytes() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsBytes() &&; absl::optional AsBytes() const&& { return common_internal::AsOptional(AsBytes()); } // Performs a checked cast from a value to a double value, // returning a non-empty optional with either a value or reference to the // double value. Otherwise an empty optional is returned. absl::optional AsDouble() const; // Performs a checked cast from a value to a duration value, // returning a non-empty optional with either a value or reference to the // duration value. Otherwise an empty optional is returned. absl::optional AsDuration() const; // Performs a checked cast from a value to an error value, // returning a non-empty optional with either a value or reference to the // error value. Otherwise an empty optional is returned. optional_ref AsError() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsError(); } optional_ref AsError() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsError() &&; absl::optional AsError() const&& { return common_internal::AsOptional(AsError()); } // Performs a checked cast from a value to an int value, // returning a non-empty optional with either a value or reference to the // int value. Otherwise an empty optional is returned. absl::optional AsInt() const; // Performs a checked cast from a value to a list value, // returning a non-empty optional with either a value or reference to the // list value. Otherwise an empty optional is returned. absl::optional AsList() & { return std::as_const(*this).AsList(); } absl::optional AsList() const&; absl::optional AsList() &&; absl::optional AsList() const&& { return common_internal::AsOptional(AsList()); } // Performs a checked cast from a value to a map value, // returning a non-empty optional with either a value or reference to the // map value. Otherwise an empty optional is returned. absl::optional AsMap() & { return std::as_const(*this).AsMap(); } absl::optional AsMap() const&; absl::optional AsMap() &&; absl::optional AsMap() const&& { return common_internal::AsOptional(AsMap()); } // Performs a checked cast from a value to a message value, // returning a non-empty optional with either a value or reference to the // message value. Otherwise an empty optional is returned. absl::optional AsMessage() & { return std::as_const(*this).AsMessage(); } absl::optional AsMessage() const&; absl::optional AsMessage() &&; absl::optional AsMessage() const&& { return common_internal::AsOptional(AsMessage()); } // Performs a checked cast from a value to a null value, // returning a non-empty optional with either a value or reference to the // null value. Otherwise an empty optional is returned. absl::optional AsNull() const; // Performs a checked cast from a value to an opaque value, // returning a non-empty optional with either a value or reference to the // opaque value. Otherwise an empty optional is returned. optional_ref AsOpaque() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsOpaque(); } optional_ref AsOpaque() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsOpaque() &&; absl::optional AsOpaque() const&& { return common_internal::AsOptional(AsOpaque()); } // Performs a checked cast from a value to an optional value, // returning a non-empty optional with either a value or reference to the // optional value. Otherwise an empty optional is returned. optional_ref AsOptional() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsOptional(); } optional_ref AsOptional() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsOptional() &&; absl::optional AsOptional() const&& { return common_internal::AsOptional(AsOptional()); } // Performs a checked cast from a value to a parsed JSON list value, // returning a non-empty optional with either a value or reference to the // parsed message value. Otherwise an empty optional is returned. optional_ref AsParsedJsonList() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsParsedJsonList(); } optional_ref AsParsedJsonList() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsParsedJsonList() &&; absl::optional AsParsedJsonList() const&& { return common_internal::AsOptional(AsParsedJsonList()); } // Performs a checked cast from a value to a parsed JSON map value, // returning a non-empty optional with either a value or reference to the // parsed message value. Otherwise an empty optional is returned. optional_ref AsParsedJsonMap() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsParsedJsonMap(); } optional_ref AsParsedJsonMap() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsParsedJsonMap() &&; absl::optional AsParsedJsonMap() const&& { return common_internal::AsOptional(AsParsedJsonMap()); } // Performs a checked cast from a value to a custom list value, // returning a non-empty optional with either a value or reference to the // custom list value. Otherwise an empty optional is returned. optional_ref AsCustomList() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsCustomList(); } optional_ref AsCustomList() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsCustomList() &&; absl::optional AsCustomList() const&& { return common_internal::AsOptional(AsCustomList()); } // Performs a checked cast from a value to a custom map value, // returning a non-empty optional with either a value or reference to the // custom map value. Otherwise an empty optional is returned. optional_ref AsCustomMap() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsCustomMap(); } optional_ref AsCustomMap() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsCustomMap() &&; absl::optional AsCustomMap() const&& { return common_internal::AsOptional(AsCustomMap()); } // Performs a checked cast from a value to a parsed map field value, // returning a non-empty optional with either a value or reference to the // parsed map field value. Otherwise an empty optional is returned. optional_ref AsParsedMapField() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsParsedMapField(); } optional_ref AsParsedMapField() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsParsedMapField() &&; absl::optional AsParsedMapField() const&& { return common_internal::AsOptional(AsParsedMapField()); } // Performs a checked cast from a value to a parsed message value, // returning a non-empty optional with either a value or reference to the // parsed message value. Otherwise an empty optional is returned. optional_ref AsParsedMessage() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsParsedMessage(); } optional_ref AsParsedMessage() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsParsedMessage() &&; absl::optional AsParsedMessage() const&& { return common_internal::AsOptional(AsParsedMessage()); } // Performs a checked cast from a value to a parsed repeated field value, // returning a non-empty optional with either a value or reference to the // parsed repeated field value. Otherwise an empty optional is returned. optional_ref AsParsedRepeatedField() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsParsedRepeatedField(); } optional_ref AsParsedRepeatedField() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsParsedRepeatedField() &&; absl::optional AsParsedRepeatedField() const&& { return common_internal::AsOptional(AsParsedRepeatedField()); } // Performs a checked cast from a value to a custom struct value, // returning a non-empty optional with either a value or reference to the // custom struct value. Otherwise an empty optional is returned. optional_ref AsCustomStruct() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsCustomStruct(); } optional_ref AsCustomStruct() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsCustomStruct() &&; absl::optional AsCustomStruct() const&& { return common_internal::AsOptional(AsCustomStruct()); } // Performs a checked cast from a value to a string value, // returning a non-empty optional with either a value or reference to the // string value. Otherwise an empty optional is returned. optional_ref AsString() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsString(); } optional_ref AsString() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsString() &&; absl::optional AsString() const&& { return common_internal::AsOptional(AsString()); } // Performs a checked cast from a value to a struct value, // returning a non-empty optional with either a value or reference to the // struct value. Otherwise an empty optional is returned. absl::optional AsStruct() & { return std::as_const(*this).AsStruct(); } absl::optional AsStruct() const&; absl::optional AsStruct() &&; absl::optional AsStruct() const&& { return common_internal::AsOptional(AsStruct()); } // Performs a checked cast from a value to a timestamp value, // returning a non-empty optional with either a value or reference to the // timestamp value. Otherwise an empty optional is returned. absl::optional AsTimestamp() const; // Performs a checked cast from a value to a type value, // returning a non-empty optional with either a value or reference to the // type value. Otherwise an empty optional is returned. optional_ref AsType() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsType(); } optional_ref AsType() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsType() &&; absl::optional AsType() const&& { return common_internal::AsOptional(AsType()); } // Performs a checked cast from a value to an uint value, // returning a non-empty optional with either a value or reference to the // uint value. Otherwise an empty optional is returned. absl::optional AsUint() const; // Performs a checked cast from a value to an unknown value, // returning a non-empty optional with either a value or reference to the // unknown value. Otherwise an empty optional is returned. optional_ref AsUnknown() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsUnknown(); } optional_ref AsUnknown() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsUnknown() &&; absl::optional AsUnknown() const&& { return common_internal::AsOptional(AsUnknown()); } // Convenience method for use with template metaprogramming. See // `AsBool()`. template std::enable_if_t, absl::optional> As() & { return AsBool(); } template std::enable_if_t, absl::optional> As() const& { return AsBool(); } template std::enable_if_t, absl::optional> As() && { return AsBool(); } template std::enable_if_t, absl::optional> As() const&& { return AsBool(); } // Convenience method for use with template metaprogramming. See // `AsBytes()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsBytes(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsBytes(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsBytes(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsBytes(); } // Convenience method for use with template metaprogramming. See // `AsDouble()`. template std::enable_if_t, absl::optional> As() & { return AsDouble(); } template std::enable_if_t, absl::optional> As() const& { return AsDouble(); } template std::enable_if_t, absl::optional> As() && { return AsDouble(); } template std::enable_if_t, absl::optional> As() const&& { return AsDouble(); } // Convenience method for use with template metaprogramming. See // `AsDuration()`. template std::enable_if_t, absl::optional> As() & { return AsDuration(); } template std::enable_if_t, absl::optional> As() const& { return AsDuration(); } template std::enable_if_t, absl::optional> As() && { return AsDuration(); } template std::enable_if_t, absl::optional> As() const&& { return AsDuration(); } // Convenience method for use with template metaprogramming. See // `AsError()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsError(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsError(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsError(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsError(); } // Convenience method for use with template metaprogramming. See // `AsInt()`. template std::enable_if_t, absl::optional> As() & { return AsInt(); } template std::enable_if_t, absl::optional> As() const& { return AsInt(); } template std::enable_if_t, absl::optional> As() && { return AsInt(); } template std::enable_if_t, absl::optional> As() const&& { return AsInt(); } // Convenience method for use with template metaprogramming. See // `AsList()`. template std::enable_if_t, absl::optional> As() & { return AsList(); } template std::enable_if_t, absl::optional> As() const& { return AsList(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsList(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsList(); } // Convenience method for use with template metaprogramming. See // `AsMap()`. template std::enable_if_t, absl::optional> As() & { return AsMap(); } template std::enable_if_t, absl::optional> As() const& { return AsMap(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsMap(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsMap(); } // Convenience method for use with template metaprogramming. See // `AsMessage()`. template std::enable_if_t, absl::optional> As() & { return AsMessage(); } template std::enable_if_t, absl::optional> As() const& { return AsMessage(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsMessage(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsMessage(); } // Convenience method for use with template metaprogramming. See // `AsNull()`. template std::enable_if_t, absl::optional> As() & { return AsNull(); } template std::enable_if_t, absl::optional> As() const& { return AsNull(); } template std::enable_if_t, absl::optional> As() && { return AsNull(); } template std::enable_if_t, absl::optional> As() const&& { return AsNull(); } // Convenience method for use with template metaprogramming. See // `AsOpaque()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsOpaque(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsOpaque(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsOpaque(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsOpaque(); } // Convenience method for use with template metaprogramming. See // `AsOptional()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsOptional(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsOptional(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsOptional(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsOptional(); } // Convenience method for use with template metaprogramming. See // `AsParsedJsonList()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedJsonList(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedJsonList(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsParsedJsonList(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsParsedJsonList(); } // Convenience method for use with template metaprogramming. See // `AsParsedJsonMap()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedJsonMap(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedJsonMap(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsParsedJsonMap(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsParsedJsonMap(); } // Convenience method for use with template metaprogramming. See // `AsCustomList()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustomList(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustomList(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsCustomList(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsCustomList(); } // Convenience method for use with template metaprogramming. See // `AsCustomMap()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustomMap(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustomMap(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsCustomMap(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsCustomMap(); } // Convenience method for use with template metaprogramming. See // `AsParsedMapField()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedMapField(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedMapField(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsParsedMapField(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsParsedMapField(); } // Convenience method for use with template metaprogramming. See // `AsParsedMessage()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedMessage(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedMessage(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsParsedMessage(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsParsedMessage(); } // Convenience method for use with template metaprogramming. See // `AsParsedRepeatedField()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedRepeatedField(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedRepeatedField(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsParsedRepeatedField(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsParsedRepeatedField(); } // Convenience method for use with template metaprogramming. See // `AsCustomStruct()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustomStruct(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustomStruct(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsCustomStruct(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsCustomStruct(); } // Convenience method for use with template metaprogramming. See // `AsString()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsString(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsString(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsString(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsString(); } // Convenience method for use with template metaprogramming. See // `AsStruct()`. template std::enable_if_t, absl::optional> As() & { return AsStruct(); } template std::enable_if_t, absl::optional> As() const& { return AsStruct(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsStruct(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsStruct(); } // Convenience method for use with template metaprogramming. See // `AsTimestamp()`. template std::enable_if_t, absl::optional> As() & { return AsTimestamp(); } template std::enable_if_t, absl::optional> As() const& { return AsTimestamp(); } template std::enable_if_t, absl::optional> As() && { return AsTimestamp(); } template std::enable_if_t, absl::optional> As() const&& { return AsTimestamp(); } // Convenience method for use with template metaprogramming. See // `AsType()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsType(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsType(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsType(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsType(); } // Convenience method for use with template metaprogramming. See // `AsUint()`. template std::enable_if_t, absl::optional> As() & { return AsUint(); } template std::enable_if_t, absl::optional> As() const& { return AsUint(); } template std::enable_if_t, absl::optional> As() && { return AsUint(); } template std::enable_if_t, absl::optional> As() const&& { return AsUint(); } // Convenience method for use with template metaprogramming. See // `AsUnknown()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsUnknown(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsUnknown(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsUnknown(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsUnknown(); } // Performs an unchecked cast from a value to a bool value. In // debug builds a best effort is made to crash. If `IsBool()` would return // false, calling this method is undefined behavior. BoolValue GetBool() const { ABSL_DCHECK(IsBool()) << *this; return variant_.Get(); } // Performs an unchecked cast from a value to a bytes value. In // debug builds a best effort is made to crash. If `IsBytes()` would return // false, calling this method is undefined behavior. const BytesValue& GetBytes() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetBytes(); } const BytesValue& GetBytes() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; BytesValue GetBytes() &&; BytesValue GetBytes() const&& { return GetBytes(); } // Performs an unchecked cast from a value to a double value. In // debug builds a best effort is made to crash. If `IsDouble()` would return // false, calling this method is undefined behavior. DoubleValue GetDouble() const; // Performs an unchecked cast from a value to a duration value. In // debug builds a best effort is made to crash. If `IsDuration()` would return // false, calling this method is undefined behavior. DurationValue GetDuration() const; // Performs an unchecked cast from a value to an error value. In // debug builds a best effort is made to crash. If `IsError()` would return // false, calling this method is undefined behavior. const ErrorValue& GetError() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetError(); } const ErrorValue& GetError() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; ErrorValue GetError() &&; ErrorValue GetError() const&& { return GetError(); } // Performs an unchecked cast from a value to an int value. In // debug builds a best effort is made to crash. If `IsInt()` would return // false, calling this method is undefined behavior. IntValue GetInt() const; // Performs an unchecked cast from a value to a list value. In // debug builds a best effort is made to crash. If `IsList()` would return // false, calling this method is undefined behavior. ListValue GetList() & { return std::as_const(*this).GetList(); } ListValue GetList() const&; ListValue GetList() &&; ListValue GetList() const&& { return GetList(); } // Performs an unchecked cast from a value to a map value. In // debug builds a best effort is made to crash. If `IsMap()` would return // false, calling this method is undefined behavior. MapValue GetMap() & { return std::as_const(*this).GetMap(); } MapValue GetMap() const&; MapValue GetMap() &&; MapValue GetMap() const&& { return GetMap(); } // Performs an unchecked cast from a value to a message value. In // debug builds a best effort is made to crash. If `IsMessage()` would return // false, calling this method is undefined behavior. MessageValue GetMessage() & { return std::as_const(*this).GetMessage(); } MessageValue GetMessage() const&; MessageValue GetMessage() &&; MessageValue GetMessage() const&& { return GetMessage(); } // Performs an unchecked cast from a value to a null value. In // debug builds a best effort is made to crash. If `IsNull()` would return // false, calling this method is undefined behavior. NullValue GetNull() const; // Performs an unchecked cast from a value to an opaque value. In // debug builds a best effort is made to crash. If `IsOpaque()` would return // false, calling this method is undefined behavior. const OpaqueValue& GetOpaque() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetOpaque(); } const OpaqueValue& GetOpaque() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; OpaqueValue GetOpaque() &&; OpaqueValue GetOpaque() const&& { return GetOpaque(); } // Performs an unchecked cast from a value to an optional value. In // debug builds a best effort is made to crash. If `IsOptional()` would return // false, calling this method is undefined behavior. const OptionalValue& GetOptional() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetOptional(); } const OptionalValue& GetOptional() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; OptionalValue GetOptional() &&; OptionalValue GetOptional() const&& { return GetOptional(); } // Performs an unchecked cast from a value to a parsed message value. In // debug builds a best effort is made to crash. If `IsParsedJsonList()` would // return false, calling this method is undefined behavior. const ParsedJsonListValue& GetParsedJsonList() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetParsedJsonList(); } const ParsedJsonListValue& GetParsedJsonList() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; ParsedJsonListValue GetParsedJsonList() &&; ParsedJsonListValue GetParsedJsonList() const&& { return GetParsedJsonList(); } // Performs an unchecked cast from a value to a parsed message value. In // debug builds a best effort is made to crash. If `IsParsedJsonMap()` would // return false, calling this method is undefined behavior. const ParsedJsonMapValue& GetParsedJsonMap() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetParsedJsonMap(); } const ParsedJsonMapValue& GetParsedJsonMap() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; ParsedJsonMapValue GetParsedJsonMap() &&; ParsedJsonMapValue GetParsedJsonMap() const&& { return GetParsedJsonMap(); } // Performs an unchecked cast from a value to a custom list value. In // debug builds a best effort is made to crash. If `IsCustomList()` would // return false, calling this method is undefined behavior. const CustomListValue& GetCustomList() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetCustomList(); } const CustomListValue& GetCustomList() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; CustomListValue GetCustomList() &&; CustomListValue GetCustomList() const&& { return GetCustomList(); } // Performs an unchecked cast from a value to a custom map value. In // debug builds a best effort is made to crash. If `IsCustomMap()` would // return false, calling this method is undefined behavior. const CustomMapValue& GetCustomMap() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetCustomMap(); } const CustomMapValue& GetCustomMap() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; CustomMapValue GetCustomMap() &&; CustomMapValue GetCustomMap() const&& { return GetCustomMap(); } // Performs an unchecked cast from a value to a parsed map field value. In // debug builds a best effort is made to crash. If `IsParsedMapField()` would // return false, calling this method is undefined behavior. const ParsedMapFieldValue& GetParsedMapField() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetParsedMapField(); } const ParsedMapFieldValue& GetParsedMapField() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; ParsedMapFieldValue GetParsedMapField() &&; ParsedMapFieldValue GetParsedMapField() const&& { return GetParsedMapField(); } // Performs an unchecked cast from a value to a parsed message value. In // debug builds a best effort is made to crash. If `IsParsedMessage()` would // return false, calling this method is undefined behavior. const ParsedMessageValue& GetParsedMessage() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetParsedMessage(); } const ParsedMessageValue& GetParsedMessage() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; ParsedMessageValue GetParsedMessage() &&; ParsedMessageValue GetParsedMessage() const&& { return GetParsedMessage(); } // Performs an unchecked cast from a value to a parsed repeated field value. // In debug builds a best effort is made to crash. If // `IsParsedRepeatedField()` would return false, calling this method is // undefined behavior. const ParsedRepeatedFieldValue& GetParsedRepeatedField() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetParsedRepeatedField(); } const ParsedRepeatedFieldValue& GetParsedRepeatedField() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; ParsedRepeatedFieldValue GetParsedRepeatedField() &&; ParsedRepeatedFieldValue GetParsedRepeatedField() const&& { return GetParsedRepeatedField(); } // Performs an unchecked cast from a value to a custom struct value. In // debug builds a best effort is made to crash. If `IsCustomStruct()` would // return false, calling this method is undefined behavior. const CustomStructValue& GetCustomStruct() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetCustomStruct(); } const CustomStructValue& GetCustomStruct() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; CustomStructValue GetCustomStruct() &&; CustomStructValue GetCustomStruct() const&& { return GetCustomStruct(); } // Performs an unchecked cast from a value to a string value. In // debug builds a best effort is made to crash. If `IsString()` would return // false, calling this method is undefined behavior. const StringValue& GetString() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetString(); } const StringValue& GetString() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; StringValue GetString() &&; StringValue GetString() const&& { return GetString(); } // Performs an unchecked cast from a value to a struct value. In // debug builds a best effort is made to crash. If `IsStruct()` would return // false, calling this method is undefined behavior. StructValue GetStruct() & { return std::as_const(*this).GetStruct(); } StructValue GetStruct() const&; StructValue GetStruct() &&; StructValue GetStruct() const&& { return GetStruct(); } // Performs an unchecked cast from a value to a timestamp value. In // debug builds a best effort is made to crash. If `IsTimestamp()` would // return false, calling this method is undefined behavior. TimestampValue GetTimestamp() const; // Performs an unchecked cast from a value to a type value. In // debug builds a best effort is made to crash. If `IsType()` would return // false, calling this method is undefined behavior. const TypeValue& GetType() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetType(); } const TypeValue& GetType() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; TypeValue GetType() &&; TypeValue GetType() const&& { return GetType(); } // Performs an unchecked cast from a value to an uint value. In // debug builds a best effort is made to crash. If `IsUint()` would return // false, calling this method is undefined behavior. UintValue GetUint() const; // Performs an unchecked cast from a value to an unknown value. In // debug builds a best effort is made to crash. If `IsUnknown()` would return // false, calling this method is undefined behavior. const UnknownValue& GetUnknown() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetUnknown(); } const UnknownValue& GetUnknown() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; UnknownValue GetUnknown() &&; UnknownValue GetUnknown() const&& { return GetUnknown(); } // Convenience method for use with template metaprogramming. See // `GetBool()`. template std::enable_if_t, BoolValue> Get() & { return GetBool(); } template std::enable_if_t, BoolValue> Get() const& { return GetBool(); } template std::enable_if_t, BoolValue> Get() && { return GetBool(); } template std::enable_if_t, BoolValue> Get() const&& { return GetBool(); } // Convenience method for use with template metaprogramming. See // `GetBytes()`. template std::enable_if_t, const BytesValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetBytes(); } template std::enable_if_t, const BytesValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetBytes(); } template std::enable_if_t, BytesValue> Get() && { return std::move(*this).GetBytes(); } template std::enable_if_t, BytesValue> Get() const&& { return std::move(*this).GetBytes(); } // Convenience method for use with template metaprogramming. See // `GetDouble()`. template std::enable_if_t, DoubleValue> Get() & { return GetDouble(); } template std::enable_if_t, DoubleValue> Get() const& { return GetDouble(); } template std::enable_if_t, DoubleValue> Get() && { return GetDouble(); } template std::enable_if_t, DoubleValue> Get() const&& { return GetDouble(); } // Convenience method for use with template metaprogramming. See // `GetDuration()`. template std::enable_if_t, DurationValue> Get() & { return GetDuration(); } template std::enable_if_t, DurationValue> Get() const& { return GetDuration(); } template std::enable_if_t, DurationValue> Get() && { return GetDuration(); } template std::enable_if_t, DurationValue> Get() const&& { return GetDuration(); } // Convenience method for use with template metaprogramming. See // `GetError()`. template std::enable_if_t, const ErrorValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetError(); } template std::enable_if_t, const ErrorValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetError(); } template std::enable_if_t, ErrorValue> Get() && { return std::move(*this).GetError(); } template std::enable_if_t, ErrorValue> Get() const&& { return std::move(*this).GetError(); } // Convenience method for use with template metaprogramming. See // `GetInt()`. template std::enable_if_t, IntValue> Get() & { return GetInt(); } template std::enable_if_t, IntValue> Get() const& { return GetInt(); } template std::enable_if_t, IntValue> Get() && { return GetInt(); } template std::enable_if_t, IntValue> Get() const&& { return GetInt(); } // Convenience method for use with template metaprogramming. See // `GetList()`. template std::enable_if_t, ListValue> Get() & { return GetList(); } template std::enable_if_t, ListValue> Get() const& { return GetList(); } template std::enable_if_t, ListValue> Get() && { return std::move(*this).GetList(); } template std::enable_if_t, ListValue> Get() const&& { return std::move(*this).GetList(); } // Convenience method for use with template metaprogramming. See // `GetMap()`. template std::enable_if_t, MapValue> Get() & { return GetMap(); } template std::enable_if_t, MapValue> Get() const& { return GetMap(); } template std::enable_if_t, MapValue> Get() && { return std::move(*this).GetMap(); } template std::enable_if_t, MapValue> Get() const&& { return std::move(*this).GetMap(); } // Convenience method for use with template metaprogramming. See // `GetMessage()`. template std::enable_if_t, MessageValue> Get() & { return GetMessage(); } template std::enable_if_t, MessageValue> Get() const& { return GetMessage(); } template std::enable_if_t, MessageValue> Get() && { return std::move(*this).GetMessage(); } template std::enable_if_t, MessageValue> Get() const&& { return std::move(*this).GetMessage(); } // Convenience method for use with template metaprogramming. See // `GetNull()`. template std::enable_if_t, NullValue> Get() & { return GetNull(); } template std::enable_if_t, NullValue> Get() const& { return GetNull(); } template std::enable_if_t, NullValue> Get() && { return GetNull(); } template std::enable_if_t, NullValue> Get() const&& { return GetNull(); } // Convenience method for use with template metaprogramming. See // `GetOpaque()`. template std::enable_if_t, const OpaqueValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetOpaque(); } template std::enable_if_t, const OpaqueValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetOpaque(); } template std::enable_if_t, OpaqueValue> Get() && { return std::move(*this).GetOpaque(); } template std::enable_if_t, OpaqueValue> Get() const&& { return std::move(*this).GetOpaque(); } // Convenience method for use with template metaprogramming. See // `GetOptional()`. template std::enable_if_t, const OptionalValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetOptional(); } template std::enable_if_t, const OptionalValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetOptional(); } template std::enable_if_t, OptionalValue> Get() && { return std::move(*this).GetOptional(); } template std::enable_if_t, OptionalValue> Get() const&& { return std::move(*this).GetOptional(); } // Convenience method for use with template metaprogramming. See // `GetParsedJsonList()`. template std::enable_if_t, const ParsedJsonListValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedJsonList(); } template std::enable_if_t, const ParsedJsonListValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedJsonList(); } template std::enable_if_t, ParsedJsonListValue> Get() && { return std::move(*this).GetParsedJsonList(); } template std::enable_if_t, ParsedJsonListValue> Get() const&& { return std::move(*this).GetParsedJsonList(); } // Convenience method for use with template metaprogramming. See // `GetParsedJsonMap()`. template std::enable_if_t, const ParsedJsonMapValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedJsonMap(); } template std::enable_if_t, const ParsedJsonMapValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedJsonMap(); } template std::enable_if_t, ParsedJsonMapValue> Get() && { return std::move(*this).GetParsedJsonMap(); } template std::enable_if_t, ParsedJsonMapValue> Get() const&& { return std::move(*this).GetParsedJsonMap(); } // Convenience method for use with template metaprogramming. See // `GetCustomList()`. template std::enable_if_t, const CustomListValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustomList(); } template std::enable_if_t, const CustomListValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustomList(); } template std::enable_if_t, CustomListValue> Get() && { return std::move(*this).GetCustomList(); } template std::enable_if_t, CustomListValue> Get() const&& { return std::move(*this).GetCustomList(); } // Convenience method for use with template metaprogramming. See // `GetCustomMap()`. template std::enable_if_t, const CustomMapValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustomMap(); } template std::enable_if_t, const CustomMapValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustomMap(); } template std::enable_if_t, CustomMapValue> Get() && { return std::move(*this).GetCustomMap(); } template std::enable_if_t, CustomMapValue> Get() const&& { return std::move(*this).GetCustomMap(); } // Convenience method for use with template metaprogramming. See // `GetParsedMapField()`. template std::enable_if_t, const ParsedMapFieldValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedMapField(); } template std::enable_if_t, const ParsedMapFieldValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedMapField(); } template std::enable_if_t, ParsedMapFieldValue> Get() && { return std::move(*this).GetParsedMapField(); } template std::enable_if_t, ParsedMapFieldValue> Get() const&& { return std::move(*this).GetParsedMapField(); } // Convenience method for use with template metaprogramming. See // `GetParsedMessage()`. template std::enable_if_t, const ParsedMessageValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedMessage(); } template std::enable_if_t, const ParsedMessageValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedMessage(); } template std::enable_if_t, ParsedMessageValue> Get() && { return std::move(*this).GetParsedMessage(); } template std::enable_if_t, ParsedMessageValue> Get() const&& { return std::move(*this).GetParsedMessage(); } // Convenience method for use with template metaprogramming. See // `GetParsedRepeatedField()`. template std::enable_if_t, const ParsedRepeatedFieldValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedRepeatedField(); } template std::enable_if_t, const ParsedRepeatedFieldValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedRepeatedField(); } template std::enable_if_t, ParsedRepeatedFieldValue> Get() && { return std::move(*this).GetParsedRepeatedField(); } template std::enable_if_t, ParsedRepeatedFieldValue> Get() const&& { return std::move(*this).GetParsedRepeatedField(); } // Convenience method for use with template metaprogramming. See // `GetCustomStruct()`. template std::enable_if_t, const CustomStructValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustomStruct(); } template std::enable_if_t, const CustomStructValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustomStruct(); } template std::enable_if_t, CustomStructValue> Get() && { return std::move(*this).GetCustomStruct(); } template std::enable_if_t, CustomStructValue> Get() const&& { return std::move(*this).GetCustomStruct(); } // Convenience method for use with template metaprogramming. See // `GetString()`. template std::enable_if_t, const StringValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetString(); } template std::enable_if_t, const StringValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetString(); } template std::enable_if_t, StringValue> Get() && { return std::move(*this).GetString(); } template std::enable_if_t, StringValue> Get() const&& { return std::move(*this).GetString(); } // Convenience method for use with template metaprogramming. See // `GetStruct()`. template std::enable_if_t, StructValue> Get() & { return GetStruct(); } template std::enable_if_t, StructValue> Get() const& { return GetStruct(); } template std::enable_if_t, StructValue> Get() && { return std::move(*this).GetStruct(); } template std::enable_if_t, StructValue> Get() const&& { return std::move(*this).GetStruct(); } // Convenience method for use with template metaprogramming. See // `GetTimestamp()`. template std::enable_if_t, TimestampValue> Get() & { return GetTimestamp(); } template std::enable_if_t, TimestampValue> Get() const& { return GetTimestamp(); } template std::enable_if_t, TimestampValue> Get() && { return GetTimestamp(); } template std::enable_if_t, TimestampValue> Get() const&& { return GetTimestamp(); } // Convenience method for use with template metaprogramming. See // `GetType()`. template std::enable_if_t, const TypeValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetType(); } template std::enable_if_t, const TypeValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetType(); } template std::enable_if_t, TypeValue> Get() && { return std::move(*this).GetType(); } template std::enable_if_t, TypeValue> Get() const&& { return std::move(*this).GetType(); } // Convenience method for use with template metaprogramming. See // `GetUint()`. template std::enable_if_t, UintValue> Get() & { return GetUint(); } template std::enable_if_t, UintValue> Get() const& { return GetUint(); } template std::enable_if_t, UintValue> Get() && { return GetUint(); } template std::enable_if_t, UintValue> Get() const&& { return GetUint(); } // Convenience method for use with template metaprogramming. See // `GetUnknown()`. template std::enable_if_t, const UnknownValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetUnknown(); } template std::enable_if_t, const UnknownValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetUnknown(); } template std::enable_if_t, UnknownValue> Get() && { return std::move(*this).GetUnknown(); } template std::enable_if_t, UnknownValue> Get() const&& { return std::move(*this).GetUnknown(); } // When `Value` is default constructed, it is in a valid but undefined state. // Any attempt to use it invokes undefined behavior. This mention can be used // to test whether this value is valid. explicit operator bool() const { return true; } private: friend struct NativeTypeTraits; friend bool common_internal::IsLegacyListValue(const Value& value); friend common_internal::LegacyListValue common_internal::GetLegacyListValue( const Value& value); friend bool common_internal::IsLegacyMapValue(const Value& value); friend common_internal::LegacyMapValue common_internal::GetLegacyMapValue( const Value& value); friend bool common_internal::IsLegacyStructValue(const Value& value); friend common_internal::LegacyStructValue common_internal::GetLegacyStructValue(const Value& value); friend class common_internal::ValueMixin; friend struct ArenaTraits; common_internal::ValueVariant variant_; }; // Overloads for heterogeneous equality of numeric values. bool operator==(IntValue lhs, UintValue rhs); bool operator==(UintValue lhs, IntValue rhs); bool operator==(IntValue lhs, DoubleValue rhs); bool operator==(DoubleValue lhs, IntValue rhs); bool operator==(UintValue lhs, DoubleValue rhs); bool operator==(DoubleValue lhs, UintValue rhs); inline bool operator!=(IntValue lhs, UintValue rhs) { return !operator==(lhs, rhs); } inline bool operator!=(UintValue lhs, IntValue rhs) { return !operator==(lhs, rhs); } inline bool operator!=(IntValue lhs, DoubleValue rhs) { return !operator==(lhs, rhs); } inline bool operator!=(DoubleValue lhs, IntValue rhs) { return !operator==(lhs, rhs); } inline bool operator!=(UintValue lhs, DoubleValue rhs) { return !operator==(lhs, rhs); } inline bool operator!=(DoubleValue lhs, UintValue rhs) { return !operator==(lhs, rhs); } template <> struct NativeTypeTraits final { static NativeTypeId Id(const Value& value) { return value.variant_.Visit([](const auto& alternative) -> NativeTypeId { return NativeTypeId::Of(alternative); }); } }; template <> struct ArenaTraits { static bool trivially_destructible(const Value& value) { return value.variant_.Visit([](const auto& alternative) -> bool { return ArenaTraits<>::trivially_destructible(alternative); }); } }; // Statically assert some expectations. static_assert(sizeof(Value) <= 32); static_assert(alignof(Value) <= alignof(std::max_align_t)); static_assert(std::is_default_constructible_v); static_assert(std::is_copy_constructible_v); static_assert(std::is_copy_assignable_v); static_assert(std::is_nothrow_move_constructible_v); static_assert(std::is_nothrow_move_assignable_v); static_assert(std::is_nothrow_swappable_v); inline common_internal::ImplicitlyConvertibleStatus ErrorValueAssign::operator()(absl::Status status) const { *value_ = ErrorValue(std::move(status)); return common_internal::ImplicitlyConvertibleStatus(); } namespace common_internal { template absl::StatusOr ValueMixin::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->Equal( other, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr ListValueMixin::Get( size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->Get( index, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr ListValueMixin::Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->Contains( other, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr MapValueMixin::Get( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->Get( key, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr> MapValueMixin::Find( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_ASSIGN_OR_RETURN( bool found, static_cast(this)->Find( other, descriptor_pool, message_factory, arena, &result)); if (found) { return result; } return absl::nullopt; } template absl::StatusOr MapValueMixin::Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->Has( key, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr MapValueMixin::ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ListValue result; CEL_RETURN_IF_ERROR(static_cast(this)->ListKeys( descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr StructValueMixin::GetFieldByName( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->GetFieldByName( name, ProtoWrapperTypeOptions::kUnsetNull, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr StructValueMixin::GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->GetFieldByName( name, unboxing_options, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr StructValueMixin::GetFieldByNumber( int64_t number, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->GetFieldByNumber( number, ProtoWrapperTypeOptions::kUnsetNull, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr StructValueMixin::GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(static_cast(this)->GetFieldByNumber( number, unboxing_options, descriptor_pool, message_factory, arena, &result)); return result; } template absl::StatusOr> StructValueMixin::Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK_GT(qualifiers.size(), 0); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; int count; CEL_RETURN_IF_ERROR(static_cast(this)->Qualify( qualifiers, presence_test, descriptor_pool, message_factory, arena, &result, &count)); return std::pair{std::move(result), count}; } } // namespace common_internal using ValueIteratorPtr = std::unique_ptr; inline absl::StatusOr ValueIterator::Next( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value result; CEL_RETURN_IF_ERROR(Next(descriptor_pool, message_factory, arena, &result)); return result; } inline absl::StatusOr> ValueIterator::Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value key_or_value; CEL_ASSIGN_OR_RETURN( bool ok, Next1(descriptor_pool, message_factory, arena, &key_or_value)); if (!ok) { return absl::nullopt; } return key_or_value; } inline absl::StatusOr>> ValueIterator::Next2(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); Value key; Value value; CEL_ASSIGN_OR_RETURN( bool ok, Next2(descriptor_pool, message_factory, arena, &key, &value)); if (!ok) { return absl::nullopt; } return std::pair{std::move(key), std::move(value)}; } absl_nonnull std::unique_ptr NewEmptyValueIterator(); class ValueBuilder { public: virtual ~ValueBuilder() = default; virtual absl::StatusOr> SetFieldByName( absl::string_view name, Value value) = 0; virtual absl::StatusOr> SetFieldByNumber( int64_t number, Value value) = 0; virtual absl::StatusOr Build() && = 0; }; using ValueBuilderPtr = std::unique_ptr; absl_nonnull ListValueBuilderPtr NewListValueBuilder(google::protobuf::Arena* absl_nonnull arena); absl_nonnull MapValueBuilderPtr NewMapValueBuilder(google::protobuf::Arena* absl_nonnull arena); // Returns a new `StructValueBuilder`. Returns `nullptr` if there is no such // message type with the name `name` in `descriptor_pool`. Returns an error if // `message_factory` is unable to provide a prototype for the descriptor // returned from `descriptor_pool`. absl_nullable StructValueBuilderPtr NewStructValueBuilder( google::protobuf::Arena* absl_nonnull arena, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, absl::string_view name); using ListValueBuilderInterface = ListValueBuilder; using MapValueBuilderInterface = MapValueBuilder; using StructValueBuilderInterface = StructValueBuilder; // Now that Value is complete, we can define various parts of list, map, opaque, // and struct which depend on Value. namespace common_internal { using MapFieldKeyAccessor = void (*)(const google::protobuf::MapKey&, const google::protobuf::Message* absl_nonnull, google::protobuf::Arena* absl_nonnull, Value* absl_nonnull); absl::StatusOr MapFieldKeyAccessorFor( const google::protobuf::FieldDescriptor* absl_nonnull field); using MapFieldValueAccessor = void (*)( const google::protobuf::MapValueConstRef&, const google::protobuf::Message* absl_nonnull, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull, Value* absl_nonnull); absl::StatusOr MapFieldValueAccessorFor( const google::protobuf::FieldDescriptor* absl_nonnull field); using RepeatedFieldAccessor = void (*)(int, const google::protobuf::Message* absl_nonnull, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::Reflection* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull, Value* absl_nonnull); absl::StatusOr RepeatedFieldAccessorFor( const google::protobuf::FieldDescriptor* absl_nonnull field); } // namespace common_internal } // namespace cel #pragma pop_macro("GetMessage") #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUE_H_ ================================================ FILE: common/value_kind.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUE_KIND_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUE_KIND_H_ #include #include #include "absl/base/macros.h" #include "absl/strings/string_view.h" #include "common/kind.h" namespace cel { // `ValueKind` is a subset of `Kind`, representing all valid `Kind` for `Value`. // All `ValueKind` are valid `Kind`, but it is not guaranteed that all `Kind` // are valid `ValueKind`. enum class ValueKind : std::underlying_type_t { kNull = static_cast(Kind::kNull), kBool = static_cast(Kind::kBool), kInt = static_cast(Kind::kInt), kUint = static_cast(Kind::kUint), kDouble = static_cast(Kind::kDouble), kString = static_cast(Kind::kString), kBytes = static_cast(Kind::kBytes), kStruct = static_cast(Kind::kStruct), kDuration = static_cast(Kind::kDuration), kTimestamp = static_cast(Kind::kTimestamp), kList = static_cast(Kind::kList), kMap = static_cast(Kind::kMap), kUnknown = static_cast(Kind::kUnknown), kType = static_cast(Kind::kType), kError = static_cast(Kind::kError), kOpaque = static_cast(Kind::kOpaque), // Legacy aliases, deprecated do not use. kNullType = kNull, kInt64 = kInt, kUint64 = kUint, kMessage = kStruct, kUnknownSet = kUnknown, kCelType = kType, // INTERNAL: Do not exceed 63. Implementation details rely on the fact that // we can store `Kind` using 6 bits. kNotForUseWithExhaustiveSwitchStatements = static_cast(Kind::kNotForUseWithExhaustiveSwitchStatements), }; constexpr Kind ValueKindToKind(ValueKind kind) { return static_cast( static_cast>(kind)); } constexpr bool KindIsValueKind(Kind kind) { return kind != Kind::kBoolWrapper && kind != Kind::kIntWrapper && kind != Kind::kUintWrapper && kind != Kind::kDoubleWrapper && kind != Kind::kStringWrapper && kind != Kind::kBytesWrapper && kind != Kind::kDyn && kind != Kind::kAny && kind != Kind::kTypeParam && kind != Kind::kFunction; } constexpr bool operator==(Kind lhs, ValueKind rhs) { return lhs == ValueKindToKind(rhs); } constexpr bool operator==(ValueKind lhs, Kind rhs) { return ValueKindToKind(lhs) == rhs; } constexpr bool operator!=(Kind lhs, ValueKind rhs) { return !operator==(lhs, rhs); } constexpr bool operator!=(ValueKind lhs, Kind rhs) { return !operator==(lhs, rhs); } inline absl::string_view ValueKindToString(ValueKind kind) { // All ValueKind are valid Kind. return KindToString(ValueKindToKind(kind)); } constexpr ValueKind KindToValueKind(Kind kind) { ABSL_ASSERT(KindIsValueKind(kind)); return static_cast( static_cast>(kind)); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUE_KIND_H_ ================================================ FILE: common/value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/value.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/type.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/types/optional.h" #include "common/type.h" #include "common/value_testing.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/generated_enum_reflection.h" namespace cel { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::internal::DynamicParseTextProto; using ::cel::internal::GetTestingDescriptorPool; using ::cel::internal::GetTestingMessageFactory; using ::testing::An; using ::testing::Eq; using ::testing::NotNull; using ::testing::Optional; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; TEST(Value, GeneratedEnum) { EXPECT_EQ(Value::Enum(google::protobuf::NULL_VALUE), NullValue()); EXPECT_EQ(Value::Enum(google::protobuf::SYNTAX_EDITIONS), IntValue(2)); } TEST(Value, DynamicEnum) { EXPECT_THAT( Value::Enum(google::protobuf::GetEnumDescriptor(), 0), test::IsNullValue()); EXPECT_THAT( Value::Enum(google::protobuf::GetEnumDescriptor() ->FindValueByNumber(0)), test::IsNullValue()); EXPECT_THAT( Value::Enum(google::protobuf::GetEnumDescriptor(), 2), test::IntValueIs(2)); EXPECT_THAT(Value::Enum(google::protobuf::GetEnumDescriptor() ->FindValueByNumber(2)), test::IntValueIs(2)); } TEST(Value, DynamicClosedEnum) { google::protobuf::FileDescriptorProto file_descriptor; file_descriptor.set_name("test/closed_enum.proto"); file_descriptor.set_package("test"); file_descriptor.set_syntax("editions"); file_descriptor.set_edition(google::protobuf::EDITION_2023); { auto* enum_descriptor = file_descriptor.add_enum_type(); enum_descriptor->set_name("ClosedEnum"); enum_descriptor->mutable_options()->mutable_features()->set_enum_type( google::protobuf::FeatureSet::CLOSED); auto* enum_value_descriptor = enum_descriptor->add_value(); enum_value_descriptor->set_number(1); enum_value_descriptor->set_name("FOO"); enum_value_descriptor = enum_descriptor->add_value(); enum_value_descriptor->set_number(2); enum_value_descriptor->set_name("BAR"); } google::protobuf::DescriptorPool pool; ASSERT_THAT(pool.BuildFile(file_descriptor), NotNull()); const auto* enum_descriptor = pool.FindEnumTypeByName("test.ClosedEnum"); ASSERT_THAT(enum_descriptor, NotNull()); EXPECT_THAT(Value::Enum(enum_descriptor, 0), test::ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))); } TEST(Value, Is) { google::protobuf::Arena arena; EXPECT_TRUE(Value(BoolValue()).Is()); EXPECT_TRUE(Value(BoolValue(true)).IsTrue()); EXPECT_TRUE(Value(BoolValue(false)).IsFalse()); EXPECT_TRUE(Value(BytesValue()).Is()); EXPECT_TRUE(Value(DoubleValue()).Is()); EXPECT_TRUE(Value(DurationValue()).Is()); EXPECT_TRUE(Value(ErrorValue()).Is()); EXPECT_TRUE(Value(IntValue()).Is()); EXPECT_TRUE(Value(ListValue()).Is()); EXPECT_TRUE(Value(CustomListValue()).Is()); EXPECT_TRUE(Value(CustomListValue()).Is()); EXPECT_TRUE(Value(ParsedJsonListValue()).Is()); EXPECT_TRUE(Value(ParsedJsonListValue()).Is()); { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("repeated_int32")); EXPECT_TRUE(Value(ParsedRepeatedFieldValue(message, field, &arena)) .Is()); EXPECT_TRUE(Value(ParsedRepeatedFieldValue(message, field, &arena)) .Is()); } EXPECT_TRUE(Value(MapValue()).Is()); EXPECT_TRUE(Value(CustomMapValue()).Is()); EXPECT_TRUE(Value(CustomMapValue()).Is()); EXPECT_TRUE(Value(ParsedJsonMapValue()).Is()); EXPECT_TRUE(Value(ParsedJsonMapValue()).Is()); { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("map_int32_int32")); EXPECT_TRUE( Value(ParsedMapFieldValue(message, field, &arena)).Is()); EXPECT_TRUE(Value(ParsedMapFieldValue(message, field, &arena)) .Is()); } EXPECT_TRUE(Value(NullValue()).Is()); EXPECT_TRUE(Value(OptionalValue()).Is()); EXPECT_TRUE(Value(OptionalValue()).Is()); EXPECT_TRUE(Value(ParsedMessageValue()).Is()); EXPECT_TRUE(Value(ParsedMessageValue()).Is()); EXPECT_TRUE(Value(ParsedMessageValue()).Is()); EXPECT_TRUE(Value(StringValue()).Is()); EXPECT_TRUE(Value(TimestampValue()).Is()); EXPECT_TRUE(Value(TypeValue(StringType())).Is()); EXPECT_TRUE(Value(UintValue()).Is()); EXPECT_TRUE(Value(UnknownValue()).Is()); } template constexpr T& AsLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return t; } template constexpr const T& AsConstLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return t; } template constexpr T&& AsRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return static_cast(t); } template constexpr const T&& AsConstRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return static_cast(t); } TEST(Value, As) { google::protobuf::Arena arena; EXPECT_THAT(Value(BoolValue()).As(), Optional(An())); EXPECT_THAT(Value(BoolValue()).As(), Eq(absl::nullopt)); { Value value(BytesValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } EXPECT_THAT(Value(DoubleValue()).As(), Optional(An())); EXPECT_THAT(Value(DoubleValue()).As(), Eq(absl::nullopt)); EXPECT_THAT(Value(DurationValue()).As(), Optional(An())); EXPECT_THAT(Value(DurationValue()).As(), Eq(absl::nullopt)); { Value value(ErrorValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(ErrorValue()).As(), Eq(absl::nullopt)); } EXPECT_THAT(Value(IntValue()).As(), Optional(An())); EXPECT_THAT(Value(IntValue()).As(), Eq(absl::nullopt)); { Value value(ListValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(ListValue()).As(), Eq(absl::nullopt)); } { Value value(ParsedJsonListValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(ListValue()).As(), Eq(absl::nullopt)); } { Value value(ParsedJsonListValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { Value value(CustomListValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(ListValue()).As(), Eq(absl::nullopt)); } { Value value(CustomListValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("repeated_int32")); Value value(ParsedRepeatedFieldValue{message, field, &arena}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("repeated_int32")); Value value(ParsedRepeatedFieldValue{message, field, &arena}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT( AsConstRValueRef(other_value).As(), Optional(An())); } { Value value(MapValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(MapValue()).As(), Eq(absl::nullopt)); } { Value value(ParsedJsonMapValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(MapValue()).As(), Eq(absl::nullopt)); } { Value value(ParsedJsonMapValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { Value value(CustomMapValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(MapValue()).As(), Eq(absl::nullopt)); } { Value value(CustomMapValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("map_int32_int32")); Value value(ParsedMapFieldValue{message, field, &arena}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("map_int32_int32")); Value value(ParsedMapFieldValue{message, field, &arena}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { Value value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(ParsedMessageValue{ DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}) .As(), Eq(absl::nullopt)); } EXPECT_THAT(Value(NullValue()).As(), Optional(An())); EXPECT_THAT(Value(NullValue()).As(), Eq(absl::nullopt)); { Value value(OptionalValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(OpaqueValue(OptionalValue())).As(), Eq(absl::nullopt)); } { Value value(OptionalValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(OptionalValue()).As(), Eq(absl::nullopt)); } { OpaqueValue value(OptionalValue{}); OpaqueValue other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { Value value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { Value value(StringValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(StringValue()).As(), Eq(absl::nullopt)); } { Value value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } EXPECT_THAT(Value(TimestampValue()).As(), Optional(An())); EXPECT_THAT(Value(TimestampValue()).As(), Eq(absl::nullopt)); { Value value(TypeValue(StringType{})); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(TypeValue(StringType())).As(), Eq(absl::nullopt)); } EXPECT_THAT(Value(UintValue()).As(), Optional(An())); EXPECT_THAT(Value(UintValue()).As(), Eq(absl::nullopt)); { Value value(UnknownValue{}); Value other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); EXPECT_THAT(Value(UnknownValue()).As(), Eq(absl::nullopt)); } } template decltype(auto) DoGet(From&& from) { return std::forward(from).template Get(); } TEST(Value, Get) { google::protobuf::Arena arena; EXPECT_THAT(DoGet(Value(BoolValue())), An()); { Value value(BytesValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } EXPECT_THAT(DoGet(Value(DoubleValue())), An()); EXPECT_THAT(DoGet(Value(DurationValue())), An()); { Value value(ErrorValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } EXPECT_THAT(DoGet(Value(IntValue())), An()); { Value value(ListValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(ParsedJsonListValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(ParsedJsonListValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT( DoGet(AsConstRValueRef(other_value)), An()); } { Value value(CustomListValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(CustomListValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("repeated_int32")); Value value(ParsedRepeatedFieldValue{message, field, &arena}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("repeated_int32")); Value value(ParsedRepeatedFieldValue{message, field, &arena}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT( DoGet(AsConstRValueRef(other_value)), An()); } { Value value(MapValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(ParsedJsonMapValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(ParsedJsonMapValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(CustomMapValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(CustomMapValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("map_int32_int32")); Value value(ParsedMapFieldValue{message, field, &arena}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { auto message = DynamicParseTextProto( &arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()); const auto* field = ABSL_DIE_IF_NULL( message->GetDescriptor()->FindFieldByName("map_int32_int32")); Value value(ParsedMapFieldValue{message, field, &arena}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT( DoGet(AsConstRValueRef(other_value)), An()); } { Value value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } EXPECT_THAT(DoGet(Value(NullValue())), An()); { Value value(OptionalValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(OptionalValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { OpaqueValue value(OptionalValue{}); OpaqueValue other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT( DoGet(AsConstRValueRef(other_value)), An()); } { Value value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(StringValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { Value value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } EXPECT_THAT(DoGet(Value(TimestampValue())), An()); { Value value(TypeValue(StringType{})); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } EXPECT_THAT(DoGet(Value(UintValue())), An()); { Value value(UnknownValue{}); Value other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } } TEST(Value, NumericHeterogeneousEquality) { EXPECT_EQ(IntValue(1), UintValue(1)); EXPECT_EQ(UintValue(1), IntValue(1)); EXPECT_EQ(IntValue(1), DoubleValue(1)); EXPECT_EQ(DoubleValue(1), IntValue(1)); EXPECT_EQ(UintValue(1), DoubleValue(1)); EXPECT_EQ(DoubleValue(1), UintValue(1)); EXPECT_NE(IntValue(1), UintValue(2)); EXPECT_NE(UintValue(1), IntValue(2)); EXPECT_NE(IntValue(1), DoubleValue(2)); EXPECT_NE(DoubleValue(1), IntValue(2)); EXPECT_NE(UintValue(1), DoubleValue(2)); EXPECT_NE(DoubleValue(1), UintValue(2)); } using ValueIteratorTest = common_internal::ValueTest<>; TEST_F(ValueIteratorTest, Empty) { auto iterator = NewEmptyValueIterator(); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(ValueIteratorTest, Empty1) { auto iterator = NewEmptyValueIterator(); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ValueIteratorTest, Empty2) { auto iterator = NewEmptyValueIterator(); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } } // namespace } // namespace cel ================================================ FILE: common/value_testing.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/value_testing.h" #include #include #include #include #include "absl/status/status.h" #include "absl/time/time.h" #include "common/value.h" #include "common/value_kind.h" #include "internal/testing.h" namespace cel { void PrintTo(const Value& value, std::ostream* os) { *os << value << "\n"; } namespace test { namespace { using ::testing::Matcher; template constexpr ValueKind ToValueKind() { if constexpr (std::is_same_v) { return ValueKind::kBool; } else if constexpr (std::is_same_v) { return ValueKind::kInt; } else if constexpr (std::is_same_v) { return ValueKind::kUint; } else if constexpr (std::is_same_v) { return ValueKind::kDouble; } else if constexpr (std::is_same_v) { return ValueKind::kString; } else if constexpr (std::is_same_v) { return ValueKind::kBytes; } else if constexpr (std::is_same_v) { return ValueKind::kDuration; } else if constexpr (std::is_same_v) { return ValueKind::kTimestamp; } else if constexpr (std::is_same_v) { return ValueKind::kError; } else if constexpr (std::is_same_v) { return ValueKind::kMap; } else if constexpr (std::is_same_v) { return ValueKind::kList; } else if constexpr (std::is_same_v) { return ValueKind::kStruct; } else if constexpr (std::is_same_v) { return ValueKind::kOpaque; } else { // Otherwise, unspecified (uninitialized value) return ValueKind::kError; } } template class SimpleTypeMatcherImpl : public testing::MatcherInterface { public: using MatcherType = Matcher; explicit SimpleTypeMatcherImpl(MatcherType&& matcher) : matcher_(std::forward(matcher)) {} bool MatchAndExplain(const Value& v, testing::MatchResultListener* listener) const override { return v.Is() && matcher_.MatchAndExplain(v.Get().NativeValue(), listener); } void DescribeTo(std::ostream* os) const override { *os << absl::StrCat("kind is ", ValueKindToString(ToValueKind()), " and "); matcher_.DescribeTo(os); } private: MatcherType matcher_; }; template class StringTypeMatcherImpl : public testing::MatcherInterface { public: using MatcherType = Matcher; explicit StringTypeMatcherImpl(MatcherType matcher) : matcher_((std::move(matcher))) {} bool MatchAndExplain(const Value& v, testing::MatchResultListener* listener) const override { return v.Is() && matcher_.Matches(v.Get().ToString()); } void DescribeTo(std::ostream* os) const override { *os << absl::StrCat("kind is ", ValueKindToString(ToValueKind()), " and "); matcher_.DescribeTo(os); } private: MatcherType matcher_; }; template class AbstractTypeMatcherImpl : public testing::MatcherInterface { public: using MatcherType = Matcher; explicit AbstractTypeMatcherImpl(MatcherType&& matcher) : matcher_(std::forward(matcher)) {} bool MatchAndExplain(const Value& v, testing::MatchResultListener* listener) const override { return v.Is() && matcher_.Matches(v.template Get()); } void DescribeTo(std::ostream* os) const override { *os << absl::StrCat("kind is ", ValueKindToString(ToValueKind()), " and "); matcher_.DescribeTo(os); } private: MatcherType matcher_; }; class OptionalValueMatcherImpl : public testing::MatcherInterface { public: explicit OptionalValueMatcherImpl(ValueMatcher matcher) : matcher_(std::move(matcher)) {} bool MatchAndExplain(const Value& v, testing::MatchResultListener* listener) const override { if (!v.IsOptional()) { *listener << "wanted OptionalValue, got " << ValueKindToString(v.kind()); return false; } const auto& optional_value = v.GetOptional(); if (!optional_value.HasValue()) { *listener << "OptionalValue is not engaged"; return false; } return matcher_.MatchAndExplain(optional_value.Value(), listener); } void DescribeTo(std::ostream* os) const override { *os << "is OptionalValue that is engaged with value whose "; matcher_.DescribeTo(os); } private: ValueMatcher matcher_; }; MATCHER(OptionalValueIsEmptyImpl, "is empty OptionalValue") { const Value& v = arg; if (!v.IsOptional()) { *result_listener << "wanted OptionalValue, got " << ValueKindToString(v.kind()); return false; } const auto& optional_value = v.GetOptional(); *result_listener << (optional_value.HasValue() ? "is not empty" : "is empty"); return !optional_value.HasValue(); } } // namespace ValueMatcher BoolValueIs(Matcher m) { return ValueMatcher(new SimpleTypeMatcherImpl(std::move(m))); } ValueMatcher IntValueIs(Matcher m) { return ValueMatcher( new SimpleTypeMatcherImpl(std::move(m))); } ValueMatcher UintValueIs(Matcher m) { return ValueMatcher( new SimpleTypeMatcherImpl(std::move(m))); } ValueMatcher DoubleValueIs(Matcher m) { return ValueMatcher( new SimpleTypeMatcherImpl(std::move(m))); } ValueMatcher TimestampValueIs(Matcher m) { return ValueMatcher( new SimpleTypeMatcherImpl(std::move(m))); } ValueMatcher DurationValueIs(Matcher m) { return ValueMatcher( new SimpleTypeMatcherImpl(std::move(m))); } ValueMatcher ErrorValueIs(Matcher m) { return ValueMatcher( new SimpleTypeMatcherImpl(std::move(m))); } ValueMatcher StringValueIs(Matcher m) { return ValueMatcher(new StringTypeMatcherImpl(std::move(m))); } ValueMatcher BytesValueIs(Matcher m) { return ValueMatcher(new StringTypeMatcherImpl(std::move(m))); } ValueMatcher MapValueIs(Matcher m) { return ValueMatcher(new AbstractTypeMatcherImpl(std::move(m))); } ValueMatcher ListValueIs(Matcher m) { return ValueMatcher(new AbstractTypeMatcherImpl(std::move(m))); } ValueMatcher StructValueIs(Matcher m) { return ValueMatcher(new AbstractTypeMatcherImpl(std::move(m))); } ValueMatcher OptionalValueIs(ValueMatcher m) { return ValueMatcher(new OptionalValueMatcherImpl(std::move(m))); } ValueMatcher OptionalValueIsEmpty() { return OptionalValueIsEmptyImpl(); } } // namespace test } // namespace cel ================================================ FILE: common/value_testing.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUE_TESTING_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUE_TESTING_H_ #include #include #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/value.h" #include "common/value_kind.h" #include "internal/equals_text_proto.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { // GTest Printer void PrintTo(const Value& value, std::ostream* os); namespace test { using ValueMatcher = testing::Matcher; MATCHER_P(ValueKindIs, m, "") { return ExplainMatchResult(m, arg.kind(), result_listener); } // Returns a matcher for CEL null value. inline ValueMatcher IsNullValue() { return ValueKindIs(ValueKind::kNull); } // Returns a matcher for CEL bool values. ValueMatcher BoolValueIs(testing::Matcher m); // Returns a matcher for CEL int values. ValueMatcher IntValueIs(testing::Matcher m); // Returns a matcher for CEL uint values. ValueMatcher UintValueIs(testing::Matcher m); // Returns a matcher for CEL double values. ValueMatcher DoubleValueIs(testing::Matcher m); // Returns a matcher for CEL duration values. ValueMatcher DurationValueIs(testing::Matcher m); // Returns a matcher for CEL timestamp values. ValueMatcher TimestampValueIs(testing::Matcher m); // Returns a matcher for CEL error values. ValueMatcher ErrorValueIs(testing::Matcher m); // Returns a matcher for CEL string values. ValueMatcher StringValueIs(testing::Matcher m); // Returns a matcher for CEL bytes values. ValueMatcher BytesValueIs(testing::Matcher m); // Returns a matcher for CEL map values. ValueMatcher MapValueIs(testing::Matcher m); // Returns a matcher for CEL list values. ValueMatcher ListValueIs(testing::Matcher m); // Returns a matcher for CEL struct values. ValueMatcher StructValueIs(testing::Matcher m); // Returns a matcher for CEL struct values. ValueMatcher OptionalValueIsEmpty(); // Returns a matcher for CEL struct values. ValueMatcher OptionalValueIs(ValueMatcher m); // Returns a Matcher that tests the value of a CEL struct's field. // ValueManager* mgr must remain valid for the lifetime of the matcher. MATCHER_P5(StructValueFieldIs, name, m, descriptor_pool, message_factory, arena, "") { auto wrapped_m = ::absl_testing::IsOkAndHolds(m); return ExplainMatchResult(wrapped_m, cel::StructValue(arg).GetFieldByName( name, descriptor_pool, message_factory, arena), result_listener); } // Returns a Matcher that tests the presence of a CEL struct's field. // ValueManager* mgr must remain valid for the lifetime of the matcher. MATCHER_P2(StructValueFieldHas, name, m, "") { auto wrapped_m = ::absl_testing::IsOkAndHolds(m); return ExplainMatchResult( wrapped_m, cel::StructValue(arg).HasFieldByName(name), result_listener); } class ListValueElementsMatcher { public: using is_gtest_matcher = void; explicit ListValueElementsMatcher( testing::Matcher>&& m, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : m_(std::move(m)), descriptor_pool_(ABSL_DIE_IF_NULL(descriptor_pool)), // Crash OK message_factory_(ABSL_DIE_IF_NULL(message_factory)), // Crash OK arena_(ABSL_DIE_IF_NULL(arena)) // Crash OK {} bool MatchAndExplain(const ListValue& arg, testing::MatchResultListener* result_listener) const { std::vector elements; absl::Status s = arg.ForEach( [&](const Value& v) -> absl::StatusOr { elements.push_back(v); return true; }, descriptor_pool_, message_factory_, arena_); if (!s.ok()) { *result_listener << "cannot convert to list of values: " << s; return false; } return m_.MatchAndExplain(elements, result_listener); } void DescribeTo(std::ostream* os) const { *os << m_; } void DescribeNegationTo(std::ostream* os) const { *os << m_; } private: testing::Matcher> m_; const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; google::protobuf::MessageFactory* absl_nonnull message_factory_; google::protobuf::Arena* absl_nonnull arena_; }; // Returns a matcher that tests the elements of a cel::ListValue on a given // matcher as if they were a std::vector. // ValueManager* mgr must remain valid for the lifetime of the matcher. inline ListValueElementsMatcher ListValueElements( testing::Matcher>&& m, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { return ListValueElementsMatcher(std::move(m), descriptor_pool, message_factory, arena); } class MapValueElementsMatcher { public: using is_gtest_matcher = void; explicit MapValueElementsMatcher( testing::Matcher>>&& m, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : m_(std::move(m)), descriptor_pool_(ABSL_DIE_IF_NULL(descriptor_pool)), // Crash OK message_factory_(ABSL_DIE_IF_NULL(message_factory)), // Crash OK arena_(ABSL_DIE_IF_NULL(arena)) // Crash OK {} bool MatchAndExplain(const MapValue& arg, testing::MatchResultListener* result_listener) const { std::vector> elements; absl::Status s = arg.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { elements.push_back({key, value}); return true; }, descriptor_pool_, message_factory_, arena_); if (!s.ok()) { *result_listener << "cannot convert to list of values: " << s; return false; } return m_.MatchAndExplain(elements, result_listener); } void DescribeTo(std::ostream* os) const { *os << m_; } void DescribeNegationTo(std::ostream* os) const { *os << m_; } private: testing::Matcher>> m_; const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; google::protobuf::MessageFactory* absl_nonnull message_factory_; google::protobuf::Arena* absl_nonnull arena_; }; // Returns a matcher that tests the elements of a cel::MapValue on a given // matcher as if they were a std::vector>. // ValueManager* mgr must remain valid for the lifetime of the matcher. inline MapValueElementsMatcher MapValueElements( testing::Matcher>>&& m, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { return MapValueElementsMatcher(std::move(m), descriptor_pool, message_factory, arena); } } // namespace test } // namespace cel namespace cel::common_internal { template class ValueTest : public ::testing::TestWithParam> { public: google::protobuf::Arena* absl_nonnull arena() { return &arena_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return ::cel::internal::GetTestingDescriptorPool(); } google::protobuf::MessageFactory* absl_nonnull message_factory() { return ::cel::internal::GetTestingMessageFactory(); } google::protobuf::Message* absl_nonnull NewArenaValueMessage() { return ABSL_DIE_IF_NULL( // Crash OK message_factory()->GetPrototype(ABSL_DIE_IF_NULL( // Crash OK descriptor_pool()->FindMessageTypeByName( "google.protobuf.Value")))) ->New(arena()); } template auto GeneratedParseTextProto(absl::string_view text = "") { return ::cel::internal::GeneratedParseTextProto( arena(), text, descriptor_pool(), message_factory()); } template auto DynamicParseTextProto(absl::string_view text = "") { return ::cel::internal::DynamicParseTextProto( arena(), text, descriptor_pool(), message_factory()); } template auto EqualsTextProto(absl::string_view text) { return ::cel::internal::EqualsTextProto(arena(), text, descriptor_pool(), message_factory()); } auto EqualsValueTextProto(absl::string_view text) { return EqualsTextProto(text); } template const google::protobuf::FieldDescriptor* absl_nonnull DynamicGetField( absl::string_view name) { return ABSL_DIE_IF_NULL( // Crash OK ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( // Crash OK internal::MessageTypeNameFor())) ->FindFieldByName(name)); } template ParsedMessageValue MakeParsedMessage(absl::string_view text = R"pb()pb") { return ParsedMessageValue(DynamicParseTextProto(text), arena()); } private: google::protobuf::Arena arena_; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUE_TESTING_H_ ================================================ FILE: common/value_testing_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/value_testing.h" #include #include "gtest/gtest-spi.h" #include "absl/status/status.h" #include "absl/time/time.h" #include "common/value.h" #include "internal/testing.h" namespace cel::test { namespace { using ::absl_testing::StatusIs; using ::testing::_; using ::testing::ElementsAre; using ::testing::Truly; using ::testing::UnorderedElementsAre; TEST(BoolValueIs, Match) { EXPECT_THAT(BoolValue(true), BoolValueIs(true)); } TEST(BoolValueIs, NoMatch) { EXPECT_THAT(BoolValue(false), Not(BoolValueIs(true))); EXPECT_THAT(IntValue(2), Not(BoolValueIs(true))); } TEST(BoolValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(IntValue(42), BoolValueIs(true)); }(), "kind is bool and is equal to true"); } TEST(IntValueIs, Match) { EXPECT_THAT(IntValue(42), IntValueIs(42)); } TEST(IntValueIs, NoMatch) { EXPECT_THAT(IntValue(-42), Not(IntValueIs(42))); EXPECT_THAT(UintValue(2), Not(IntValueIs(42))); } TEST(IntValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(UintValue(42), IntValueIs(42)); }(), "kind is int and is equal to 42"); } TEST(UintValueIs, Match) { EXPECT_THAT(UintValue(42), UintValueIs(42)); } TEST(UintValueIs, NoMatch) { EXPECT_THAT(UintValue(41), Not(UintValueIs(42))); EXPECT_THAT(IntValue(2), Not(UintValueIs(42))); } TEST(UintValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(IntValue(42), UintValueIs(42)); }(), "kind is uint and is equal to 42"); } TEST(DoubleValueIs, Match) { EXPECT_THAT(DoubleValue(1.2), DoubleValueIs(1.2)); } TEST(DoubleValueIs, NoMatch) { EXPECT_THAT(DoubleValue(41), Not(DoubleValueIs(1.2))); EXPECT_THAT(IntValue(2), Not(DoubleValueIs(1.2))); } TEST(DoubleValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(IntValue(42), DoubleValueIs(1.2)); }(), "kind is double and is equal to 1.2"); } TEST(DurationValueIs, Match) { EXPECT_THAT(DurationValue(absl::Minutes(2)), DurationValueIs(absl::Minutes(2))); } TEST(DurationValueIs, NoMatch) { EXPECT_THAT(DurationValue(absl::Minutes(5)), Not(DurationValueIs(absl::Minutes(2)))); EXPECT_THAT(IntValue(2), Not(DurationValueIs(absl::Minutes(2)))); } TEST(DurationValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(IntValue(42), DurationValueIs(absl::Minutes(2))); }(), "kind is duration and is equal to 2m"); } TEST(TimestampValueIs, Match) { EXPECT_THAT(TimestampValue(absl::UnixEpoch() + absl::Minutes(2)), TimestampValueIs(absl::UnixEpoch() + absl::Minutes(2))); } TEST(TimestampValueIs, NoMatch) { EXPECT_THAT(TimestampValue(absl::UnixEpoch()), Not(TimestampValueIs(absl::UnixEpoch() + absl::Minutes(2)))); EXPECT_THAT(IntValue(2), Not(TimestampValueIs(absl::UnixEpoch() + absl::Minutes(2)))); } TEST(TimestampValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(IntValue(42), TimestampValueIs(absl::UnixEpoch() + absl::Minutes(2))); }(), "kind is timestamp and is equal to 19"); } TEST(StringValueIs, Match) { EXPECT_THAT(StringValue("hello!"), StringValueIs("hello!")); } TEST(StringValueIs, NoMatch) { EXPECT_THAT(StringValue("hello!"), Not(StringValueIs("goodbye!"))); EXPECT_THAT(IntValue(2), Not(StringValueIs("goodbye!"))); } TEST(StringValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(IntValue(42), StringValueIs("hello!")); }(), "kind is string and is equal to \"hello!\""); } TEST(BytesValueIs, Match) { EXPECT_THAT(BytesValue("hello!"), BytesValueIs("hello!")); } TEST(BytesValueIs, NoMatch) { EXPECT_THAT(BytesValue("hello!"), Not(BytesValueIs("goodbye!"))); EXPECT_THAT(IntValue(2), Not(BytesValueIs("goodbye!"))); } TEST(BytesValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(IntValue(42), BytesValueIs("hello!")); }(), "kind is bytes and is equal to \"hello!\""); } TEST(ErrorValueIs, Match) { EXPECT_THAT(ErrorValue(absl::InternalError("test")), ErrorValueIs(StatusIs(absl::StatusCode::kInternal, "test"))); } TEST(ErrorValueIs, NoMatch) { EXPECT_THAT(ErrorValue(absl::UnknownError("test")), Not(ErrorValueIs(StatusIs(absl::StatusCode::kInternal, "test")))); EXPECT_THAT(IntValue(2), Not(ErrorValueIs(_))); } TEST(ErrorValueIs, NonMatchMessage) { EXPECT_NONFATAL_FAILURE( []() { EXPECT_THAT(IntValue(42), ErrorValueIs(StatusIs( absl::StatusCode::kInternal, "test"))); }(), "kind is *error* and"); } using ValueMatcherTest = common_internal::ValueTest<>; TEST_F(ValueMatcherTest, OptionalValueIsMatch) { EXPECT_THAT(OptionalValue::Of(IntValue(42), arena()), OptionalValueIs(IntValueIs(42))); } TEST_F(ValueMatcherTest, OptionalValueIsHeldValueDifferent) { EXPECT_NONFATAL_FAILURE( [&]() { EXPECT_THAT(OptionalValue::Of(IntValue(-42), arena()), OptionalValueIs(IntValueIs(42))); }(), "is OptionalValue that is engaged with value whose kind is int and is " "equal to 42"); } TEST_F(ValueMatcherTest, OptionalValueIsNotEngaged) { EXPECT_NONFATAL_FAILURE( [&]() { EXPECT_THAT(OptionalValue::None(), OptionalValueIs(IntValueIs(42))); }(), "is not engaged"); } TEST_F(ValueMatcherTest, OptionalValueIsNotAnOptional) { EXPECT_NONFATAL_FAILURE( [&]() { EXPECT_THAT(IntValue(42), OptionalValueIs(IntValueIs(42))); }(), "wanted OptionalValue, got int"); } TEST_F(ValueMatcherTest, OptionalValueIsEmptyMatch) { EXPECT_THAT(OptionalValue::None(), OptionalValueIsEmpty()); } TEST_F(ValueMatcherTest, OptionalValueIsEmptyNotEmpty) { EXPECT_NONFATAL_FAILURE( [&]() { EXPECT_THAT(OptionalValue::Of(IntValue(42), arena()), OptionalValueIsEmpty()); }(), "is not empty"); } TEST_F(ValueMatcherTest, OptionalValueIsEmptyNotOptional) { EXPECT_NONFATAL_FAILURE( [&]() { EXPECT_THAT(IntValue(42), OptionalValueIsEmpty()); }(), "wanted OptionalValue, got int"); } TEST_F(ValueMatcherTest, ListMatcherBasic) { auto builder = NewListValueBuilder(arena()); ASSERT_OK(builder->Add(IntValue(42))); Value list_value = std::move(*builder).Build(); EXPECT_THAT(list_value, ListValueIs(Truly([](const ListValue& v) { auto size = v.Size(); return size.ok() && *size == 1; }))); } TEST_F(ValueMatcherTest, ListMatcherMatchesElements) { auto builder = NewListValueBuilder(arena()); ASSERT_OK(builder->Add(IntValue(42))); ASSERT_OK(builder->Add(IntValue(1337))); ASSERT_OK(builder->Add(IntValue(42))); ASSERT_OK(builder->Add(IntValue(100))); EXPECT_THAT(std::move(*builder).Build(), ListValueIs(ListValueElements( ElementsAre(IntValueIs(42), IntValueIs(1337), IntValueIs(42), IntValueIs(100)), descriptor_pool(), message_factory(), arena()))); } TEST_F(ValueMatcherTest, MapMatcherBasic) { auto builder = NewMapValueBuilder(arena()); ASSERT_OK(builder->Put(IntValue(42), IntValue(42))); Value map_value = std::move(*builder).Build(); EXPECT_THAT(map_value, MapValueIs(Truly([](const MapValue& v) { auto size = v.Size(); return size.ok() && *size == 1; }))); } TEST_F(ValueMatcherTest, MapMatcherMatchesElements) { auto builder = NewMapValueBuilder(arena()); ASSERT_OK(builder->Put(IntValue(42), StringValue("answer"))); ASSERT_OK(builder->Put(IntValue(1337), StringValue("leet"))); EXPECT_THAT( std::move(*builder).Build(), MapValueIs(MapValueElements( UnorderedElementsAre(Pair(IntValueIs(42), StringValueIs("answer")), Pair(IntValueIs(1337), StringValueIs("leet"))), descriptor_pool(), message_factory(), arena()))); } } // namespace } // namespace cel::test ================================================ FILE: common/values/bool_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::ValueReflection; std::string BoolDebugString(bool value) { return value ? "true" : "false"; } } // namespace std::string BoolValue::DebugString() const { return BoolDebugString(NativeValue()); } absl::Status BoolValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::BoolValue message; message.set_value(NativeValue()); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", message.GetTypeName())); } return absl::OkStatus(); } absl::Status BoolValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.SetBoolValue(json, NativeValue()); return absl::OkStatus(); } absl::Status BoolValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsBool(); other_value.has_value()) { *result = BoolValue{NativeValue() == other_value->NativeValue()}; return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/bool_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BOOL_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BOOL_VALUE_H_ #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class BoolValue; // `BoolValue` represents values of the primitive `bool` type. class BoolValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kBool; BoolValue() = default; BoolValue(const BoolValue&) = default; BoolValue(BoolValue&&) = default; BoolValue& operator=(const BoolValue&) = default; BoolValue& operator=(BoolValue&&) = default; explicit BoolValue(bool value) noexcept : value_(value) {} // NOLINTNEXTLINE(google-explicit-constructor) operator bool() const noexcept { return value_; } ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return BoolType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return NativeValue() == false; } bool NativeValue() const { return static_cast(*this); } friend void swap(BoolValue& lhs, BoolValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); } private: friend class common_internal::ValueMixin; bool value_ = false; }; template H AbslHashValue(H state, BoolValue value) { return H::combine(std::move(state), value.NativeValue()); } inline std::ostream& operator<<(std::ostream& out, BoolValue value) { return out << value.DebugString(); } inline BoolValue FalseValue() noexcept { return BoolValue(false); } inline BoolValue TrueValue() noexcept { return BoolValue(true); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BOOL_VALUE_H_ ================================================ FILE: common/values/bool_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/hash/hash.h" #include "absl/status/status_matchers.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using BoolValueTest = common_internal::ValueTest<>; TEST_F(BoolValueTest, Kind) { EXPECT_EQ(BoolValue(true).kind(), BoolValue::kKind); EXPECT_EQ(Value(BoolValue(true)).kind(), BoolValue::kKind); } TEST_F(BoolValueTest, DebugString) { { std::ostringstream out; out << BoolValue(true); EXPECT_EQ(out.str(), "true"); } { std::ostringstream out; out << Value(BoolValue(true)); EXPECT_EQ(out.str(), "true"); } } TEST_F(BoolValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(BoolValue(false).ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(bool_value: false)pb")); } TEST_F(BoolValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(BoolValue(true)), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(BoolValue(true))), NativeTypeId::For()); } TEST_F(BoolValueTest, HashValue) { EXPECT_EQ(absl::HashOf(BoolValue(true)), absl::HashOf(true)); } TEST_F(BoolValueTest, Equality) { EXPECT_NE(BoolValue(false), true); EXPECT_NE(true, BoolValue(false)); EXPECT_NE(BoolValue(false), BoolValue(true)); } TEST_F(BoolValueTest, LessThan) { EXPECT_LT(BoolValue(false), true); EXPECT_LT(false, BoolValue(true)); EXPECT_LT(BoolValue(false), BoolValue(true)); } } // namespace } // namespace cel ================================================ FILE: common/values/bytes_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/internal/byte_string.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::ValueReflection; template std::string BytesDebugString(const Bytes& value) { return value.NativeValue(absl::Overload( [](absl::string_view string) -> std::string { return internal::FormatBytesLiteral(string); }, [](const absl::Cord& cord) -> std::string { if (auto flat = cord.TryFlat(); flat.has_value()) { return internal::FormatBytesLiteral(*flat); } return internal::FormatBytesLiteral(static_cast(cord)); })); } } // namespace BytesValue BytesValue::Concat(const BytesValue& lhs, const BytesValue& rhs, google::protobuf::Arena* absl_nonnull arena) { return BytesValue( common_internal::ByteString::Concat(lhs.value_, rhs.value_, arena)); } std::string BytesValue::DebugString() const { return BytesDebugString(*this); } absl::Status BytesValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::BytesValue message; message.set_value(NativeString()); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", message.GetTypeName())); } return absl::OkStatus(); } absl::Status BytesValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); NativeValue([&](const auto& value) { value_reflection.SetStringValueFromBytes(json, value); }); return absl::OkStatus(); } absl::Status BytesValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsBytes(); other_value.has_value()) { *result = NativeValue([other_value](const auto& value) -> BoolValue { return other_value->NativeValue( [&value](const auto& other_value) -> BoolValue { return BoolValue{value == other_value}; }); }); return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } BytesValue BytesValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { return BytesValue(value_.Clone(arena)); } size_t BytesValue::Size() const { return NativeValue( [](const auto& alternative) -> size_t { return alternative.size(); }); } bool BytesValue::IsEmpty() const { return NativeValue( [](const auto& alternative) -> bool { return alternative.empty(); }); } bool BytesValue::Equals(absl::string_view bytes) const { return NativeValue([bytes](const auto& alternative) -> bool { return alternative == bytes; }); } bool BytesValue::Equals(const absl::Cord& bytes) const { return NativeValue([&bytes](const auto& alternative) -> bool { return alternative == bytes; }); } bool BytesValue::Equals(const BytesValue& bytes) const { return bytes.NativeValue( [this](const auto& alternative) -> bool { return Equals(alternative); }); } namespace { int CompareImpl(absl::string_view lhs, absl::string_view rhs) { return lhs.compare(rhs); } int CompareImpl(absl::string_view lhs, const absl::Cord& rhs) { return -rhs.Compare(lhs); } int CompareImpl(const absl::Cord& lhs, absl::string_view rhs) { return lhs.Compare(rhs); } int CompareImpl(const absl::Cord& lhs, const absl::Cord& rhs) { return lhs.Compare(rhs); } } // namespace int BytesValue::Compare(absl::string_view bytes) const { return NativeValue([bytes](const auto& alternative) -> int { return CompareImpl(alternative, bytes); }); } int BytesValue::Compare(const absl::Cord& bytes) const { return NativeValue([&bytes](const auto& alternative) -> int { return CompareImpl(alternative, bytes); }); } int BytesValue::Compare(const BytesValue& bytes) const { return bytes.NativeValue( [this](const auto& alternative) -> int { return Compare(alternative); }); } } // namespace cel ================================================ FILE: common/values/bytes_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_H_ #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/allocator.h" #include "common/arena.h" #include "common/internal/byte_string.h" #include "common/memory.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class BytesValue; class BytesValueInputStream; class BytesValueOutputStream; namespace common_internal { absl::string_view LegacyBytesValue(const BytesValue& value, bool stable, google::protobuf::Arena* absl_nonnull arena); } // namespace common_internal // `BytesValue` represents values of the primitive `bytes` type. class BytesValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kBytes; static BytesValue From(const char* absl_nullable value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static BytesValue From(absl::string_view value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static BytesValue From(const absl::Cord& value); static BytesValue From(std::string&& value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static BytesValue Wrap(absl::string_view value, google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static BytesValue Wrap(absl::string_view value) = delete; static BytesValue Wrap(const absl::Cord& value); static BytesValue Wrap(std::string&& value) = delete; static BytesValue Wrap(std::string&& value, google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) = delete; // Returns a BytesValue that aliases the provided string. Caller must ensure // the provided string outlives the use of the returned BytesValue. static BytesValue WrapUnsafe(absl::string_view value); static BytesValue Concat(const BytesValue& lhs, const BytesValue& rhs, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); ABSL_DEPRECATED("Use From") explicit BytesValue(const char* absl_nullable value) : value_(value) {} ABSL_DEPRECATED("Use From") explicit BytesValue(absl::string_view value) : value_(value) {} ABSL_DEPRECATED("Use From") explicit BytesValue(const absl::Cord& value) : value_(value) {} ABSL_DEPRECATED("Use From") explicit BytesValue(std::string&& value) : value_(std::move(value)) {} ABSL_DEPRECATED("Use From") BytesValue(Allocator<> allocator, const char* absl_nullable value) : value_(allocator, value) {} ABSL_DEPRECATED("Use From") BytesValue(Allocator<> allocator, absl::string_view value) : value_(allocator, value) {} ABSL_DEPRECATED("Use From") BytesValue(Allocator<> allocator, const absl::Cord& value) : value_(allocator, value) {} ABSL_DEPRECATED("Use From") BytesValue(Allocator<> allocator, std::string&& value) : value_(allocator, std::move(value)) {} ABSL_DEPRECATED("Use Wrap") BytesValue(Borrower borrower, absl::string_view value) : value_(borrower, value) {} ABSL_DEPRECATED("Use Wrap") BytesValue(Borrower borrower, const absl::Cord& value) : value_(borrower, value) {} BytesValue() = default; BytesValue(const BytesValue&) = default; BytesValue(BytesValue&&) = default; BytesValue& operator=(const BytesValue&) = default; BytesValue& operator=(BytesValue&&) = default; constexpr ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return BytesType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return NativeValue([](const auto& value) -> bool { return value.empty(); }); } BytesValue Clone(google::protobuf::Arena* absl_nonnull arena) const; ABSL_DEPRECATED("Use ToString()") std::string NativeString() const { return value_.ToString(); } ABSL_DEPRECATED("Use ToStringView()") absl::string_view NativeString( std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.ToStringView(&scratch); } ABSL_DEPRECATED("Use ToCord()") absl::Cord NativeCord() const { return value_.ToCord(); } template ABSL_DEPRECATED("Use TryFlat()") std::common_type_t< std::invoke_result_t, std::invoke_result_t> NativeValue(Visitor&& visitor) const { return value_.Visit(std::forward(visitor)); } void swap(BytesValue& other) noexcept { using std::swap; swap(value_, other.value_); } size_t Size() const; bool IsEmpty() const; bool Equals(absl::string_view bytes) const; bool Equals(const absl::Cord& bytes) const; bool Equals(const BytesValue& bytes) const; int Compare(absl::string_view bytes) const; int Compare(const absl::Cord& bytes) const; int Compare(const BytesValue& bytes) const; absl::optional TryFlat() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.TryFlat(); } std::string ToString() const { return value_.ToString(); } void CopyToString(std::string* absl_nonnull out) const { value_.CopyToString(out); } void AppendToString(std::string* absl_nonnull out) const { value_.AppendToString(out); } absl::Cord ToCord() const { return value_.ToCord(); } void CopyToCord(absl::Cord* absl_nonnull out) const { value_.CopyToCord(out); } void AppendToCord(absl::Cord* absl_nonnull out) const { value_.AppendToCord(out); } absl::string_view ToStringView( std::string* absl_nonnull scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.ToStringView(scratch); } friend bool operator<(const BytesValue& lhs, const BytesValue& rhs) { return lhs.value_ < rhs.value_; } private: friend class common_internal::ValueMixin; friend class BytesValueInputStream; friend class BytesValueOutputStream; friend absl::string_view common_internal::LegacyBytesValue( const BytesValue& value, bool stable, google::protobuf::Arena* absl_nonnull arena); friend struct ArenaTraits; explicit BytesValue(common_internal::ByteString value) noexcept : value_(std::move(value)) {} common_internal::ByteString value_; }; inline void swap(BytesValue& lhs, BytesValue& rhs) noexcept { lhs.swap(rhs); } inline std::ostream& operator<<(std::ostream& out, const BytesValue& value) { return out << value.DebugString(); } inline bool operator==(const BytesValue& lhs, absl::string_view rhs) { return lhs.Equals(rhs); } inline bool operator==(absl::string_view lhs, const BytesValue& rhs) { return rhs == lhs; } inline bool operator!=(const BytesValue& lhs, absl::string_view rhs) { return !lhs.Equals(rhs); } inline bool operator!=(absl::string_view lhs, const BytesValue& rhs) { return rhs != lhs; } inline BytesValue BytesValue::From(const char* absl_nullable value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { return From(absl::NullSafeStringView(value), arena); } inline BytesValue BytesValue::From(absl::string_view value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(arena != nullptr); return BytesValue(arena, value); } inline BytesValue BytesValue::From(const absl::Cord& value) { return BytesValue(value); } inline BytesValue BytesValue::From(std::string&& value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(arena != nullptr); return BytesValue(arena, std::move(value)); } inline BytesValue BytesValue::Wrap(absl::string_view value, google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(arena != nullptr); return BytesValue(Borrower::Arena(arena), value); } inline BytesValue BytesValue::WrapUnsafe(absl::string_view value) { return BytesValue(common_internal::ByteString::FromExternal(value)); } inline BytesValue BytesValue::Wrap(const absl::Cord& value) { return BytesValue(value); } namespace common_internal { inline absl::string_view LegacyBytesValue(const BytesValue& value, bool stable, google::protobuf::Arena* absl_nonnull arena) { return LegacyByteString(value.value_, stable, arena); } } // namespace common_internal template <> struct ArenaTraits { using constructible = std::true_type; static bool trivially_destructible(const BytesValue& value) { return ArenaTraits<>::trivially_destructible(value.value_); } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_H_ ================================================ FILE: common/values/bytes_value_input_stream.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "absl/utility/utility.h" #include "common/internal/byte_string.h" #include "common/values/bytes_value.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { class BytesValueInputStream final : public google::protobuf::io::ZeroCopyInputStream { public: explicit BytesValueInputStream( const BytesValue* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND) { Construct(value); } ~BytesValueInputStream() override { AsVariant().~variant(); } bool Next(const void** data, int* size) override { return absl::visit( [&data, &size](auto& alternative) -> bool { return alternative.Next(data, size); }, AsVariant()); } void BackUp(int count) override { absl::visit( [&count](auto& alternative) -> void { alternative.BackUp(count); }, AsVariant()); } bool Skip(int count) override { return absl::visit( [&count](auto& alternative) -> bool { return alternative.Skip(count); }, AsVariant()); } int64_t ByteCount() const override { return absl::visit( [](const auto& alternative) -> int64_t { return alternative.ByteCount(); }, AsVariant()); } bool ReadCord(absl::Cord* cord, int count) override { return absl::visit( [&cord, &count](auto& alternative) -> bool { return alternative.ReadCord(cord, count); }, AsVariant()); } private: using Variant = absl::variant; void Construct(const BytesValue* absl_nonnull value) { ABSL_DCHECK(value != nullptr); switch (value->value_.GetKind()) { case common_internal::ByteStringKind::kSmall: Construct(value->value_.GetSmall()); break; case common_internal::ByteStringKind::kMedium: Construct(value->value_.GetMedium()); break; case common_internal::ByteStringKind::kLarge: Construct(&value->value_.GetLarge()); break; } } void Construct(absl::string_view value) { ABSL_DCHECK_LE(value.size(), static_cast(std::numeric_limits::max())); ::new (static_cast(&impl_[0])) Variant(absl::in_place_type, value.data(), static_cast(value.size())); } void Construct(const absl::Cord* absl_nonnull value) { ::new (static_cast(&impl_[0])) Variant(absl::in_place_type, value); } void Destruct() { AsVariant().~variant(); } Variant& AsVariant() ABSL_ATTRIBUTE_LIFETIME_BOUND { return *std::launder(reinterpret_cast(&impl_[0])); } const Variant& AsVariant() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return *std::launder(reinterpret_cast(&impl_[0])); } alignas(Variant) char impl_[sizeof(Variant)]; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ ================================================ FILE: common/values/bytes_value_output_stream.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "absl/utility/utility.h" #include "common/internal/byte_string.h" #include "common/values/bytes_value.h" #include "google/protobuf/arena.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { class BytesValueOutputStream final : public google::protobuf::io::ZeroCopyOutputStream { public: explicit BytesValueOutputStream(const BytesValue& value) : BytesValueOutputStream(value, /*arena=*/nullptr) {} BytesValueOutputStream(const BytesValue& value, google::protobuf::Arena* absl_nullable arena) { Construct(value, arena); } bool Next(void** data, int* size) override { return absl::visit(absl::Overload( [&data, &size](String& string) -> bool { return string.stream.Next(data, size); }, [&data, &size](Cord& cord) -> bool { return cord.Next(data, size); }), AsVariant()); } void BackUp(int count) override { absl::visit( absl::Overload( [&count](String& string) -> void { string.stream.BackUp(count); }, [&count](Cord& cord) -> void { cord.BackUp(count); }), AsVariant()); } int64_t ByteCount() const override { return absl::visit( absl::Overload( [](const String& string) -> int64_t { return string.stream.ByteCount(); }, [](const Cord& cord) -> int64_t { return cord.ByteCount(); }), AsVariant()); } bool WriteAliasedRaw(const void* data, int size) override { return absl::visit(absl::Overload( [&data, &size](String& string) -> bool { return string.stream.WriteAliasedRaw(data, size); }, [&data, &size](Cord& cord) -> bool { return cord.WriteAliasedRaw(data, size); }), AsVariant()); } bool AllowsAliasing() const override { return absl::visit( absl::Overload( [](const String& string) -> bool { return string.stream.AllowsAliasing(); }, [](const Cord& cord) -> bool { return cord.AllowsAliasing(); }), AsVariant()); } bool WriteCord(const absl::Cord& out) override { return absl::visit( absl::Overload( [&out](String& string) -> bool { return string.stream.WriteCord(out); }, [&out](Cord& cord) -> bool { return cord.WriteCord(out); }), AsVariant()); } BytesValue Consume() && { return absl::visit(absl::Overload( [](String& string) -> BytesValue { return BytesValue(string.arena, std::move(string.target)); }, [](Cord& cord) -> BytesValue { return BytesValue(cord.Consume()); }), AsVariant()); } private: struct String final { String(absl::string_view target, google::protobuf::Arena* absl_nullable arena) : target(target), stream(&this->target), arena(arena) {} std::string target; google::protobuf::io::StringOutputStream stream; google::protobuf::Arena* absl_nullable arena; }; using Cord = google::protobuf::io::CordOutputStream; using Variant = absl::variant; void Construct(const BytesValue& value, google::protobuf::Arena* absl_nullable arena) { switch (value.value_.GetKind()) { case common_internal::ByteStringKind::kSmall: Construct(value.value_.GetSmall(), arena); break; case common_internal::ByteStringKind::kMedium: Construct(value.value_.GetMedium(), arena); break; case common_internal::ByteStringKind::kLarge: Construct(value.value_.GetLarge()); break; } } void Construct(absl::string_view value, google::protobuf::Arena* absl_nullable arena) { ::new (static_cast(&impl_[0])) Variant(absl::in_place_type, value, arena); } void Construct(const absl::Cord& value) { ::new (static_cast(&impl_[0])) Variant(absl::in_place_type, value); } void Destruct() { AsVariant().~variant(); } Variant& AsVariant() ABSL_ATTRIBUTE_LIFETIME_BOUND { return *std::launder(reinterpret_cast(&impl_[0])); } const Variant& AsVariant() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return *std::launder(reinterpret_cast(&impl_[0])); } alignas(Variant) char impl_[sizeof(Variant)]; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ ================================================ FILE: common/values/bytes_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::testing::An; using ::testing::Eq; using ::testing::NotNull; using ::testing::Optional; using BytesValueTest = common_internal::ValueTest<>; TEST_F(BytesValueTest, Kind) { EXPECT_EQ(BytesValue("foo").kind(), BytesValue::kKind); EXPECT_EQ(Value(BytesValue(absl::Cord("foo"))).kind(), BytesValue::kKind); } TEST_F(BytesValueTest, DebugString) { { std::ostringstream out; out << BytesValue("foo"); EXPECT_EQ(out.str(), "b\"foo\""); } { std::ostringstream out; out << BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})); EXPECT_EQ(out.str(), "b\"foo\""); } { std::ostringstream out; out << Value(BytesValue(absl::Cord("foo"))); EXPECT_EQ(out.str(), "b\"foo\""); } } TEST_F(BytesValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(BytesValue("foo").ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(string_value: "Zm9v")pb")); } TEST_F(BytesValueTest, NativeValue) { std::string scratch; EXPECT_EQ(BytesValue("foo").NativeString(), "foo"); EXPECT_EQ(BytesValue("foo").NativeString(scratch), "foo"); EXPECT_EQ(BytesValue("foo").NativeCord(), "foo"); } TEST_F(BytesValueTest, TryFlat) { EXPECT_THAT(BytesValue("foo").TryFlat(), Optional(Eq("foo"))); EXPECT_THAT( BytesValue(absl::MakeFragmentedCord({"Hello, World!", "World, Hello!"})) .TryFlat(), Eq(absl::nullopt)); } TEST_F(BytesValueTest, ToString) { EXPECT_EQ(BytesValue("foo").ToString(), "foo"); EXPECT_EQ(BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).ToString(), "foo"); } TEST_F(BytesValueTest, CopyToString) { std::string out; BytesValue("foo").CopyToString(&out); EXPECT_EQ(out, "foo"); BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).CopyToString(&out); EXPECT_EQ(out, "foo"); } TEST_F(BytesValueTest, AppendToString) { std::string out; BytesValue("foo").AppendToString(&out); EXPECT_EQ(out, "foo"); BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).AppendToString(&out); EXPECT_EQ(out, "foofoo"); } TEST_F(BytesValueTest, ToCord) { EXPECT_EQ(BytesValue("foo").ToCord(), "foo"); EXPECT_EQ(BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).ToCord(), "foo"); } TEST_F(BytesValueTest, CopyToCord) { absl::Cord out; BytesValue("foo").CopyToCord(&out); EXPECT_EQ(out, "foo"); BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).CopyToCord(&out); EXPECT_EQ(out, "foo"); } TEST_F(BytesValueTest, AppendToCord) { absl::Cord out; BytesValue("foo").AppendToCord(&out); EXPECT_EQ(out, "foo"); BytesValue(absl::MakeFragmentedCord({"f", "o", "o"})).AppendToCord(&out); EXPECT_EQ(out, "foofoo"); } TEST_F(BytesValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(BytesValue("foo")), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(BytesValue(absl::Cord("foo")))), NativeTypeId::For()); } TEST_F(BytesValueTest, StringViewEquality) { // NOLINTBEGIN(readability/check) EXPECT_TRUE(BytesValue("foo") == "foo"); EXPECT_FALSE(BytesValue("foo") == "bar"); EXPECT_TRUE("foo" == BytesValue("foo")); EXPECT_FALSE("bar" == BytesValue("foo")); // NOLINTEND(readability/check) } TEST_F(BytesValueTest, StringViewInequality) { // NOLINTBEGIN(readability/check) EXPECT_FALSE(BytesValue("foo") != "foo"); EXPECT_TRUE(BytesValue("foo") != "bar"); EXPECT_FALSE("foo" != BytesValue("foo")); EXPECT_TRUE("bar" != BytesValue("foo")); // NOLINTEND(readability/check) } TEST_F(BytesValueTest, Comparison) { EXPECT_LT(BytesValue("bar"), BytesValue("foo")); EXPECT_FALSE(BytesValue("foo") < BytesValue("foo")); EXPECT_FALSE(BytesValue("foo") < BytesValue("bar")); } TEST_F(BytesValueTest, StringInputStream) { BytesValue value = BytesValue("foo"); BytesValueInputStream stream(&value); const void* data; int size; absl::Cord cord; ASSERT_TRUE(stream.Next(&data, &size)); EXPECT_THAT(data, NotNull()); EXPECT_EQ(size, 3); EXPECT_EQ(stream.ByteCount(), 3); stream.BackUp(size); ASSERT_TRUE(stream.Skip(3)); EXPECT_FALSE(stream.ReadCord(&cord, 3)); EXPECT_FALSE(stream.Next(&data, &size)); } TEST_F(BytesValueTest, CordInputStream) { BytesValue value = BytesValue(absl::Cord("foo")); BytesValueInputStream stream(&value); const void* data; int size; absl::Cord cord; ASSERT_TRUE(stream.Next(&data, &size)); EXPECT_THAT(data, NotNull()); EXPECT_EQ(size, 3); EXPECT_EQ(stream.ByteCount(), 3); stream.BackUp(size); ASSERT_TRUE(stream.Skip(3)); EXPECT_FALSE(stream.ReadCord(&cord, 3)); EXPECT_FALSE(stream.Next(&data, &size)); } TEST_F(BytesValueTest, ArenaStringOutputStream) { BytesValue value = BytesValue(""); { BytesValueOutputStream stream(value, arena()); EXPECT_THAT(stream.AllowsAliasing(), An()); EXPECT_EQ(stream.ByteCount(), 0); google::protobuf::Value value_proto; auto* struct_proto = value_proto.mutable_struct_value(); (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); EXPECT_EQ(std::move(stream).Consume(), value_proto.SerializePartialAsString()); } { BytesValueOutputStream stream(value); EXPECT_EQ(std::move(stream).Consume(), ""); } } TEST_F(BytesValueTest, StringOutputStream) { BytesValue value = BytesValue(""); { BytesValueOutputStream stream(value); EXPECT_THAT(stream.AllowsAliasing(), An()); EXPECT_EQ(stream.ByteCount(), 0); google::protobuf::Value value_proto; auto* struct_proto = value_proto.mutable_struct_value(); (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); EXPECT_EQ(std::move(stream).Consume(), value_proto.SerializePartialAsString()); } { BytesValueOutputStream stream(value); EXPECT_EQ(std::move(stream).Consume(), ""); } } TEST_F(BytesValueTest, CordOutputStream) { BytesValue value = BytesValue(absl::Cord()); { BytesValueOutputStream stream(value); EXPECT_THAT(stream.AllowsAliasing(), An()); EXPECT_EQ(stream.ByteCount(), 0); google::protobuf::Value value_proto; auto* struct_proto = value_proto.mutable_struct_value(); (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); EXPECT_EQ(std::move(stream).Consume(), value_proto.SerializePartialAsString()); } { BytesValueOutputStream stream(value); EXPECT_EQ(std::move(stream).Consume(), ""); } } } // namespace } // namespace cel ================================================ FILE: common/values/custom_list_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/casting.h" #include "common/native_type.h" #include "common/value.h" #include "common/values/list_value_builder.h" #include "common/values/values.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::ListValueReflection; using ::cel::well_known_types::ValueReflection; using ::google::api::expr::runtime::CelValue; class EmptyListValue final : public common_internal::CompatListValue { public: static const EmptyListValue& Get() { static const absl::NoDestructor empty; return *empty; } EmptyListValue() = default; std::string DebugString() const override { return "[]"; } bool IsEmpty() const override { return true; } size_t Size() const override { return 0; } absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); json->Clear(); return absl::OkStatus(); } CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { return CustomListValue(&EmptyListValue::Get(), arena); } int size() const override { return 0; } CelValue operator[](int index) const override { static const absl::NoDestructor error( absl::InvalidArgumentError("index out of bounds")); return CelValue::CreateError(&*error); } CelValue Get(google::protobuf::Arena* arena, int index) const override { if (arena == nullptr) { return (*this)[index]; } return CelValue::CreateError(google::protobuf::Arena::Create( arena, absl::InvalidArgumentError("index out of bounds"))); } private: absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull, Value* absl_nonnull result) const override { *result = IndexOutOfBoundsError(index); return absl::OkStatus(); } }; } // namespace namespace common_internal { const CompatListValue* absl_nonnull EmptyCompatListValue() { return &EmptyListValue::Get(); } } // namespace common_internal class CustomListValueInterfaceIterator final : public ValueIterator { public: explicit CustomListValueInterfaceIterator( const CustomListValueInterface& interface) : interface_(interface), size_(interface_.Size()) {} bool HasNext() override { return index_ < size_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (ABSL_PREDICT_FALSE(index_ >= size_)) { return absl::FailedPreconditionError( "ValueIterator::Next() called when " "ValueIterator::HasNext() returns false"); } CEL_RETURN_IF_ERROR(interface_.Get(index_, descriptor_pool, message_factory, arena, result)); ++index_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (index_ >= size_) { return false; } CEL_RETURN_IF_ERROR(interface_.Get(index_, descriptor_pool, message_factory, arena, key_or_value)); ++index_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (index_ >= size_) { return false; } if (value != nullptr) { CEL_RETURN_IF_ERROR(interface_.Get(index_, descriptor_pool, message_factory, arena, value)); } *key = IntValue(index_); ++index_; return true; } private: const CustomListValueInterface& interface_; const size_t size_; size_t index_ = 0; }; namespace { class CustomListValueDispatcherIterator final : public ValueIterator { public: explicit CustomListValueDispatcherIterator( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, size_t size) : dispatcher_(dispatcher), content_(content), size_(size) {} bool HasNext() override { return index_ < size_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (ABSL_PREDICT_FALSE(index_ >= size_)) { return absl::FailedPreconditionError( "ValueIterator::Next() called when " "ValueIterator::HasNext() returns false"); } CEL_RETURN_IF_ERROR(dispatcher_->get(dispatcher_, content_, index_, descriptor_pool, message_factory, arena, result)); ++index_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (index_ >= size_) { return false; } CEL_RETURN_IF_ERROR(dispatcher_->get(dispatcher_, content_, index_, descriptor_pool, message_factory, arena, key_or_value)); ++index_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (index_ >= size_) { return false; } if (value != nullptr) { CEL_RETURN_IF_ERROR(dispatcher_->get(dispatcher_, content_, index_, descriptor_pool, message_factory, arena, value)); } *key = IntValue(index_); ++index_; return true; } private: const CustomListValueDispatcher* absl_nonnull const dispatcher_; const CustomListValueContent content_; const size_t size_; size_t index_ = 0; }; } // namespace absl::Status CustomListValueInterface::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); ListValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor_pool)); const google::protobuf::Message* prototype = message_factory->GetPrototype(reflection.GetDescriptor()); if (prototype == nullptr) { return absl::UnknownError( absl::StrCat("failed to get message prototype: ", reflection.GetDescriptor()->full_name())); } google::protobuf::Arena arena; google::protobuf::Message* message = prototype->New(&arena); CEL_RETURN_IF_ERROR( ConvertToJsonArray(descriptor_pool, message_factory, message)); if (!message->SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.ListValue"); } return absl::OkStatus(); } absl::Status CustomListValueInterface::ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { const size_t size = Size(); for (size_t index = 0; index < size; ++index) { Value element; CEL_RETURN_IF_ERROR( Get(index, descriptor_pool, message_factory, arena, &element)); CEL_ASSIGN_OR_RETURN(auto ok, callback(index, element)); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr CustomListValueInterface::NewIterator() const { return std::make_unique(*this); } absl::Status CustomListValueInterface::Equal( const ListValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return ListValueEqual(*this, other, descriptor_pool, message_factory, arena, result); } absl::Status CustomListValueInterface::Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { Value outcome = BoolValue(false); Value equal; CEL_RETURN_IF_ERROR(ForEach( [&](size_t index, const Value& element) -> absl::StatusOr { CEL_RETURN_IF_ERROR(element.Equal(other, descriptor_pool, message_factory, arena, &equal)); if (auto bool_result = As(equal); bool_result.has_value() && bool_result->NativeValue()) { outcome = BoolValue(true); return false; } return true; }, descriptor_pool, message_factory, arena)); *result = outcome; return absl::OkStatus(); } CustomListValue::CustomListValue() { content_ = CustomListValueContent::From(CustomListValueInterface::Content{ .interface = &EmptyListValue::Get(), .arena = nullptr}); } NativeTypeId CustomListValue::GetTypeId() const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->GetNativeTypeId(); } return dispatcher_->get_type_id(dispatcher_, content_); } absl::string_view CustomListValue::GetTypeName() const { return "list"; } std::string CustomListValue::DebugString() const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->DebugString(); } if (dispatcher_->debug_string != nullptr) { return dispatcher_->debug_string(dispatcher_, content_); } return "list"; } absl::Status CustomListValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->SerializeTo(descriptor_pool, message_factory, output); } if (dispatcher_->serialize_to != nullptr) { return dispatcher_->serialize_to(dispatcher_, content_, descriptor_pool, message_factory, output); } return absl::UnimplementedError( absl::StrCat(GetTypeName(), " is unserializable")); } absl::Status CustomListValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); google::protobuf::Message* json_array = value_reflection.MutableListValue(json); return ConvertToJsonArray(descriptor_pool, message_factory, json_array); } absl::Status CustomListValue::ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->ConvertToJsonArray(descriptor_pool, message_factory, json); } if (dispatcher_->convert_to_json_array != nullptr) { return dispatcher_->convert_to_json_array( dispatcher_, content_, descriptor_pool, message_factory, json); } return absl::UnimplementedError( absl::StrCat(GetTypeName(), " is not convertable to JSON")); } absl::Status CustomListValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_list_value = other.AsList(); other_list_value) { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Equal(*other_list_value, descriptor_pool, message_factory, arena, result); } if (dispatcher_->equal != nullptr) { return dispatcher_->equal(dispatcher_, content_, *other_list_value, descriptor_pool, message_factory, arena, result); } return common_internal::ListValueEqual(*this, *other_list_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } bool CustomListValue::IsZeroValue() const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->IsZeroValue(); } return dispatcher_->is_zero_value(dispatcher_, content_); } CustomListValue CustomListValue::Clone( google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); if (content.arena != arena) { return content.interface->Clone(arena); } return *this; } return dispatcher_->clone(dispatcher_, content_, arena); } bool CustomListValue::IsEmpty() const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->IsEmpty(); } if (dispatcher_->is_empty != nullptr) { return dispatcher_->is_empty(dispatcher_, content_); } return dispatcher_->size(dispatcher_, content_) == 0; } size_t CustomListValue::Size() const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Size(); } return dispatcher_->size(dispatcher_, content_); } absl::Status CustomListValue::Get( size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Get(index, descriptor_pool, message_factory, arena, result); } return dispatcher_->get(dispatcher_, content_, index, descriptor_pool, message_factory, arena, result); } absl::Status CustomListValue::ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->ForEach(callback, descriptor_pool, message_factory, arena); } if (dispatcher_->for_each != nullptr) { return dispatcher_->for_each(dispatcher_, content_, callback, descriptor_pool, message_factory, arena); } const size_t size = dispatcher_->size(dispatcher_, content_); for (size_t index = 0; index < size; ++index) { Value element; CEL_RETURN_IF_ERROR(dispatcher_->get(dispatcher_, content_, index, descriptor_pool, message_factory, arena, &element)); CEL_ASSIGN_OR_RETURN(auto ok, callback(index, element)); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr CustomListValue::NewIterator() const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->NewIterator(); } if (dispatcher_->new_iterator != nullptr) { return dispatcher_->new_iterator(dispatcher_, content_); } return std::make_unique( dispatcher_, content_, dispatcher_->size(dispatcher_, content_)); } absl::Status CustomListValue::Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (dispatcher_ == nullptr) { CustomListValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Contains(other, descriptor_pool, message_factory, arena, result); } if (dispatcher_->contains != nullptr) { return dispatcher_->contains(dispatcher_, content_, other, descriptor_pool, message_factory, arena, result); } Value outcome = BoolValue(false); Value equal; CEL_RETURN_IF_ERROR(ForEach( [&](size_t index, const Value& element) -> absl::StatusOr { CEL_RETURN_IF_ERROR(element.Equal(other, descriptor_pool, message_factory, arena, &equal)); if (auto bool_result = As(equal); bool_result.has_value() && bool_result->NativeValue()) { outcome = BoolValue(true); return false; } return true; }, descriptor_pool, message_factory, arena)); *result = outcome; return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/custom_list_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" // `CustomListValue` represents values of the primitive `list` type. // `CustomListValueView` is a non-owning view of `CustomListValue`. // `CustomListValueInterface` is the abstract base class of implementations. // `CustomListValue` and `CustomListValueView` act as smart pointers to // `CustomListValueInterface`. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_LIST_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_LIST_VALUE_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/native_type.h" #include "common/value_kind.h" #include "common/values/custom_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class CustomListValueInterface; class CustomListValueInterfaceIterator; class CustomListValue; struct CustomListValueDispatcher; using CustomListValueContent = CustomValueContent; struct CustomListValueDispatcher { using GetTypeId = NativeTypeId (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content); using GetArena = google::protobuf::Arena* absl_nullable (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content); using DebugString = std::string (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content); using SerializeTo = absl::Status (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output); using ConvertToJsonArray = absl::Status (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json); using Equal = absl::Status (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, const ListValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using IsZeroValue = bool (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content); using IsEmpty = bool (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content); using Size = size_t (*)(const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content); using Get = absl::Status (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using ForEach = absl::Status (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, absl::FunctionRef(size_t, const Value&)> callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); using NewIterator = absl::StatusOr (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content); using Contains = absl::Status (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using Clone = CustomListValue (*)( const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, google::protobuf::Arena* absl_nonnull arena); absl_nonnull GetTypeId get_type_id; absl_nonnull GetArena get_arena; // If null, simply returns "list". absl_nullable DebugString debug_string = nullptr; // If null, attempts to serialize results in an UNIMPLEMENTED error. absl_nullable SerializeTo serialize_to = nullptr; // If null, attempts to convert to JSON results in an UNIMPLEMENTED error. absl_nullable ConvertToJsonArray convert_to_json_array = nullptr; // If null, an nonoptimal fallback implementation for equality is used. absl_nullable Equal equal = nullptr; absl_nonnull IsZeroValue is_zero_value; // If null, `size(...) == 0` is used. absl_nullable IsEmpty is_empty = nullptr; absl_nonnull Size size; absl_nonnull Get get; // If null, a fallback implementation using `size` and `get` is used. absl_nullable ForEach for_each = nullptr; // If null, a fallback implementation using `size` and `get` is used. absl_nullable NewIterator new_iterator = nullptr; // If null, a fallback implementation is used. absl_nullable Contains contains = nullptr; absl_nonnull Clone clone; }; class CustomListValueInterface { public: CustomListValueInterface() = default; CustomListValueInterface(const CustomListValueInterface&) = delete; CustomListValueInterface(CustomListValueInterface&&) = delete; virtual ~CustomListValueInterface() = default; CustomListValueInterface& operator=(const CustomListValueInterface&) = delete; CustomListValueInterface& operator=(CustomListValueInterface&&) = delete; using ForEachCallback = absl::FunctionRef(const Value&)>; using ForEachWithIndexCallback = absl::FunctionRef(size_t, const Value&)>; private: friend class CustomListValueInterfaceIterator; friend class CustomListValue; friend absl::Status common_internal::ListValueEqual( const CustomListValueInterface& lhs, const ListValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); virtual std::string DebugString() const = 0; virtual absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; virtual absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const = 0; virtual absl::Status Equal( const ListValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; virtual bool IsZeroValue() const { return IsEmpty(); } virtual bool IsEmpty() const { return Size() == 0; } virtual size_t Size() const = 0; virtual absl::Status Get( size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; virtual absl::Status ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; virtual absl::StatusOr NewIterator() const; virtual absl::Status Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; virtual CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const = 0; virtual NativeTypeId GetNativeTypeId() const = 0; struct Content { const CustomListValueInterface* absl_nonnull interface; const google::protobuf::Arena* absl_nullable arena; }; }; // Creates a custom list value from a manual dispatch table `dispatcher` and // opaque data `content` whose format is only know to functions in the manual // dispatch table. The dispatch table should probably be valid for the lifetime // of the process, but at a minimum must outlive all instances of the resulting // value. // // IMPORTANT: This approach to implementing CustomListValue should only be // used when you know exactly what you are doing. When in doubt, just implement // CustomListValueInterface. CustomListValue UnsafeCustomListValue( const CustomListValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomListValueContent content); class CustomListValue final : private common_internal::ListValueMixin { public: static constexpr ValueKind kKind = ValueKind::kList; // Constructs a custom list value from an implementation of // `CustomListValueInterface` `interface` whose lifetime is tied to that of // the arena `arena`. CustomListValue(const CustomListValueInterface* absl_nonnull interface ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(interface != nullptr); ABSL_DCHECK(arena != nullptr); content_ = CustomListValueContent::From(CustomListValueInterface::Content{ .interface = interface, .arena = arena}); } CustomListValue(); CustomListValue(const CustomListValue&) = default; CustomListValue(CustomListValue&&) = default; CustomListValue& operator=(const CustomListValue&) = default; CustomListValue& operator=(CustomListValue&&) = default; static constexpr ValueKind kind() { return kKind; } NativeTypeId GetTypeId() const; absl::string_view GetTypeName() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonArray(). absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Equal; bool IsZeroValue() const; CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const; bool IsEmpty() const; size_t Size() const; // See ListValueInterface::Get for documentation. absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Get; using ForEachCallback = typename CustomListValueInterface::ForEachCallback; using ForEachWithIndexCallback = typename CustomListValueInterface::ForEachWithIndexCallback; absl::Status ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; using ListValueMixin::ForEach; absl::StatusOr NewIterator() const; absl::Status Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Contains; const CustomListValueDispatcher* absl_nullable dispatcher() const { return dispatcher_; } CustomListValueContent content() const { ABSL_DCHECK(dispatcher_ != nullptr); return content_; } const CustomListValueInterface* absl_nullable interface() const { if (dispatcher_ == nullptr) { return content_.To().interface; } return nullptr; } friend void swap(CustomListValue& lhs, CustomListValue& rhs) noexcept { using std::swap; swap(lhs.dispatcher_, rhs.dispatcher_); swap(lhs.content_, rhs.content_); } private: friend class common_internal::ValueMixin; friend class common_internal::ListValueMixin; friend CustomListValue UnsafeCustomListValue( const CustomListValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomListValueContent content); CustomListValue(const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content) : dispatcher_(dispatcher), content_(content) { ABSL_DCHECK(dispatcher != nullptr); ABSL_DCHECK(dispatcher->get_type_id != nullptr); ABSL_DCHECK(dispatcher->get_arena != nullptr); ABSL_DCHECK(dispatcher->is_zero_value != nullptr); ABSL_DCHECK(dispatcher->size != nullptr); ABSL_DCHECK(dispatcher->get != nullptr); ABSL_DCHECK(dispatcher->clone != nullptr); } const CustomListValueDispatcher* absl_nullable dispatcher_ = nullptr; CustomListValueContent content_ = CustomListValueContent::Zero(); }; inline std::ostream& operator<<(std::ostream& out, const CustomListValue& type) { return out << type.DebugString(); } template <> struct NativeTypeTraits final { static NativeTypeId Id(const CustomListValue& type) { return type.GetTypeId(); } }; inline CustomListValue UnsafeCustomListValue( const CustomListValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomListValueContent content) { return CustomListValue(dispatcher, content); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_LIST_VALUE_H_ ================================================ FILE: common/values/custom_list_value_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/memory.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::IsNull; using ::testing::Not; using ::testing::NotNull; using ::testing::Optional; using ::testing::Pair; using ::testing::UnorderedElementsAre; class CustomListValueTest; struct CustomListValueTestContent { google::protobuf::Arena* absl_nonnull arena; }; class CustomListValueInterfaceTest final : public CustomListValueInterface { public: std::string DebugString() const override { return "[true, 1]"; } absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const override { google::protobuf::Value json; google::protobuf::ListValue* json_array = json.mutable_list_value(); json_array->add_values()->set_bool_value(true); json_array->add_values()->set_number_value(1.0); if (!json.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Value"); } return absl::OkStatus(); } absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { google::protobuf::ListValue json_array; json_array.add_values()->set_bool_value(true); json_array.add_values()->set_number_value(1.0); absl::Cord serialized; if (!json_array.SerializePartialToString(&serialized)) { return absl::UnknownError( "failed to serialize google.protobuf.ListValue"); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError("failed to parse google.protobuf.ListValue"); } return absl::OkStatus(); } size_t Size() const override { return 2; } CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { return CustomListValue( (::new (arena->AllocateAligned(sizeof(CustomListValueInterfaceTest), alignof(CustomListValueInterfaceTest))) CustomListValueInterfaceTest()), arena); } private: absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { if (index == 0) { *result = TrueValue(); return absl::OkStatus(); } if (index == 1) { *result = IntValue(1); return absl::OkStatus(); } *result = IndexOutOfBoundsError(index); return absl::OkStatus(); } NativeTypeId GetNativeTypeId() const override { return NativeTypeId::For(); } }; class CustomListValueTest : public common_internal::ValueTest<> { public: CustomListValue MakeInterface() { return CustomListValue( (::new (arena()->AllocateAligned(sizeof(CustomListValueInterfaceTest), alignof(CustomListValueInterfaceTest))) CustomListValueInterfaceTest()), arena()); } CustomListValue MakeDispatcher() { return UnsafeCustomListValue( &test_dispatcher_, CustomValueContent::From( CustomListValueTestContent{.arena = arena()})); } protected: CustomListValueDispatcher test_dispatcher_ = { .get_type_id = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content) -> NativeTypeId { return NativeTypeId::For(); }, .get_arena = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content) -> google::protobuf::Arena* absl_nullable { return content.To().arena; }, .debug_string = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content) -> std::string { return "[true, 1]"; }, .serialize_to = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) -> absl::Status { google::protobuf::Value json; google::protobuf::Struct* json_object = json.mutable_struct_value(); (*json_object->mutable_fields())["foo"].set_bool_value(true); (*json_object->mutable_fields())["bar"].set_number_value(1.0); if (!json.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Value"); } return absl::OkStatus(); }, .convert_to_json_array = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) -> absl::Status { { google::protobuf::ListValue json_array; json_array.add_values()->set_bool_value(true); json_array.add_values()->set_number_value(1.0); absl::Cord serialized; if (!json_array.SerializePartialToString(&serialized)) { return absl::UnknownError( "failed to serialize google.protobuf.ListValue"); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError( "failed to parse google.protobuf.ListValue"); } return absl::OkStatus(); } }, .is_zero_value = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content) -> bool { return false; }, .size = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content) -> size_t { return 2; }, .get = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) -> absl::Status { if (index == 0) { *result = TrueValue(); return absl::OkStatus(); } if (index == 1) { *result = IntValue(1); return absl::OkStatus(); } *result = IndexOutOfBoundsError(index); return absl::OkStatus(); }, .clone = [](const CustomListValueDispatcher* absl_nonnull dispatcher, CustomListValueContent content, google::protobuf::Arena* absl_nonnull arena) -> CustomListValue { return UnsafeCustomListValue( dispatcher, CustomValueContent::From( CustomListValueTestContent{.arena = arena})); }, }; }; TEST_F(CustomListValueTest, Kind) { EXPECT_EQ(CustomListValue::kind(), CustomListValue::kKind); } TEST_F(CustomListValueTest, Dispatcher_GetTypeId) { EXPECT_EQ(MakeDispatcher().GetTypeId(), NativeTypeId::For()); } TEST_F(CustomListValueTest, Interface_GetTypeId) { EXPECT_EQ(MakeInterface().GetTypeId(), NativeTypeId::For()); } TEST_F(CustomListValueTest, Dispatcher_GetTypeName) { EXPECT_EQ(MakeDispatcher().GetTypeName(), "list"); } TEST_F(CustomListValueTest, Interface_GetTypeName) { EXPECT_EQ(MakeInterface().GetTypeName(), "list"); } TEST_F(CustomListValueTest, Dispatcher_DebugString) { EXPECT_EQ(MakeDispatcher().DebugString(), "[true, 1]"); } TEST_F(CustomListValueTest, Interface_DebugString) { EXPECT_EQ(MakeInterface().DebugString(), "[true, 1]"); } TEST_F(CustomListValueTest, Dispatcher_IsZeroValue) { EXPECT_FALSE(MakeDispatcher().IsZeroValue()); } TEST_F(CustomListValueTest, Interface_IsZeroValue) { EXPECT_FALSE(MakeInterface().IsZeroValue()); } TEST_F(CustomListValueTest, Dispatcher_SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(MakeDispatcher().SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); } TEST_F(CustomListValueTest, Interface_SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(MakeInterface().SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); } TEST_F(CustomListValueTest, Dispatcher_ConvertToJson) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeDispatcher().ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( list_value: { values: { bool_value: true } values: { number_value: 1.0 } } )pb")); } TEST_F(CustomListValueTest, Interface_ConvertToJson) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeInterface().ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( list_value: { values: { bool_value: true } values: { number_value: 1.0 } } )pb")); } TEST_F(CustomListValueTest, Dispatcher_ConvertToJsonArray) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeDispatcher().ConvertToJsonArray(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( values: { bool_value: true } values: { number_value: 1.0 } )pb")); } TEST_F(CustomListValueTest, Interface_ConvertToJsonArray) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeInterface().ConvertToJsonArray(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( values: { bool_value: true } values: { number_value: 1.0 } )pb")); } TEST_F(CustomListValueTest, Dispatcher_IsEmpty) { EXPECT_FALSE(MakeDispatcher().IsEmpty()); } TEST_F(CustomListValueTest, Interface_IsEmpty) { EXPECT_FALSE(MakeInterface().IsEmpty()); } TEST_F(CustomListValueTest, Dispatcher_Size) { EXPECT_EQ(MakeDispatcher().Size(), 2); } TEST_F(CustomListValueTest, Interface_Size) { EXPECT_EQ(MakeInterface().Size(), 2); } TEST_F(CustomListValueTest, Dispatcher_Get) { CustomListValue list = MakeDispatcher(); ASSERT_THAT(list.Get(0, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_THAT(list.Get(1, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); ASSERT_THAT( list.Get(2, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); } TEST_F(CustomListValueTest, Interface_Get) { CustomListValue list = MakeInterface(); ASSERT_THAT(list.Get(0, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_THAT(list.Get(1, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); ASSERT_THAT( list.Get(2, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); } TEST_F(CustomListValueTest, Dispatcher_ForEach) { std::vector> fields; EXPECT_THAT( MakeDispatcher().ForEach( [&](size_t index, const Value& value) -> absl::StatusOr { fields.push_back(std::pair{index, value}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(fields, UnorderedElementsAre(Pair(0, BoolValueIs(true)), Pair(1, IntValueIs(1)))); } TEST_F(CustomListValueTest, Interface_ForEach) { std::vector> fields; EXPECT_THAT( MakeInterface().ForEach( [&](size_t index, const Value& value) -> absl::StatusOr { fields.push_back(std::pair{index, value}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(fields, UnorderedElementsAre(Pair(0, BoolValueIs(true)), Pair(1, IntValueIs(1)))); } TEST_F(CustomListValueTest, Dispatcher_NewIterator) { CustomListValue list = MakeDispatcher(); ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(CustomListValueTest, Interface_NewIterator) { CustomListValue list = MakeInterface(); ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(CustomListValueTest, Dispatcher_NewIterator1) { CustomListValue list = MakeDispatcher(); ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(true)))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IntValueIs(1)))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomListValueTest, Interface_NewIterator1) { CustomListValue list = MakeInterface(); ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(true)))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IntValueIs(1)))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomListValueTest, Dispatcher_NewIterator2) { CustomListValue list = MakeDispatcher(); ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(IntValueIs(0), BoolValueIs(true))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(IntValueIs(1), IntValueIs(1))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomListValueTest, Interface_NewIterator2) { CustomListValue list = MakeInterface(); ASSERT_OK_AND_ASSIGN(auto iterator, list.NewIterator()); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(IntValueIs(0), BoolValueIs(true))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(IntValueIs(1), IntValueIs(1))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomListValueTest, Dispatcher_Contains) { CustomListValue list = MakeDispatcher(); EXPECT_THAT( list.Contains(TrueValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( list.Contains(IntValue(1), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(list.Contains(UintValue(1u), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(list.Contains(DoubleValue(1.0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(list.Contains(FalseValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT( list.Contains(IntValue(0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(list.Contains(UintValue(0u), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(list.Contains(DoubleValue(0.0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } TEST_F(CustomListValueTest, Interface_Contains) { CustomListValue list = MakeInterface(); EXPECT_THAT( list.Contains(TrueValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( list.Contains(IntValue(1), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(list.Contains(UintValue(1u), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(list.Contains(DoubleValue(1.0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(list.Contains(FalseValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT( list.Contains(IntValue(0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(list.Contains(UintValue(0u), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(list.Contains(DoubleValue(0.0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } TEST_F(CustomListValueTest, Dispatcher) { EXPECT_THAT(MakeDispatcher().dispatcher(), NotNull()); EXPECT_THAT(MakeDispatcher().interface(), IsNull()); } TEST_F(CustomListValueTest, Interface) { EXPECT_THAT(MakeInterface().dispatcher(), IsNull()); EXPECT_THAT(MakeInterface().interface(), NotNull()); } } // namespace } // namespace cel ================================================ FILE: common/values/custom_map_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/base/attributes.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_kind.h" #include "common/values/list_value_builder.h" #include "common/values/map_value_builder.h" #include "common/values/values.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::StructReflection; using ::cel::well_known_types::ValueReflection; using ::google::api::expr::runtime::CelList; using ::google::api::expr::runtime::CelValue; absl::Status NoSuchKeyError(const Value& key) { return absl::NotFoundError( absl::StrCat("Key not found in map : ", key.DebugString())); } absl::Status InvalidMapKeyTypeError(ValueKind kind) { return absl::InvalidArgumentError( absl::StrCat("Invalid map key type: '", ValueKindToString(kind), "'")); } class EmptyMapValue final : public common_internal::CompatMapValue { public: static const EmptyMapValue& Get() { static const absl::NoDestructor empty; return *empty; } EmptyMapValue() = default; std::string DebugString() const override { return "{}"; } bool IsEmpty() const override { return true; } size_t Size() const override { return 0; } absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const override { *result = ListValue(); return absl::OkStatus(); } absl::StatusOr NewIterator() const override { return NewEmptyValueIterator(); } absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); json->Clear(); return absl::OkStatus(); } CustomMapValue Clone(google::protobuf::Arena* absl_nonnull) const override { return CustomMapValue(); } absl::optional operator[](CelValue key) const override { return absl::nullopt; } using CompatMapValue::Get; absl::optional Get(google::protobuf::Arena* arena, CelValue key) const override { return absl::nullopt; } absl::StatusOr Has(const CelValue& key) const override { return false; } int size() const override { return static_cast(Size()); } absl::StatusOr ListKeys() const override { return common_internal::EmptyCompatListValue(); } absl::StatusOr ListKeys(google::protobuf::Arena*) const override { return ListKeys(); } private: absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { return false; } absl::StatusOr Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { return false; } }; } // namespace namespace common_internal { const CompatMapValue* absl_nonnull EmptyCompatMapValue() { return &EmptyMapValue::Get(); } } // namespace common_internal class CustomMapValueInterfaceIterator final : public ValueIterator { public: explicit CustomMapValueInterfaceIterator( const CustomMapValueInterface* absl_nonnull interface) : interface_(interface) {} bool HasNext() override { if (keys_iterator_ == nullptr) { return !interface_->IsEmpty(); } return keys_iterator_->HasNext(); } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (keys_iterator_ == nullptr) { if (interface_->IsEmpty()) { return absl::FailedPreconditionError( "ValueIterator::Next() called when " "ValueIterator::HasNext() returns false"); } CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); } return keys_iterator_->Next(descriptor_pool, message_factory, arena, result); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (keys_iterator_ == nullptr) { if (interface_->IsEmpty()) { return false; } CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); } return keys_iterator_->Next1(descriptor_pool, message_factory, arena, key_or_value); } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (keys_iterator_ == nullptr) { if (interface_->IsEmpty()) { return false; } CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); } CEL_ASSIGN_OR_RETURN( bool ok, keys_iterator_->Next1(descriptor_pool, message_factory, arena, key)); if (!ok) { return false; } if (value != nullptr) { CEL_ASSIGN_OR_RETURN(ok, interface_->Find(*key, descriptor_pool, message_factory, arena, value)); if (!ok) { return absl::DataLossError( "map iterator returned key that was not present in the map"); } } return true; } private: // Projects the keys from the map, setting `keys_` and `keys_iterator_`. If // this returns OK it is guaranteed that `keys_iterator_` is not null. absl::Status ProjectKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(keys_iterator_ == nullptr); CEL_RETURN_IF_ERROR( interface_->ListKeys(descriptor_pool, message_factory, arena, &keys_)); CEL_ASSIGN_OR_RETURN(keys_iterator_, keys_.NewIterator()); ABSL_CHECK(keys_iterator_->HasNext()); // Crash OK return absl::OkStatus(); } const CustomMapValueInterface* absl_nonnull const interface_; ListValue keys_; absl_nullable ValueIteratorPtr keys_iterator_; }; namespace { class CustomMapValueDispatcherIterator final : public ValueIterator { public: explicit CustomMapValueDispatcherIterator( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content) : dispatcher_(dispatcher), content_(content) {} bool HasNext() override { if (keys_iterator_ == nullptr) { if (dispatcher_->is_empty != nullptr) { return !dispatcher_->is_empty(dispatcher_, content_); } return dispatcher_->size(dispatcher_, content_) != 0; } return keys_iterator_->HasNext(); } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (keys_iterator_ == nullptr) { if (dispatcher_->is_empty != nullptr ? dispatcher_->is_empty(dispatcher_, content_) : dispatcher_->size(dispatcher_, content_) == 0) { return absl::FailedPreconditionError( "ValueIterator::Next() called when " "ValueIterator::HasNext() returns false"); } CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); } return keys_iterator_->Next(descriptor_pool, message_factory, arena, result); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (keys_iterator_ == nullptr) { if (dispatcher_->is_empty != nullptr ? dispatcher_->is_empty(dispatcher_, content_) : dispatcher_->size(dispatcher_, content_) == 0) { return false; } CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); } return keys_iterator_->Next1(descriptor_pool, message_factory, arena, key_or_value); } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); ABSL_DCHECK(value != nullptr); if (keys_iterator_ == nullptr) { if (dispatcher_->is_empty != nullptr ? dispatcher_->is_empty(dispatcher_, content_) : dispatcher_->size(dispatcher_, content_) == 0) { return false; } CEL_RETURN_IF_ERROR(ProjectKeys(descriptor_pool, message_factory, arena)); } CEL_ASSIGN_OR_RETURN( bool ok, keys_iterator_->Next1(descriptor_pool, message_factory, arena, key)); if (!ok) { return false; } if (value != nullptr) { CEL_ASSIGN_OR_RETURN( ok, dispatcher_->find(dispatcher_, content_, *key, descriptor_pool, message_factory, arena, value)); if (!ok) { return absl::DataLossError( "map iterator returned key that was not present in the map"); } } return true; } private: absl::Status ProjectKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(keys_iterator_ == nullptr); CEL_RETURN_IF_ERROR(dispatcher_->list_keys(dispatcher_, content_, descriptor_pool, message_factory, arena, &keys_)); CEL_ASSIGN_OR_RETURN(keys_iterator_, keys_.NewIterator()); ABSL_CHECK(keys_iterator_->HasNext()); // Crash OK return absl::OkStatus(); } const CustomMapValueDispatcher* absl_nonnull const dispatcher_; const CustomMapValueContent content_; ListValue keys_; absl_nullable ValueIteratorPtr keys_iterator_; }; } // namespace absl::Status CustomMapValueInterface::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); StructReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor_pool)); const google::protobuf::Message* prototype = message_factory->GetPrototype(reflection.GetDescriptor()); if (prototype == nullptr) { return absl::UnknownError( absl::StrCat("failed to get message prototype: ", reflection.GetDescriptor()->full_name())); } google::protobuf::Arena arena; google::protobuf::Message* message = prototype->New(&arena); CEL_RETURN_IF_ERROR( ConvertToJsonObject(descriptor_pool, message_factory, message)); if (!message->SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Struct"); } return absl::OkStatus(); } absl::Status CustomMapValueInterface::ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { CEL_ASSIGN_OR_RETURN(auto iterator, NewIterator()); while (iterator->HasNext()) { Value key; Value value; CEL_RETURN_IF_ERROR( iterator->Next(descriptor_pool, message_factory, arena, &key)); CEL_ASSIGN_OR_RETURN( bool found, Find(key, descriptor_pool, message_factory, arena, &value)); if (!found) { value = ErrorValue(NoSuchKeyError(key)); } CEL_ASSIGN_OR_RETURN(auto ok, callback(key, value)); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr CustomMapValueInterface::NewIterator() const { return std::make_unique(this); } absl::Status CustomMapValueInterface::Equal( const MapValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return MapValueEqual(*this, other, descriptor_pool, message_factory, arena, result); } CustomMapValue::CustomMapValue() { content_ = CustomMapValueContent::From(CustomMapValueInterface::Content{ .interface = &EmptyMapValue::Get(), .arena = nullptr}); } NativeTypeId CustomMapValue::GetTypeId() const { if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->GetNativeTypeId(); } return dispatcher_->get_type_id(dispatcher_, content_); } absl::string_view CustomMapValue::GetTypeName() const { return "map"; } std::string CustomMapValue::DebugString() const { if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->DebugString(); } if (dispatcher_->debug_string != nullptr) { return dispatcher_->debug_string(dispatcher_, content_); } return "map"; } absl::Status CustomMapValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->SerializeTo(descriptor_pool, message_factory, output); } if (dispatcher_->serialize_to != nullptr) { return dispatcher_->serialize_to(dispatcher_, content_, descriptor_pool, message_factory, output); } return absl::UnimplementedError( absl::StrCat(GetTypeName(), " is unserializable")); } absl::Status CustomMapValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); google::protobuf::Message* json_object = value_reflection.MutableStructValue(json); return ConvertToJsonObject(descriptor_pool, message_factory, json_object); } absl::Status CustomMapValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->ConvertToJsonObject(descriptor_pool, message_factory, json); } if (dispatcher_->convert_to_json_object != nullptr) { return dispatcher_->convert_to_json_object( dispatcher_, content_, descriptor_pool, message_factory, json); } return absl::UnimplementedError( absl::StrCat(GetTypeName(), " is not convertable to JSON")); } absl::Status CustomMapValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_map_value = other.AsMap(); other_map_value) { if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Equal(*other_map_value, descriptor_pool, message_factory, arena, result); } if (dispatcher_->equal != nullptr) { return dispatcher_->equal(dispatcher_, content_, *other_map_value, descriptor_pool, message_factory, arena, result); } return common_internal::MapValueEqual(*this, *other_map_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } bool CustomMapValue::IsZeroValue() const { if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->IsZeroValue(); } return dispatcher_->is_zero_value(dispatcher_, content_); } CustomMapValue CustomMapValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); if (content.arena != arena) { return content.interface->Clone(arena); } return *this; } return dispatcher_->clone(dispatcher_, content_, arena); } bool CustomMapValue::IsEmpty() const { if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->IsEmpty(); } if (dispatcher_->is_empty != nullptr) { return dispatcher_->is_empty(dispatcher_, content_); } return dispatcher_->size(dispatcher_, content_) == 0; } size_t CustomMapValue::Size() const { if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Size(); } return dispatcher_->size(dispatcher_, content_); } absl::Status CustomMapValue::Get( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); CEL_ASSIGN_OR_RETURN( bool ok, Find(key, descriptor_pool, message_factory, arena, result)); if (ABSL_PREDICT_FALSE(!ok)) { switch (result->kind()) { case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: break; default: *result = ErrorValue(NoSuchKeyError(key)); break; } } return absl::OkStatus(); } absl::StatusOr CustomMapValue::Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); switch (key.kind()) { case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: *result = key; return false; case ValueKind::kBool: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kInt: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUint: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kString: break; default: *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); return false; } bool ok; if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); CEL_ASSIGN_OR_RETURN( ok, content.interface->Find(key, descriptor_pool, message_factory, arena, result)); } else { CEL_ASSIGN_OR_RETURN( ok, dispatcher_->find(dispatcher_, content_, key, descriptor_pool, message_factory, arena, result)); } if (ok) { return true; } *result = NullValue{}; return false; } absl::Status CustomMapValue::Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); switch (key.kind()) { case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: *result = key; return absl::OkStatus(); case ValueKind::kBool: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kInt: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUint: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kString: break; default: *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); return absl::OkStatus(); } bool has; if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); CEL_ASSIGN_OR_RETURN(has, content.interface->Has(key, descriptor_pool, message_factory, arena)); } else { CEL_ASSIGN_OR_RETURN( has, dispatcher_->has(dispatcher_, content_, key, descriptor_pool, message_factory, arena)); } *result = BoolValue(has); return absl::OkStatus(); } absl::Status CustomMapValue::ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->ListKeys(descriptor_pool, message_factory, arena, result); } return dispatcher_->list_keys(dispatcher_, content_, descriptor_pool, message_factory, arena, result); } absl::Status CustomMapValue::ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->ForEach(callback, descriptor_pool, message_factory, arena); } if (dispatcher_->for_each != nullptr) { return dispatcher_->for_each(dispatcher_, content_, callback, descriptor_pool, message_factory, arena); } absl_nonnull ValueIteratorPtr iterator; if (dispatcher_->new_iterator != nullptr) { CEL_ASSIGN_OR_RETURN(iterator, dispatcher_->new_iterator(dispatcher_, content_)); } else { iterator = std::make_unique(dispatcher_, content_); } while (iterator->HasNext()) { Value key; Value value; CEL_RETURN_IF_ERROR( iterator->Next(descriptor_pool, message_factory, arena, &key)); CEL_ASSIGN_OR_RETURN( bool found, dispatcher_->find(dispatcher_, content_, key, descriptor_pool, message_factory, arena, &value)); if (!found) { value = ErrorValue(NoSuchKeyError(key)); } CEL_ASSIGN_OR_RETURN(auto ok, callback(key, value)); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr CustomMapValue::NewIterator() const { if (dispatcher_ == nullptr) { CustomMapValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->NewIterator(); } if (dispatcher_->new_iterator != nullptr) { return dispatcher_->new_iterator(dispatcher_, content_); } return std::make_unique(dispatcher_, content_); } } // namespace cel ================================================ FILE: common/values/custom_map_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" // `CustomMapValue` represents values of the primitive `map` type. // `CustomMapValueView` is a non-owning view of `CustomMapValue`. // `CustomMapValueInterface` is the abstract base class of implementations. // `CustomMapValue` and `CustomMapValueView` act as smart pointers to // `CustomMapValueInterface`. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_VALUE_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/native_type.h" #include "common/value_kind.h" #include "common/values/custom_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class ListValue; class CustomMapValueInterface; class CustomMapValueInterfaceKeysIterator; class CustomMapValue; using CustomMapValueContent = CustomValueContent; struct CustomMapValueDispatcher { using GetTypeId = NativeTypeId (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content); using GetArena = google::protobuf::Arena* absl_nullable (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content); using DebugString = std::string (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content); using SerializeTo = absl::Status (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output); using ConvertToJsonObject = absl::Status (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json); using Equal = absl::Status (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const MapValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using IsZeroValue = bool (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content); using IsEmpty = bool (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content); using Size = size_t (*)(const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content); using Find = absl::StatusOr (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using Has = absl::StatusOr (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); using ListKeys = absl::Status (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result); using ForEach = absl::Status (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, absl::FunctionRef(const Value&, const Value&)> callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); using NewIterator = absl::StatusOr (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content); using Clone = CustomMapValue (*)( const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, google::protobuf::Arena* absl_nonnull arena); absl_nonnull GetTypeId get_type_id; absl_nonnull GetArena get_arena; // If null, simply returns "map". absl_nullable DebugString debug_string = nullptr; // If null, attempts to serialize results in an UNIMPLEMENTED error. absl_nullable SerializeTo serialize_to = nullptr; // If null, attempts to convert to JSON results in an UNIMPLEMENTED error. absl_nullable ConvertToJsonObject convert_to_json_object = nullptr; // If null, an nonoptimal fallback implementation for equality is used. absl_nullable Equal equal = nullptr; absl_nonnull IsZeroValue is_zero_value; // If null, `size(...) == 0` is used. absl_nullable IsEmpty is_empty = nullptr; absl_nonnull Size size; absl_nonnull Find find; absl_nonnull Has has; absl_nonnull ListKeys list_keys; // If null, a fallback implementation based on `list_keys` is used. absl_nullable ForEach for_each = nullptr; // If null, a fallback implementation based on `list_keys` is used. absl_nullable NewIterator new_iterator = nullptr; absl_nonnull Clone clone; }; class CustomMapValueInterface { public: CustomMapValueInterface() = default; CustomMapValueInterface(const CustomMapValueInterface&) = delete; CustomMapValueInterface(CustomMapValueInterface&&) = delete; virtual ~CustomMapValueInterface() = default; CustomMapValueInterface& operator=(const CustomMapValueInterface&) = delete; CustomMapValueInterface& operator=(CustomMapValueInterface&&) = delete; using ForEachCallback = absl::FunctionRef(const Value&, const Value&)>; private: friend class CustomMapValueInterfaceIterator; friend class CustomMapValue; friend absl::Status common_internal::MapValueEqual( const CustomMapValueInterface& lhs, const MapValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); virtual std::string DebugString() const = 0; virtual absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; virtual absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const = 0; virtual absl::Status Equal( const MapValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; virtual bool IsZeroValue() const { return IsEmpty(); } // Returns `true` if this map contains no entries, `false` otherwise. virtual bool IsEmpty() const { return Size() == 0; } // Returns the number of entries in this map. virtual size_t Size() const = 0; // See the corresponding member function of `MapValue` for // documentation. virtual absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const = 0; // See the corresponding member function of `MapValue` for // documentation. virtual absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; // By default, implementations do not guarantee any iteration order. Unless // specified otherwise, assume the iteration order is random. virtual absl::StatusOr NewIterator() const; virtual CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const = 0; virtual absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; virtual absl::StatusOr Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const = 0; virtual NativeTypeId GetNativeTypeId() const = 0; struct Content { const CustomMapValueInterface* absl_nonnull interface; google::protobuf::Arena* absl_nullable arena; }; }; // Creates a custom map value from a manual dispatch table `dispatcher` and // opaque data `content` whose format is only know to functions in the manual // dispatch table. The dispatch table should probably be valid for the lifetime // of the process, but at a minimum must outlive all instances of the resulting // value. // // IMPORTANT: This approach to implementing CustomMapValue should only be // used when you know exactly what you are doing. When in doubt, just implement // CustomMapValueInterface. CustomMapValue UnsafeCustomMapValue(const CustomMapValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomMapValueContent content); class CustomMapValue final : private common_internal::MapValueMixin { public: static constexpr ValueKind kKind = ValueKind::kMap; // Constructs a custom map value from an implementation of // `CustomMapValueInterface` `interface` whose lifetime is tied to that of // the arena `arena`. CustomMapValue(const CustomMapValueInterface* absl_nonnull interface ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(interface != nullptr); ABSL_DCHECK(arena != nullptr); content_ = CustomMapValueContent::From(CustomMapValueInterface::Content{ .interface = interface, .arena = arena}); } // By default, this creates an empty map whose type is `map(dyn, dyn)`. Unless // you can help it, you should use a more specific typed map value. CustomMapValue(); CustomMapValue(const CustomMapValue&) = default; CustomMapValue(CustomMapValue&&) = default; CustomMapValue& operator=(const CustomMapValue&) = default; CustomMapValue& operator=(CustomMapValue&&) = default; static constexpr ValueKind kind() { return kKind; } NativeTypeId GetTypeId() const; absl::string_view GetTypeName() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonObject(). absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Equal; bool IsZeroValue() const; CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const; bool IsEmpty() const; size_t Size() const; // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Get; // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Has; // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; // See the corresponding type declaration of `MapValueInterface` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr NewIterator() const; const CustomMapValueDispatcher* absl_nullable dispatcher() const { return dispatcher_; } CustomMapValueContent content() const { ABSL_DCHECK(dispatcher_ != nullptr); return content_; } const CustomMapValueInterface* absl_nullable interface() const { if (dispatcher_ == nullptr) { return content_.To().interface; } return nullptr; } friend void swap(CustomMapValue& lhs, CustomMapValue& rhs) noexcept { using std::swap; swap(lhs.dispatcher_, rhs.dispatcher_); swap(lhs.content_, rhs.content_); } private: friend class common_internal::ValueMixin; friend class common_internal::MapValueMixin; friend CustomMapValue UnsafeCustomMapValue( const CustomMapValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomMapValueContent content); CustomMapValue(const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content) : dispatcher_(dispatcher), content_(content) { ABSL_DCHECK(dispatcher != nullptr); ABSL_DCHECK(dispatcher->get_type_id != nullptr); ABSL_DCHECK(dispatcher->get_arena != nullptr); ABSL_DCHECK(dispatcher->is_zero_value != nullptr); ABSL_DCHECK(dispatcher->size != nullptr); ABSL_DCHECK(dispatcher->find != nullptr); ABSL_DCHECK(dispatcher->has != nullptr); ABSL_DCHECK(dispatcher->list_keys != nullptr); ABSL_DCHECK(dispatcher->clone != nullptr); } const CustomMapValueDispatcher* absl_nullable dispatcher_ = nullptr; CustomMapValueContent content_ = CustomMapValueContent::Zero(); }; inline std::ostream& operator<<(std::ostream& out, const CustomMapValue& type) { return out << type.DebugString(); } template <> struct NativeTypeTraits final { static NativeTypeId Id(const CustomMapValue& type) { return type.GetTypeId(); } }; inline CustomMapValue UnsafeCustomMapValue( const CustomMapValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomMapValueContent content) { return CustomMapValue(dispatcher, content); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_VALUE_H_ ================================================ FILE: common/values/custom_map_value_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/memory.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "common/values/list_value_builder.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::cel::test::StringValueIs; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::IsNull; using ::testing::Not; using ::testing::NotNull; using ::testing::Optional; using ::testing::Pair; using ::testing::UnorderedElementsAre; class CustomMapValueTest; struct CustomMapValueTestContent { google::protobuf::Arena* absl_nonnull arena; }; class CustomMapValueInterfaceTest final : public CustomMapValueInterface { public: std::string DebugString() const override { return "{\"foo\": true, \"bar\": 1}"; } absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const override { google::protobuf::Value json; google::protobuf::ListValue* json_array = json.mutable_list_value(); json_array->add_values()->set_bool_value(true); json_array->add_values()->set_number_value(1.0); if (!json.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Value"); } return absl::OkStatus(); } absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { google::protobuf::Struct json_object; (*json_object.mutable_fields())["foo"].set_bool_value(true); (*json_object.mutable_fields())["bar"].set_number_value(1.0); absl::Cord serialized; if (!json_object.SerializePartialToString(&serialized)) { return absl::UnknownError("failed to serialize google.protobuf.Struct"); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError("failed to parse google.protobuf.Struct"); } return absl::OkStatus(); } size_t Size() const override { return 2; } absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const override { auto builder = common_internal::NewListValueBuilder(arena); builder->Reserve(2); CEL_RETURN_IF_ERROR(builder->Add(StringValue("foo"))); CEL_RETURN_IF_ERROR(builder->Add(StringValue("bar"))); *result = std::move(*builder).Build(); return absl::OkStatus(); } CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { return CustomMapValue( (::new (arena->AllocateAligned(sizeof(CustomMapValueInterfaceTest), alignof(CustomMapValueInterfaceTest))) CustomMapValueInterfaceTest()), arena); } private: absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { if (auto string_key = key.AsString(); string_key) { if (*string_key == "foo") { *result = TrueValue(); return true; } if (*string_key == "bar") { *result = IntValue(1); return true; } } return false; } absl::StatusOr Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { if (auto string_key = key.AsString(); string_key) { if (*string_key == "foo") { return true; } if (*string_key == "bar") { return true; } } return false; } NativeTypeId GetNativeTypeId() const override { return NativeTypeId::For(); } }; class CustomMapValueTest : public common_internal::ValueTest<> { public: CustomMapValue MakeInterface() { return CustomMapValue( (::new (arena()->AllocateAligned(sizeof(CustomMapValueInterfaceTest), alignof(CustomMapValueInterfaceTest))) CustomMapValueInterfaceTest()), arena()); } CustomMapValue MakeDispatcher() { return UnsafeCustomMapValue( &test_dispatcher_, CustomValueContent::From( CustomMapValueTestContent{.arena = arena()})); } protected: CustomMapValueDispatcher test_dispatcher_ = { .get_type_id = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content) -> NativeTypeId { return NativeTypeId::For(); }, .get_arena = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content) -> google::protobuf::Arena* absl_nullable { return content.To().arena; }, .debug_string = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content) -> std::string { return "{\"foo\": true, \"bar\": 1}"; }, .serialize_to = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) -> absl::Status { google::protobuf::Value json; google::protobuf::Struct* json_object = json.mutable_struct_value(); (*json_object->mutable_fields())["foo"].set_bool_value(true); (*json_object->mutable_fields())["bar"].set_number_value(1.0); if (!json.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Value"); } return absl::OkStatus(); }, .convert_to_json_object = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) -> absl::Status { { google::protobuf::Struct json_object; (*json_object.mutable_fields())["foo"].set_bool_value(true); (*json_object.mutable_fields())["bar"].set_number_value(1.0); absl::Cord serialized; if (!json_object.SerializePartialToString(&serialized)) { return absl::UnknownError( "failed to serialize google.protobuf.Struct"); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError("failed to parse google.protobuf.Struct"); } return absl::OkStatus(); } }, .is_zero_value = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content) -> bool { return false; }, .size = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content) -> size_t { return 2; }, .find = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) -> absl::StatusOr { if (auto string_key = key.AsString(); string_key) { if (*string_key == "foo") { *result = TrueValue(); return true; } if (*string_key == "bar") { *result = IntValue(1); return true; } } return false; }, .has = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { if (auto string_key = key.AsString(); string_key) { if (*string_key == "foo") { return true; } if (*string_key == "bar") { return true; } } return false; }, .list_keys = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) -> absl::Status { auto builder = common_internal::NewListValueBuilder(arena); builder->Reserve(2); CEL_RETURN_IF_ERROR(builder->Add(StringValue("foo"))); CEL_RETURN_IF_ERROR(builder->Add(StringValue("bar"))); *result = std::move(*builder).Build(); return absl::OkStatus(); }, .clone = [](const CustomMapValueDispatcher* absl_nonnull dispatcher, CustomMapValueContent content, google::protobuf::Arena* absl_nonnull arena) -> CustomMapValue { return UnsafeCustomMapValue( dispatcher, CustomValueContent::From( CustomMapValueTestContent{.arena = arena})); }, }; }; TEST_F(CustomMapValueTest, Kind) { EXPECT_EQ(CustomMapValue::kind(), CustomMapValue::kKind); } TEST_F(CustomMapValueTest, Dispatcher_GetTypeId) { EXPECT_EQ(MakeDispatcher().GetTypeId(), NativeTypeId::For()); } TEST_F(CustomMapValueTest, Interface_GetTypeId) { EXPECT_EQ(MakeInterface().GetTypeId(), NativeTypeId::For()); } TEST_F(CustomMapValueTest, Dispatcher_GetTypeName) { EXPECT_EQ(MakeDispatcher().GetTypeName(), "map"); } TEST_F(CustomMapValueTest, Interface_GetTypeName) { EXPECT_EQ(MakeInterface().GetTypeName(), "map"); } TEST_F(CustomMapValueTest, Dispatcher_DebugString) { EXPECT_EQ(MakeDispatcher().DebugString(), "{\"foo\": true, \"bar\": 1}"); } TEST_F(CustomMapValueTest, Interface_DebugString) { EXPECT_EQ(MakeInterface().DebugString(), "{\"foo\": true, \"bar\": 1}"); } TEST_F(CustomMapValueTest, Dispatcher_IsZeroValue) { EXPECT_FALSE(MakeDispatcher().IsZeroValue()); } TEST_F(CustomMapValueTest, Interface_IsZeroValue) { EXPECT_FALSE(MakeInterface().IsZeroValue()); } TEST_F(CustomMapValueTest, Dispatcher_SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(MakeDispatcher().SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); } TEST_F(CustomMapValueTest, Interface_SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(MakeInterface().SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); } TEST_F(CustomMapValueTest, Dispatcher_ConvertToJson) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeDispatcher().ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( struct_value: { fields: { key: "foo" value: { bool_value: true } } fields: { key: "bar" value: { number_value: 1.0 } } } )pb")); } TEST_F(CustomMapValueTest, Interface_ConvertToJson) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeInterface().ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( struct_value: { fields: { key: "foo" value: { bool_value: true } } fields: { key: "bar" value: { number_value: 1.0 } } } )pb")); } TEST_F(CustomMapValueTest, Dispatcher_ConvertToJsonObject) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeDispatcher().ConvertToJsonObject(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( fields: { key: "foo" value: { bool_value: true } } fields: { key: "bar" value: { number_value: 1.0 } } )pb")); } TEST_F(CustomMapValueTest, Interface_ConvertToJsonObject) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeInterface().ConvertToJsonObject(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( fields: { key: "foo" value: { bool_value: true } } fields: { key: "bar" value: { number_value: 1.0 } } )pb")); } TEST_F(CustomMapValueTest, Dispatcher_IsEmpty) { EXPECT_FALSE(MakeDispatcher().IsEmpty()); } TEST_F(CustomMapValueTest, Interface_IsEmpty) { EXPECT_FALSE(MakeInterface().IsEmpty()); } TEST_F(CustomMapValueTest, Dispatcher_Size) { EXPECT_EQ(MakeDispatcher().Size(), 2); } TEST_F(CustomMapValueTest, Interface_Size) { EXPECT_EQ(MakeInterface().Size(), 2); } TEST_F(CustomMapValueTest, Dispatcher_Get) { CustomMapValue map = MakeDispatcher(); ASSERT_THAT(map.Get(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_THAT(map.Get(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); ASSERT_THAT( map.Get(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); } TEST_F(CustomMapValueTest, Interface_Get) { CustomMapValue map = MakeInterface(); ASSERT_THAT(map.Get(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_THAT(map.Get(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); ASSERT_THAT( map.Get(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); } TEST_F(CustomMapValueTest, Dispatcher_Find) { CustomMapValue map = MakeDispatcher(); ASSERT_THAT(map.Find(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(true)))); ASSERT_THAT(map.Find(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IntValueIs(1)))); ASSERT_THAT(map.Find(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomMapValueTest, Interface_Find) { CustomMapValue map = MakeInterface(); ASSERT_THAT(map.Find(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(true)))); ASSERT_THAT(map.Find(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IntValueIs(1)))); ASSERT_THAT(map.Find(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomMapValueTest, Dispatcher_Has) { CustomMapValue map = MakeDispatcher(); ASSERT_THAT(map.Has(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_THAT(map.Has(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_THAT(map.Has(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } TEST_F(CustomMapValueTest, Interface_Has) { CustomMapValue map = MakeInterface(); ASSERT_THAT(map.Has(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_THAT(map.Has(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_THAT(map.Has(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } TEST_F(CustomMapValueTest, Dispatcher_ForEach) { std::vector> entries; EXPECT_THAT( MakeDispatcher().ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{key, value}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre( Pair(StringValueIs("foo"), BoolValueIs(true)), Pair(StringValueIs("bar"), IntValueIs(1)))); } TEST_F(CustomMapValueTest, Interface_ForEach) { std::vector> entries; EXPECT_THAT( MakeInterface().ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{key, value}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre( Pair(StringValueIs("foo"), BoolValueIs(true)), Pair(StringValueIs("bar"), IntValueIs(1)))); } TEST_F(CustomMapValueTest, Dispatcher_NewIterator) { CustomMapValue map = MakeDispatcher(); ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs("foo"))); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs("bar"))); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(CustomMapValueTest, Interface_NewIterator) { CustomMapValue map = MakeInterface(); ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs("foo"))); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs("bar"))); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(CustomMapValueTest, Dispatcher_NewIterator1) { CustomMapValue map = MakeDispatcher(); ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(StringValueIs("foo")))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(StringValueIs("bar")))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomMapValueTest, Interface_NewIterator1) { CustomMapValue map = MakeInterface(); ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(StringValueIs("foo")))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(StringValueIs("bar")))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomMapValueTest, Dispatcher_NewIterator2) { CustomMapValue map = MakeDispatcher(); ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); EXPECT_THAT( iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(StringValueIs("foo"), BoolValueIs(true))))); EXPECT_THAT( iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(StringValueIs("bar"), IntValueIs(1))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomMapValueTest, Interface_NewIterator2) { CustomMapValue map = MakeInterface(); ASSERT_OK_AND_ASSIGN(auto iterator, map.NewIterator()); EXPECT_THAT( iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(StringValueIs("foo"), BoolValueIs(true))))); EXPECT_THAT( iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(StringValueIs("bar"), IntValueIs(1))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(CustomMapValueTest, Dispatcher) { EXPECT_THAT(MakeDispatcher().dispatcher(), NotNull()); EXPECT_THAT(MakeDispatcher().interface(), IsNull()); } TEST_F(CustomMapValueTest, Interface) { EXPECT_THAT(MakeInterface().dispatcher(), IsNull()); EXPECT_THAT(MakeInterface().interface(), NotNull()); } } // namespace } // namespace cel ================================================ FILE: common/values/custom_struct_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/function_ref.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/native_type.h" #include "common/type.h" #include "common/value.h" #include "common/values/values.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::ValueReflection; } // namespace absl::Status CustomStructValueInterface::Equal( const StructValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return common_internal::StructValueEqual(*this, other, descriptor_pool, message_factory, arena, result); } absl::Status CustomStructValueInterface::Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const { return absl::UnimplementedError(absl::StrCat( GetTypeName(), " does not implement field selection optimization")); } NativeTypeId CustomStructValue::GetTypeId() const { if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); if (content.interface == nullptr) { return NativeTypeId(); } return content.interface->GetNativeTypeId(); } return dispatcher_->get_type_id(dispatcher_, content_); } StructType CustomStructValue::GetRuntimeType() const { ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->GetRuntimeType(); } if (dispatcher_->get_runtime_type != nullptr) { return dispatcher_->get_runtime_type(dispatcher_, content_); } return common_internal::MakeBasicStructType(GetTypeName()); } absl::string_view CustomStructValue::GetTypeName() const { ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->GetTypeName(); } return dispatcher_->get_type_name(dispatcher_, content_); } std::string CustomStructValue::DebugString() const { ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->DebugString(); } if (dispatcher_->debug_string != nullptr) { return dispatcher_->debug_string(dispatcher_, content_); } return std::string(GetTypeName()); } absl::Status CustomStructValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->SerializeTo(descriptor_pool, message_factory, output); } if (dispatcher_->serialize_to != nullptr) { return dispatcher_->serialize_to(dispatcher_, content_, descriptor_pool, message_factory, output); } return absl::UnimplementedError( absl::StrCat(GetTypeName(), " is unserializable")); } absl::Status CustomStructValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ABSL_DCHECK(*this); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); google::protobuf::Message* json_object = value_reflection.MutableStructValue(json); return ConvertToJsonObject(descriptor_pool, message_factory, json_object); } absl::Status CustomStructValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); if (ABSL_PREDICT_FALSE(content.interface == nullptr)) { json->Clear(); return absl::OkStatus(); } return content.interface->ConvertToJsonObject(descriptor_pool, message_factory, json); } if (dispatcher_->convert_to_json_object != nullptr) { return dispatcher_->convert_to_json_object( dispatcher_, content_, descriptor_pool, message_factory, json); } return absl::UnimplementedError( absl::StrCat(GetTypeName(), " is not convertable to JSON")); } absl::Status CustomStructValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(*this); if (auto other_struct_value = other.AsStruct(); other_struct_value) { if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Equal(*other_struct_value, descriptor_pool, message_factory, arena, result); } if (dispatcher_->equal != nullptr) { return dispatcher_->equal(dispatcher_, content_, *other_struct_value, descriptor_pool, message_factory, arena, result); } return common_internal::StructValueEqual(*this, *other_struct_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } bool CustomStructValue::IsZeroValue() const { ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); if (content.interface == nullptr) { return true; } return content.interface->IsZeroValue(); } return dispatcher_->is_zero_value(dispatcher_, content_); } CustomStructValue CustomStructValue::Clone( google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); if (content.interface == nullptr) { return *this; } if (content.arena != arena) { return content.interface->Clone(arena); } return *this; } return dispatcher_->clone(dispatcher_, content_, arena); } absl::Status CustomStructValue::GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->GetFieldByName(name, unboxing_options, descriptor_pool, message_factory, arena, result); } return dispatcher_->get_field_by_name(dispatcher_, content_, name, unboxing_options, descriptor_pool, message_factory, arena, result); } absl::Status CustomStructValue::GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->GetFieldByNumber(number, unboxing_options, descriptor_pool, message_factory, arena, result); } if (dispatcher_->get_field_by_number != nullptr) { return dispatcher_->get_field_by_number(dispatcher_, content_, number, unboxing_options, descriptor_pool, message_factory, arena, result); } return absl::UnimplementedError(absl::StrCat( GetTypeName(), " does not implement access by field number")); } absl::StatusOr CustomStructValue::HasFieldByName( absl::string_view name) const { ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->HasFieldByName(name); } return dispatcher_->has_field_by_name(dispatcher_, content_, name); } absl::StatusOr CustomStructValue::HasFieldByNumber(int64_t number) const { ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->HasFieldByNumber(number); } if (dispatcher_->has_field_by_number != nullptr) { return dispatcher_->has_field_by_number(dispatcher_, content_, number); } return absl::UnimplementedError(absl::StrCat( GetTypeName(), " does not implement access by field number")); } absl::Status CustomStructValue::ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->ForEachField(callback, descriptor_pool, message_factory, arena); } return dispatcher_->for_each_field(dispatcher_, content_, callback, descriptor_pool, message_factory, arena); } absl::Status CustomStructValue::Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const { ABSL_DCHECK_GT(qualifiers.size(), 0); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(count != nullptr); ABSL_DCHECK(*this); if (dispatcher_ == nullptr) { CustomStructValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Qualify(qualifiers, presence_test, descriptor_pool, message_factory, arena, result, count); } if (dispatcher_->qualify != nullptr) { return dispatcher_->qualify(dispatcher_, content_, qualifiers, presence_test, descriptor_pool, message_factory, arena, result, count); } return absl::UnimplementedError(absl::StrCat( GetTypeName(), " does not implement field selection optimization")); } } // namespace cel ================================================ FILE: common/values/custom_struct_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_STRUCT_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_STRUCT_VALUE_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/native_type.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_value.h" #include "common/values/values.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class CustomStructValueInterface; class CustomStructValue; class Value; struct CustomStructValueDispatcher; using CustomStructValueContent = CustomValueContent; struct CustomStructValueDispatcher { using GetTypeId = NativeTypeId (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content); using GetArena = google::protobuf::Arena* absl_nullable (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content); using GetTypeName = absl::string_view (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content); using DebugString = std::string (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content); using GetRuntimeType = StructType (*)(const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content); using SerializeTo = absl::Status (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output); using ConvertToJsonObject = absl::Status (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json); using Equal = absl::Status (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, const StructValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using IsZeroValue = bool (*)(const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content); using GetFieldByName = absl::Status (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using GetFieldByNumber = absl::Status (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using HasFieldByName = absl::StatusOr (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, absl::string_view name); using HasFieldByNumber = absl::StatusOr (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, int64_t number); using ForEachField = absl::Status (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, absl::FunctionRef(absl::string_view, const Value&)> callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); using Quality = absl::Status (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count); using Clone = CustomStructValue (*)( const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, google::protobuf::Arena* absl_nonnull arena); absl_nonnull GetTypeId get_type_id; absl_nonnull GetArena get_arena; absl_nonnull GetTypeName get_type_name; absl_nullable DebugString debug_string = nullptr; absl_nullable GetRuntimeType get_runtime_type = nullptr; absl_nullable SerializeTo serialize_to = nullptr; absl_nullable ConvertToJsonObject convert_to_json_object = nullptr; absl_nullable Equal equal = nullptr; absl_nonnull IsZeroValue is_zero_value; absl_nonnull GetFieldByName get_field_by_name; absl_nullable GetFieldByNumber get_field_by_number = nullptr; absl_nonnull HasFieldByName has_field_by_name; absl_nullable HasFieldByNumber has_field_by_number = nullptr; absl_nonnull ForEachField for_each_field; absl_nullable Quality qualify = nullptr; absl_nonnull Clone clone; }; class CustomStructValueInterface { public: CustomStructValueInterface() = default; CustomStructValueInterface(const CustomStructValueInterface&) = delete; CustomStructValueInterface(CustomStructValueInterface&&) = delete; virtual ~CustomStructValueInterface() = default; CustomStructValueInterface& operator=(const CustomStructValueInterface&) = delete; CustomStructValueInterface& operator=(CustomStructValueInterface&&) = delete; using ForEachFieldCallback = absl::FunctionRef(absl::string_view, const Value&)>; private: friend class CustomStructValue; friend absl::Status common_internal::StructValueEqual( const CustomStructValueInterface& lhs, const StructValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); virtual std::string DebugString() const = 0; virtual absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const = 0; virtual absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const = 0; virtual absl::string_view GetTypeName() const = 0; virtual StructType GetRuntimeType() const { return common_internal::MakeBasicStructType(GetTypeName()); } virtual absl::Status Equal( const StructValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; virtual bool IsZeroValue() const = 0; virtual absl::Status GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; virtual absl::Status GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; virtual absl::StatusOr HasFieldByName(absl::string_view name) const = 0; virtual absl::StatusOr HasFieldByNumber(int64_t number) const = 0; virtual absl::Status ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const = 0; virtual absl::Status Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const; virtual CustomStructValue Clone(google::protobuf::Arena* absl_nonnull arena) const = 0; virtual NativeTypeId GetNativeTypeId() const = 0; struct Content { const CustomStructValueInterface* absl_nonnull interface; google::protobuf::Arena* absl_nonnull arena; }; }; // Creates a custom struct value from a manual dispatch table `dispatcher` and // opaque data `content` whose format is only know to functions in the manual // dispatch table. The dispatch table should probably be valid for the lifetime // of the process, but at a minimum must outlive all instances of the resulting // value. // // IMPORTANT: This approach to implementing CustomStructValues should only be // used when you know exactly what you are doing. When in doubt, just implement // CustomStructValueInterface. CustomStructValue UnsafeCustomStructValue( const CustomStructValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomStructValueContent content); class CustomStructValue final : private common_internal::StructValueMixin { public: static constexpr ValueKind kKind = ValueKind::kStruct; // Constructs a custom struct value from an implementation of // `CustomStructValueInterface` `interface` whose lifetime is tied to that of // the arena `arena`. CustomStructValue(const CustomStructValueInterface* absl_nonnull interface ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(interface != nullptr); ABSL_DCHECK(arena != nullptr); content_ = CustomStructValueContent::From(CustomStructValueInterface::Content{ .interface = interface, .arena = arena}); } CustomStructValue() = default; CustomStructValue(const CustomStructValue&) = default; CustomStructValue(CustomStructValue&&) = default; CustomStructValue& operator=(const CustomStructValue&) = default; CustomStructValue& operator=(CustomStructValue&&) = default; static constexpr ValueKind kind() { return kKind; } NativeTypeId GetTypeId() const; StructType GetRuntimeType() const; absl::string_view GetTypeName() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonObject(). absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::Equal; bool IsZeroValue() const; CustomStructValue Clone(google::protobuf::Arena* absl_nonnull arena) const; absl::Status GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByName; absl::Status GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByNumber; absl::StatusOr HasFieldByName(absl::string_view name) const; absl::StatusOr HasFieldByNumber(int64_t number) const; using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; absl::Status ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::Status Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const; using StructValueMixin::Qualify; const CustomStructValueDispatcher* absl_nullable dispatcher() const { return dispatcher_; } CustomStructValueContent content() const { ABSL_DCHECK(dispatcher_ != nullptr); return content_; } const CustomStructValueInterface* absl_nullable interface() const { if (dispatcher_ == nullptr) { return content_.To().interface; } return nullptr; } explicit operator bool() const { if (dispatcher_ == nullptr) { return content_.To().interface != nullptr; } return true; } friend void swap(CustomStructValue& lhs, CustomStructValue& rhs) noexcept { using std::swap; swap(lhs.dispatcher_, rhs.dispatcher_); swap(lhs.content_, rhs.content_); } private: friend class common_internal::ValueMixin; friend class common_internal::StructValueMixin; friend CustomStructValue UnsafeCustomStructValue( const CustomStructValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomStructValueContent content); // Constructs a custom struct value from a dispatcher and content. Only // accessible from `UnsafeCustomStructValue`. CustomStructValue(const CustomStructValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomStructValueContent content) : dispatcher_(dispatcher), content_(content) { ABSL_DCHECK(dispatcher != nullptr); ABSL_DCHECK(dispatcher->get_type_id != nullptr); ABSL_DCHECK(dispatcher->get_arena != nullptr); ABSL_DCHECK(dispatcher->get_type_name != nullptr); ABSL_DCHECK(dispatcher->is_zero_value != nullptr); ABSL_DCHECK(dispatcher->get_field_by_name != nullptr); ABSL_DCHECK(dispatcher->has_field_by_name != nullptr); ABSL_DCHECK(dispatcher->for_each_field != nullptr); ABSL_DCHECK(dispatcher->clone != nullptr); } const CustomStructValueDispatcher* absl_nullable dispatcher_ = nullptr; CustomStructValueContent content_ = CustomStructValueContent::Zero(); }; inline std::ostream& operator<<(std::ostream& out, const CustomStructValue& value) { return out << value.DebugString(); } template <> struct NativeTypeTraits final { static NativeTypeId Id(const CustomStructValue& type) { return type.GetTypeId(); } }; inline CustomStructValue UnsafeCustomStructValue( const CustomStructValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, CustomStructValueContent content) { return CustomStructValue(dispatcher, content); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_STRUCT_VALUE_H_ ================================================ FILE: common/values/custom_struct_value_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "base/attribute.h" #include "common/memory.h" #include "common/native_type.h" #include "common/type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::IntValueIs; using ::testing::IsEmpty; using ::testing::IsNull; using ::testing::Not; using ::testing::NotNull; using ::testing::Pair; using ::testing::UnorderedElementsAre; class CustomStructValueTest; struct CustomStructValueTestContent { google::protobuf::Arena* absl_nonnull arena; }; class CustomStructValueInterfaceTest final : public CustomStructValueInterface { public: absl::string_view GetTypeName() const override { return "test.Interface"; } std::string DebugString() const override { return std::string(GetTypeName()); } bool IsZeroValue() const override { return false; } absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const override { google::protobuf::Value json; google::protobuf::Struct* json_object = json.mutable_struct_value(); (*json_object->mutable_fields())["foo"].set_bool_value(true); (*json_object->mutable_fields())["bar"].set_number_value(1.0); if (!json.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Value"); } return absl::OkStatus(); } absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { google::protobuf::Struct json_object; (*json_object.mutable_fields())["foo"].set_bool_value(true); (*json_object.mutable_fields())["bar"].set_number_value(1.0); absl::Cord serialized; if (!json_object.SerializePartialToString(&serialized)) { return absl::UnknownError("failed to serialize google.protobuf.Struct"); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError("failed to parse google.protobuf.Struct"); } return absl::OkStatus(); } absl::Status GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { if (name == "foo") { *result = TrueValue(); return absl::OkStatus(); } if (name == "bar") { *result = IntValue(1); return absl::OkStatus(); } return NoSuchFieldError(name).ToStatus(); } absl::Status GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { if (number == 1) { *result = TrueValue(); return absl::OkStatus(); } if (number == 2) { *result = IntValue(1); return absl::OkStatus(); } return NoSuchFieldError(absl::StrCat(number)).ToStatus(); } absl::StatusOr HasFieldByName(absl::string_view name) const override { if (name == "foo") { return true; } if (name == "bar") { return true; } return NoSuchFieldError(name).ToStatus(); } absl::StatusOr HasFieldByNumber(int64_t number) const override { if (number == 1) { return true; } if (number == 2) { return true; } return NoSuchFieldError(absl::StrCat(number)).ToStatus(); } absl::Status ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { CEL_ASSIGN_OR_RETURN(bool ok, callback("foo", TrueValue())); if (!ok) { return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(ok, callback("bar", IntValue(1))); return absl::OkStatus(); } CustomStructValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { return CustomStructValue( (::new (arena->AllocateAligned(sizeof(CustomStructValueInterfaceTest), alignof(CustomStructValueInterfaceTest))) CustomStructValueInterfaceTest()), arena); } private: NativeTypeId GetNativeTypeId() const override { return NativeTypeId::For(); } }; class CustomStructValueTest : public common_internal::ValueTest<> { public: CustomStructValue MakeInterface() { return CustomStructValue((::new (arena()->AllocateAligned( sizeof(CustomStructValueInterfaceTest), alignof(CustomStructValueInterfaceTest))) CustomStructValueInterfaceTest()), arena()); } CustomStructValue MakeDispatcher() { return UnsafeCustomStructValue( &test_dispatcher_, CustomValueContent::From( CustomStructValueTestContent{.arena = arena()})); } protected: CustomStructValueDispatcher test_dispatcher_ = { .get_type_id = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content) -> NativeTypeId { return NativeTypeId::For(); }, .get_arena = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content) -> google::protobuf::Arena* absl_nullable { return content.To().arena; }, .get_type_name = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content) -> absl::string_view { return "test.Dispatcher"; }, .debug_string = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content) -> std::string { return "test.Dispatcher"; }, .get_runtime_type = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content) -> StructType { return common_internal::MakeBasicStructType("test.Dispatcher"); }, .serialize_to = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) -> absl::Status { google::protobuf::Value json; google::protobuf::Struct* json_object = json.mutable_struct_value(); (*json_object->mutable_fields())["foo"].set_bool_value(true); (*json_object->mutable_fields())["bar"].set_number_value(1.0); if (!json.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Value"); } return absl::OkStatus(); }, .convert_to_json_object = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) -> absl::Status { google::protobuf::Struct json_object; (*json_object.mutable_fields())["foo"].set_bool_value(true); (*json_object.mutable_fields())["bar"].set_number_value(1.0); absl::Cord serialized; if (!json_object.SerializePartialToString(&serialized)) { return absl::UnknownError( "failed to serialize google.protobuf.Struct"); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError("failed to parse google.protobuf.Struct"); } return absl::OkStatus(); }, .is_zero_value = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content) -> bool { return false; }, .get_field_by_name = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) -> absl::Status { if (name == "foo") { *result = TrueValue(); return absl::OkStatus(); } if (name == "bar") { *result = IntValue(1); return absl::OkStatus(); } return NoSuchFieldError(name).ToStatus(); }, .get_field_by_number = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) -> absl::Status { if (number == 1) { *result = TrueValue(); return absl::OkStatus(); } if (number == 2) { *result = IntValue(1); return absl::OkStatus(); } return NoSuchFieldError(absl::StrCat(number)).ToStatus(); }, .has_field_by_name = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, absl::string_view name) -> absl::StatusOr { if (name == "foo") { return true; } if (name == "bar") { return true; } return NoSuchFieldError(name).ToStatus(); }, .has_field_by_number = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, int64_t number) -> absl::StatusOr { if (number == 1) { return true; } if (number == 2) { return true; } return NoSuchFieldError(absl::StrCat(number)).ToStatus(); }, .for_each_field = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, absl::FunctionRef(absl::string_view, const Value&)> callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::Status { CEL_ASSIGN_OR_RETURN(bool ok, callback("foo", TrueValue())); if (!ok) { return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(ok, callback("bar", IntValue(1))); return absl::OkStatus(); }, .clone = [](const CustomStructValueDispatcher* absl_nonnull dispatcher, CustomStructValueContent content, google::protobuf::Arena* absl_nonnull arena) -> CustomStructValue { return UnsafeCustomStructValue( dispatcher, CustomValueContent::From( CustomStructValueTestContent{.arena = arena})); }, }; }; TEST_F(CustomStructValueTest, Kind) { EXPECT_EQ(CustomStructValue::kind(), CustomStructValue::kKind); } TEST_F(CustomStructValueTest, Dispatcher_GetTypeId) { EXPECT_EQ(MakeDispatcher().GetTypeId(), NativeTypeId::For()); } TEST_F(CustomStructValueTest, Interface_GetTypeId) { EXPECT_EQ(MakeInterface().GetTypeId(), NativeTypeId::For()); } TEST_F(CustomStructValueTest, Dispatcher_GetTypeName) { EXPECT_EQ(MakeDispatcher().GetTypeName(), "test.Dispatcher"); } TEST_F(CustomStructValueTest, Interface_GetTypeName) { EXPECT_EQ(MakeInterface().GetTypeName(), "test.Interface"); } TEST_F(CustomStructValueTest, Dispatcher_DebugString) { EXPECT_EQ(MakeDispatcher().DebugString(), "test.Dispatcher"); } TEST_F(CustomStructValueTest, Interface_DebugString) { EXPECT_EQ(MakeInterface().DebugString(), "test.Interface"); } TEST_F(CustomStructValueTest, Dispatcher_GetRuntimeType) { EXPECT_EQ(MakeDispatcher().GetRuntimeType(), common_internal::MakeBasicStructType("test.Dispatcher")); } TEST_F(CustomStructValueTest, Interface_GetRuntimeType) { EXPECT_EQ(MakeInterface().GetRuntimeType(), common_internal::MakeBasicStructType("test.Interface")); } TEST_F(CustomStructValueTest, Dispatcher_IsZeroValue) { EXPECT_FALSE(MakeDispatcher().IsZeroValue()); } TEST_F(CustomStructValueTest, Interface_IsZeroValue) { EXPECT_FALSE(MakeInterface().IsZeroValue()); } TEST_F(CustomStructValueTest, Dispatcher_SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(MakeDispatcher().SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); } TEST_F(CustomStructValueTest, Interface_SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(MakeInterface().SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), Not(IsEmpty())); } TEST_F(CustomStructValueTest, Dispatcher_ConvertToJson) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeDispatcher().ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( struct_value: { fields: { key: "foo" value: { bool_value: true } } fields: { key: "bar" value: { number_value: 1.0 } } } )pb")); } TEST_F(CustomStructValueTest, Interface_ConvertToJson) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeInterface().ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( struct_value: { fields: { key: "foo" value: { bool_value: true } } fields: { key: "bar" value: { number_value: 1.0 } } } )pb")); } TEST_F(CustomStructValueTest, Dispatcher_ConvertToJsonObject) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeDispatcher().ConvertToJsonObject(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( fields: { key: "foo" value: { bool_value: true } } fields: { key: "bar" value: { number_value: 1.0 } } )pb")); } TEST_F(CustomStructValueTest, Interface_ConvertToJsonObject) { auto message = DynamicParseTextProto(); EXPECT_THAT( MakeInterface().ConvertToJsonObject(descriptor_pool(), message_factory(), cel::to_address(message)), IsOk()); EXPECT_THAT(*message, EqualsTextProto(R"pb( fields: { key: "foo" value: { bool_value: true } } fields: { key: "bar" value: { number_value: 1.0 } } )pb")); } TEST_F(CustomStructValueTest, Dispatcher_GetFieldByName) { EXPECT_THAT(MakeDispatcher().GetFieldByName("foo", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(MakeDispatcher().GetFieldByName("bar", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); } TEST_F(CustomStructValueTest, Interface_GetFieldByName) { EXPECT_THAT(MakeInterface().GetFieldByName("foo", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(MakeInterface().GetFieldByName("bar", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); } TEST_F(CustomStructValueTest, Dispatcher_GetFieldByNumber) { EXPECT_THAT(MakeDispatcher().GetFieldByNumber(1, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(MakeDispatcher().GetFieldByNumber(2, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); } TEST_F(CustomStructValueTest, Interface_GetFieldByNumber) { EXPECT_THAT(MakeInterface().GetFieldByNumber(1, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(MakeInterface().GetFieldByNumber(2, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); } TEST_F(CustomStructValueTest, Dispatcher_HasFieldByName) { EXPECT_THAT(MakeDispatcher().HasFieldByName("foo"), IsOkAndHolds(true)); EXPECT_THAT(MakeDispatcher().HasFieldByName("bar"), IsOkAndHolds(true)); } TEST_F(CustomStructValueTest, Interface_HasFieldByName) { EXPECT_THAT(MakeInterface().HasFieldByName("foo"), IsOkAndHolds(true)); EXPECT_THAT(MakeInterface().HasFieldByName("bar"), IsOkAndHolds(true)); } TEST_F(CustomStructValueTest, Dispatcher_HasFieldByNumber) { EXPECT_THAT(MakeDispatcher().HasFieldByNumber(1), IsOkAndHolds(true)); EXPECT_THAT(MakeDispatcher().HasFieldByNumber(2), IsOkAndHolds(true)); } TEST_F(CustomStructValueTest, Interface_HasFieldByNumber) { EXPECT_THAT(MakeInterface().HasFieldByNumber(1), IsOkAndHolds(true)); EXPECT_THAT(MakeInterface().HasFieldByNumber(2), IsOkAndHolds(true)); } TEST_F(CustomStructValueTest, Default_Bool) { EXPECT_FALSE(CustomStructValue()); } TEST_F(CustomStructValueTest, Dispatcher_Bool) { EXPECT_TRUE(MakeDispatcher()); } TEST_F(CustomStructValueTest, Interface_Bool) { EXPECT_TRUE(MakeInterface()); } TEST_F(CustomStructValueTest, Dispatcher_ForEachField) { std::vector> fields; EXPECT_THAT(MakeDispatcher().ForEachField( [&](absl::string_view name, const Value& value) -> absl::StatusOr { fields.push_back(std::pair{std::string(name), value}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(fields, UnorderedElementsAre(Pair("foo", BoolValueIs(true)), Pair("bar", IntValueIs(1)))); } TEST_F(CustomStructValueTest, Interface_ForEachField) { std::vector> fields; EXPECT_THAT(MakeInterface().ForEachField( [&](absl::string_view name, const Value& value) -> absl::StatusOr { fields.push_back(std::pair{std::string(name), value}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(fields, UnorderedElementsAre(Pair("foo", BoolValueIs(true)), Pair("bar", IntValueIs(1)))); } TEST_F(CustomStructValueTest, Dispatcher_Qualify) { EXPECT_THAT( MakeDispatcher().Qualify({AttributeQualifier::OfString("foo")}, false, descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kUnimplemented)); } TEST_F(CustomStructValueTest, Interface_Qualify) { EXPECT_THAT( MakeInterface().Qualify({AttributeQualifier::OfString("foo")}, false, descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kUnimplemented)); } TEST_F(CustomStructValueTest, Dispatcher) { EXPECT_THAT(MakeDispatcher().dispatcher(), NotNull()); EXPECT_THAT(MakeDispatcher().interface(), IsNull()); } TEST_F(CustomStructValueTest, Interface) { EXPECT_THAT(MakeInterface().dispatcher(), IsNull()); EXPECT_THAT(MakeInterface().interface(), NotNull()); } } // namespace } // namespace cel ================================================ FILE: common/values/custom_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_CUSTOM_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_CUSTOM_VALUE_H_ #include #include #include #include namespace cel { // CustomValueContent is an opaque 16-byte trivially copyable value. The format // of the data stored within is unknown to everything except the the caller // which creates it. Do not try to interpret it otherwise. class CustomValueContent final { public: static CustomValueContent Zero() { CustomValueContent content; std::memset(&content, 0, sizeof(content)); return content; } template static CustomValueContent From(T value) { static_assert(std::is_trivially_copyable_v, "T must be trivially copyable"); static_assert(sizeof(T) <= 16, "sizeof(T) must be no greater than 16"); CustomValueContent content; std::memcpy(content.raw_, std::addressof(value), sizeof(T)); return content; } template static CustomValueContent From(const T (&array)[N]) { static_assert(std::is_trivially_copyable_v, "T must be trivially copyable"); static_assert((sizeof(T) * N) <= 16, "sizeof(T[N]) must be no greater than 16"); CustomValueContent content; std::memcpy(content.raw_, array, sizeof(T) * N); return content; } template T To() const { static_assert(std::is_trivially_copyable_v, "T must be trivially copyable"); static_assert(sizeof(T) <= 16, "sizeof(T) must be no greater than 16"); T value; std::memcpy(std::addressof(value), raw_, sizeof(T)); return value; } bool IsZero() const { static const CustomValueContent kZero = Zero(); return std::memcmp(raw_, kZero.raw_, sizeof(raw_)) == 0; } private: alignas(void*) std::byte raw_[16]; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_CUSTOM_VALUE_H_ ================================================ FILE: common/values/double_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "common/value.h" #include "internal/number.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::ValueReflection; std::string DoubleDebugString(double value) { if (std::isfinite(value)) { if (std::floor(value) != value) { // The double is not representable as a whole number, so use // absl::StrCat which will add decimal places. return absl::StrCat(value); } // absl::StrCat historically would represent 0.0 as 0, and we want the // decimal places so ZetaSQL correctly assumes the type as double // instead of int64. std::string stringified = absl::StrCat(value); if (!absl::StrContains(stringified, '.')) { absl::StrAppend(&stringified, ".0"); } else { // absl::StrCat has a decimal now? Use it directly. } return stringified; } if (std::isnan(value)) { return "nan"; } if (std::signbit(value)) { return "-infinity"; } return "+infinity"; } } // namespace std::string DoubleValue::DebugString() const { return DoubleDebugString(NativeValue()); } absl::Status DoubleValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::DoubleValue message; message.set_value(NativeValue()); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", message.GetTypeName())); } return absl::OkStatus(); } absl::Status DoubleValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.SetNumberValue(json, NativeValue()); return absl::OkStatus(); } absl::Status DoubleValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsDouble(); other_value.has_value()) { *result = BoolValue{NativeValue() == other_value->NativeValue()}; return absl::OkStatus(); } if (auto other_value = other.AsInt(); other_value.has_value()) { *result = BoolValue{internal::Number::FromDouble(NativeValue()) == internal::Number::FromInt64(other_value->NativeValue())}; return absl::OkStatus(); } if (auto other_value = other.AsUint(); other_value.has_value()) { *result = BoolValue{internal::Number::FromDouble(NativeValue()) == internal::Number::FromUint64(other_value->NativeValue())}; return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/double_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_DOUBLE_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_DOUBLE_VALUE_H_ #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class DoubleValue; class DoubleValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kDouble; explicit DoubleValue(double value) noexcept : value_(value) {} DoubleValue() = default; DoubleValue(const DoubleValue&) = default; DoubleValue(DoubleValue&&) = default; DoubleValue& operator=(const DoubleValue&) = default; DoubleValue& operator=(DoubleValue&&) = default; ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return DoubleType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return NativeValue() == 0.0; } double NativeValue() const { return static_cast(*this); } // NOLINTNEXTLINE(google-explicit-constructor) operator double() const noexcept { return value_; } friend void swap(DoubleValue& lhs, DoubleValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); } private: friend class common_internal::ValueMixin; double value_ = 0.0; }; inline std::ostream& operator<<(std::ostream& out, DoubleValue value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_DOUBLE_VALUE_H_ ================================================ FILE: common/values/double_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/status/status_matchers.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using DoubleValueTest = common_internal::ValueTest<>; TEST_F(DoubleValueTest, Kind) { EXPECT_EQ(DoubleValue(1.0).kind(), DoubleValue::kKind); EXPECT_EQ(Value(DoubleValue(1.0)).kind(), DoubleValue::kKind); } TEST_F(DoubleValueTest, DebugString) { { std::ostringstream out; out << DoubleValue(0.0); EXPECT_EQ(out.str(), "0.0"); } { std::ostringstream out; out << DoubleValue(1.0); EXPECT_EQ(out.str(), "1.0"); } { std::ostringstream out; out << DoubleValue(1.1); EXPECT_EQ(out.str(), "1.1"); } { std::ostringstream out; out << DoubleValue(NAN); EXPECT_EQ(out.str(), "nan"); } { std::ostringstream out; out << DoubleValue(INFINITY); EXPECT_EQ(out.str(), "+infinity"); } { std::ostringstream out; out << DoubleValue(-INFINITY); EXPECT_EQ(out.str(), "-infinity"); } { std::ostringstream out; out << Value(DoubleValue(0.0)); EXPECT_EQ(out.str(), "0.0"); } } TEST_F(DoubleValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(DoubleValue(1.0).ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(number_value: 1)pb")); } TEST_F(DoubleValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(DoubleValue(1.0)), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(DoubleValue(1.0))), NativeTypeId::For()); } TEST_F(DoubleValueTest, Equality) { EXPECT_NE(DoubleValue(0.0), 1.0); EXPECT_NE(1.0, DoubleValue(0.0)); EXPECT_NE(DoubleValue(0.0), DoubleValue(1.0)); } } // namespace } // namespace cel ================================================ FILE: common/values/duration_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "google/protobuf/duration.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/time.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::DurationReflection; using ::cel::well_known_types::ValueReflection; std::string DurationDebugString(absl::Duration value) { return internal::DebugStringDuration(value); } } // namespace std::string DurationValue::DebugString() const { return DurationDebugString(NativeValue()); } absl::Status DurationValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::Duration message; CEL_RETURN_IF_ERROR( DurationReflection::SetFromAbslDuration(&message, NativeValue())); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", message.GetTypeName())); } return absl::OkStatus(); } absl::Status DurationValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.SetStringValueFromDuration(json, NativeValue()); return absl::OkStatus(); } absl::Status DurationValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsDuration(); other_value.has_value()) { *result = BoolValue{NativeValue() == other_value->NativeValue()}; return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/duration_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_DURATION_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_DURATION_VALUE_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/utility/utility.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "internal/time.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class DurationValue; DurationValue UnsafeDurationValue(absl::Duration value); absl::StatusOr SafeDurationValue(absl::Duration value); // `DurationValue` represents values of the primitive `duration` type. class DurationValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kDuration; explicit DurationValue(absl::Duration value) noexcept : DurationValue(absl::in_place, value) { ABSL_DCHECK_OK(internal::ValidateDuration(value)); } DurationValue() = default; DurationValue(const DurationValue&) = default; DurationValue(DurationValue&&) = default; DurationValue& operator=(const DurationValue&) = default; DurationValue& operator=(DurationValue&&) = default; ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return DurationType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return ToDuration() == absl::ZeroDuration(); } ABSL_DEPRECATED("Use ToDuration()") absl::Duration NativeValue() const { return static_cast(*this); } ABSL_DEPRECATED("Use ToDuration()") // NOLINTNEXTLINE(google-explicit-constructor) operator absl::Duration() const noexcept { return value_; } absl::Duration ToDuration() const { return value_; } friend void swap(DurationValue& lhs, DurationValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); } friend bool operator==(DurationValue lhs, DurationValue rhs) { return lhs.value_ == rhs.value_; } friend bool operator<(const DurationValue& lhs, const DurationValue& rhs) { return lhs.value_ < rhs.value_; } private: friend class common_internal::ValueMixin; friend DurationValue UnsafeDurationValue(absl::Duration value); DurationValue(absl::in_place_t, absl::Duration value) : value_(value) {} absl::Duration value_ = absl::ZeroDuration(); }; inline DurationValue UnsafeDurationValue(absl::Duration value) { return DurationValue(absl::in_place, value); } inline absl::StatusOr SafeDurationValue(absl::Duration value) { absl::Status status = internal::ValidateDuration(value); if (!status.ok()) { return status; } return UnsafeDurationValue(value); } inline bool operator!=(DurationValue lhs, DurationValue rhs) { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, DurationValue value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_DURATION_VALUE_H_ ================================================ FILE: common/values/duration_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/status/status_matchers.h" #include "absl/time/time.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::testing::IsEmpty; using DurationValueTest = common_internal::ValueTest<>; TEST_F(DurationValueTest, Kind) { EXPECT_EQ(DurationValue().kind(), DurationValue::kKind); EXPECT_EQ(Value(DurationValue(absl::Seconds(1))).kind(), DurationValue::kKind); } TEST_F(DurationValueTest, DebugString) { { std::ostringstream out; out << DurationValue(absl::Seconds(1)); EXPECT_EQ(out.str(), "1s"); } { std::ostringstream out; out << Value(DurationValue(absl::Seconds(1))); EXPECT_EQ(out.str(), "1s"); } } TEST_F(DurationValueTest, SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(DurationValue().SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), IsEmpty()); } TEST_F(DurationValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(DurationValue().ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(string_value: "0s")pb")); } TEST_F(DurationValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(DurationValue(absl::Seconds(1))), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(DurationValue(absl::Seconds(1)))), NativeTypeId::For()); } TEST_F(DurationValueTest, Equality) { EXPECT_NE(DurationValue(absl::ZeroDuration()), absl::Seconds(1)); EXPECT_NE(absl::Seconds(1), DurationValue(absl::ZeroDuration())); EXPECT_NE(DurationValue(absl::ZeroDuration()), DurationValue(absl::Seconds(1))); } TEST_F(DurationValueTest, Comparison) { EXPECT_LT(DurationValue(absl::ZeroDuration()), absl::Seconds(1)); EXPECT_FALSE(DurationValue(absl::Seconds(1)) < DurationValue(absl::Seconds(1))); EXPECT_FALSE(DurationValue(absl::Seconds(2)) < DurationValue(absl::Seconds(1))); } } // namespace } // namespace cel ================================================ FILE: common/values/enum_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ #include #include "google/protobuf/struct.pb.h" #include "absl/meta/type_traits.h" #include "google/protobuf/generated_enum_util.h" namespace cel::common_internal { template > inline constexpr bool kIsWellKnownEnumType = std::is_same::value; template > inline constexpr bool kIsGeneratedEnum = google::protobuf::is_proto_enum::value; template using EnableIfWellKnownEnum = std::enable_if_t< kIsWellKnownEnumType && std::is_same, U>::value, R>; template using EnableIfGeneratedEnum = std::enable_if_t< absl::conjunction< std::bool_constant>, absl::negation>>>::value, R>; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ENUM_VALUE_H_ ================================================ FILE: common/values/error_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { std::string ErrorDebugString(const absl::Status& value) { ABSL_DCHECK(!value.ok()) << "use of moved-from ErrorValue"; return value.ToString(absl::StatusToStringMode::kWithEverything); } const absl::Status& DefaultErrorValue() { static const absl::NoDestructor value( absl::UnknownError("unknown error")); return *value; } } // namespace ErrorValue::ErrorValue() : ErrorValue(DefaultErrorValue()) {} ErrorValue NoSuchFieldError(absl::string_view field) { return ErrorValue(absl::NotFoundError( absl::StrCat("no_such_field", field.empty() ? "" : " : ", field))); } ErrorValue NoSuchKeyError(absl::string_view key) { return ErrorValue( absl::NotFoundError(absl::StrCat("Key not found in map : ", key))); } ErrorValue NoSuchTypeError(absl::string_view type) { return ErrorValue( absl::NotFoundError(absl::StrCat("type not found: ", type))); } ErrorValue DuplicateKeyError() { return ErrorValue(absl::AlreadyExistsError("duplicate key in map")); } ErrorValue TypeConversionError(absl::string_view from, absl::string_view to) { return ErrorValue(absl::InvalidArgumentError( absl::StrCat("type conversion error from '", from, "' to '", to, "'"))); } ErrorValue TypeConversionError(const Type& from, const Type& to) { return TypeConversionError(from.DebugString(), to.DebugString()); } ErrorValue IndexOutOfBoundsError(size_t index) { return ErrorValue( absl::InvalidArgumentError(absl::StrCat("index out of bounds: ", index))); } ErrorValue IndexOutOfBoundsError(ptrdiff_t index) { return ErrorValue( absl::InvalidArgumentError(absl::StrCat("index out of bounds: ", index))); } bool IsNoSuchField(const ErrorValue& value) { return absl::IsNotFound(value.NativeValue()) && absl::StartsWith(value.NativeValue().message(), "no_such_field"); } bool IsNoSuchKey(const ErrorValue& value) { return absl::IsNotFound(value.NativeValue()) && absl::StartsWith(value.NativeValue().message(), "Key not found in map"); } std::string ErrorValue::DebugString() const { return ErrorDebugString(NativeValue()); } absl::Status ErrorValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); ABSL_DCHECK(*this); return absl::FailedPreconditionError( absl::StrCat(GetTypeName(), " is unserializable")); } absl::Status ErrorValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ABSL_DCHECK(*this); return absl::FailedPreconditionError( absl::StrCat(GetTypeName(), " is not convertable to JSON")); } absl::Status ErrorValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(*this); *result = FalseValue(); return absl::OkStatus(); } ErrorValue ErrorValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(*this); if (arena_ == nullptr || arena_ != arena) { return ErrorValue(arena, google::protobuf::Arena::Create(arena, ToStatus())); } return *this; } absl::Status ErrorValue::ToStatus() const& { ABSL_DCHECK(*this); if (arena_ == nullptr) { return *std::launder( reinterpret_cast(&status_.val[0])); } return *status_.ptr; } absl::Status ErrorValue::ToStatus() && { ABSL_DCHECK(*this); if (arena_ == nullptr) { return std::move( *std::launder(reinterpret_cast(&status_.val[0]))); } return *status_.ptr; } ErrorValue::operator bool() const { if (arena_ == nullptr) { return !std::launder(reinterpret_cast(&status_.val[0])) ->ok(); } return status_.ptr != nullptr && !status_.ptr->ok(); } void swap(ErrorValue& lhs, ErrorValue& rhs) noexcept { ErrorValue tmp(std::move(lhs)); lhs = std::move(rhs); rhs = std::move(tmp); } } // namespace cel ================================================ FILE: common/values/error_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_ #include #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "common/arena.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; // `ErrorValue` represents values of the `ErrorType`. class ABSL_ATTRIBUTE_TRIVIAL_ABI ErrorValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kError; explicit ErrorValue(absl::Status value) : arena_(nullptr) { ::new (static_cast(&status_.val[0])) absl::Status(std::move(value)); ABSL_DCHECK(*this) << "ErrorValue requires a non-OK absl::Status"; } // By default, this creates an UNKNOWN error. You should always create a more // specific error value. ErrorValue(); ErrorValue(const ErrorValue& other) { CopyConstruct(other); } ErrorValue(ErrorValue&& other) noexcept { MoveConstruct(other); } ~ErrorValue() { Destruct(); } ErrorValue& operator=(const ErrorValue& other) { if (this != &other) { Destruct(); CopyConstruct(other); } return *this; } ErrorValue& operator=(ErrorValue&& other) noexcept { if (this != &other) { Destruct(); MoveConstruct(other); } return *this; } static constexpr ValueKind kind() { return kKind; } static absl::string_view GetTypeName() { return ErrorType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return false; } ErrorValue Clone(google::protobuf::Arena* absl_nonnull arena) const; absl::Status ToStatus() const&; absl::Status ToStatus() &&; ABSL_DEPRECATED("Use ToStatus()") absl::Status NativeValue() const& { return ToStatus(); } ABSL_DEPRECATED("Use ToStatus()") absl::Status NativeValue() && { return std::move(*this).ToStatus(); } friend void swap(ErrorValue& lhs, ErrorValue& rhs) noexcept; explicit operator bool() const; private: friend class common_internal::ValueMixin; friend struct ArenaTraits; ErrorValue(google::protobuf::Arena* absl_nonnull arena, const absl::Status* absl_nonnull status) : arena_(arena) { status_.ptr = status; } void CopyConstruct(const ErrorValue& other) { arena_ = other.arena_; if (arena_ == nullptr) { ::new (static_cast(&status_.val[0])) absl::Status(*std::launder( reinterpret_cast(&other.status_.val[0]))); } else { status_.ptr = other.status_.ptr; } } void MoveConstruct(ErrorValue& other) { arena_ = other.arena_; if (arena_ == nullptr) { ::new (static_cast(&status_.val[0])) absl::Status(std::move(*std::launder( reinterpret_cast(&other.status_.val[0])))); } else { status_.ptr = other.status_.ptr; } } void Destruct() { if (arena_ == nullptr) { std::launder(reinterpret_cast(&status_.val[0]))->~Status(); } } google::protobuf::Arena* absl_nullable arena_; union { alignas(absl::Status) char val[sizeof(absl::Status)]; const absl::Status* absl_nonnull ptr; } status_; }; ErrorValue NoSuchFieldError(absl::string_view field); ErrorValue NoSuchKeyError(absl::string_view key); ErrorValue NoSuchTypeError(absl::string_view type); ErrorValue DuplicateKeyError(); ErrorValue TypeConversionError(absl::string_view from, absl::string_view to); ErrorValue TypeConversionError(const Type& from, const Type& to); ErrorValue IndexOutOfBoundsError(size_t index); ErrorValue IndexOutOfBoundsError(ptrdiff_t index); // Catch other integrals and forward them to the above ones. This is needed to // avoid ambiguous overload issues for smaller integral types like `int`. template std::enable_if_t, std::is_unsigned, std::negation>>, ErrorValue> IndexOutOfBoundsError(T index) { static_assert(sizeof(T) <= sizeof(size_t)); return IndexOutOfBoundsError(static_cast(index)); } template std::enable_if_t, std::is_signed, std::negation>>, ErrorValue> IndexOutOfBoundsError(T index) { static_assert(sizeof(T) <= sizeof(ptrdiff_t)); return IndexOutOfBoundsError(static_cast(index)); } inline std::ostream& operator<<(std::ostream& out, const ErrorValue& value) { return out << value.DebugString(); } bool IsNoSuchField(const ErrorValue& value); bool IsNoSuchKey(const ErrorValue& value); class ErrorValueReturn final { public: ErrorValueReturn() = default; ErrorValue operator()(absl::Status status) const { return ErrorValue(std::move(status)); } }; namespace common_internal { struct ImplicitlyConvertibleStatus { // NOLINTNEXTLINE(google-explicit-constructor) operator absl::Status() const { return absl::OkStatus(); } template // NOLINTNEXTLINE(google-explicit-constructor) operator absl::StatusOr() const { return T(); } }; } // namespace common_internal // For use with `RETURN_IF_ERROR(...).With(cel::ErrorValueAssign(&result))` and // `ASSIGN_OR_RETURN(..., ..., _.With(cel::ErrorValueAssign(&result)))`. // // IMPORTANT: // If the returning type is `absl::Status` the result will be // `absl::OkStatus()`. If the returning type is `absl::StatusOr` the result // will be `T()`. class ErrorValueAssign final { public: ErrorValueAssign() = delete; explicit ErrorValueAssign(Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND) : ErrorValueAssign(std::addressof(value)) {} explicit ErrorValueAssign( Value* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_(value) { ABSL_DCHECK(value != nullptr); } common_internal::ImplicitlyConvertibleStatus operator()( absl::Status status) const; private: Value* absl_nonnull value_; }; template <> struct ArenaTraits { static bool trivially_destructible(const ErrorValue& value) { return value.arena_ != nullptr; } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_ERROR_VALUE_H_ ================================================ FILE: common/values/error_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/status/status.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::StatusIs; using ::testing::_; using ::testing::IsEmpty; using ::testing::Not; using ErrorValueTest = common_internal::ValueTest<>; TEST_F(ErrorValueTest, Default) { ErrorValue value; EXPECT_THAT(value.NativeValue(), StatusIs(absl::StatusCode::kUnknown)); } TEST_F(ErrorValueTest, OkStatus) { EXPECT_DEBUG_DEATH(static_cast(ErrorValue(absl::OkStatus())), _); } TEST_F(ErrorValueTest, Kind) { EXPECT_EQ(ErrorValue(absl::CancelledError()).kind(), ErrorValue::kKind); EXPECT_EQ(Value(ErrorValue(absl::CancelledError())).kind(), ErrorValue::kKind); } TEST_F(ErrorValueTest, DebugString) { { std::ostringstream out; out << ErrorValue(absl::CancelledError()); EXPECT_THAT(out.str(), Not(IsEmpty())); } { std::ostringstream out; out << Value(ErrorValue(absl::CancelledError())); EXPECT_THAT(out.str(), Not(IsEmpty())); } } TEST_F(ErrorValueTest, SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT( ErrorValue().SerializeTo(descriptor_pool(), message_factory(), &output), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(ErrorValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT( ErrorValue().ConvertToJson(descriptor_pool(), message_factory(), message), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(ErrorValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(ErrorValue(absl::CancelledError())), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(ErrorValue(absl::CancelledError()))), NativeTypeId::For()); } } // namespace } // namespace cel ================================================ FILE: common/values/int_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "common/value.h" #include "internal/number.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::ValueReflection; std::string IntDebugString(int64_t value) { return absl::StrCat(value); } } // namespace std::string IntValue::DebugString() const { return IntDebugString(NativeValue()); } absl::Status IntValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::Int64Value message; message.set_value(NativeValue()); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", message.GetTypeName())); } return absl::OkStatus(); } absl::Status IntValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.SetNumberValue(json, NativeValue()); return absl::OkStatus(); } absl::Status IntValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsInt(); other_value.has_value()) { *result = BoolValue{NativeValue() == other_value->NativeValue()}; return absl::OkStatus(); } if (auto other_value = other.AsDouble(); other_value.has_value()) { *result = BoolValue{internal::Number::FromInt64(NativeValue()) == internal::Number::FromDouble(other_value->NativeValue())}; return absl::OkStatus(); } if (auto other_value = other.AsUint(); other_value.has_value()) { *result = BoolValue{internal::Number::FromInt64(NativeValue()) == internal::Number::FromUint64(other_value->NativeValue())}; return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/int_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_INT_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_INT_VALUE_H_ #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class IntValue; // `IntValue` represents values of the primitive `int` type. class IntValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kInt; explicit IntValue(int64_t value) noexcept : value_(value) {} IntValue() = default; IntValue(const IntValue&) = default; IntValue(IntValue&&) = default; IntValue& operator=(const IntValue&) = default; IntValue& operator=(IntValue&&) = default; ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return IntType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return NativeValue() == 0; } int64_t NativeValue() const { return static_cast(*this); } // NOLINTNEXTLINE(google-explicit-constructor) operator int64_t() const noexcept { return value_; } friend void swap(IntValue& lhs, IntValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); } private: friend class common_internal::ValueMixin; int64_t value_ = 0; }; template H AbslHashValue(H state, IntValue value) { return H::combine(std::move(state), value.NativeValue()); } inline bool operator==(IntValue lhs, IntValue rhs) { return lhs.NativeValue() == rhs.NativeValue(); } inline bool operator!=(IntValue lhs, IntValue rhs) { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, IntValue value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_INT_VALUE_H_ ================================================ FILE: common/values/int_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/hash/hash.h" #include "absl/status/status_matchers.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using IntValueTest = common_internal::ValueTest<>; TEST_F(IntValueTest, Kind) { EXPECT_EQ(IntValue(1).kind(), IntValue::kKind); EXPECT_EQ(Value(IntValue(1)).kind(), IntValue::kKind); } TEST_F(IntValueTest, DebugString) { { std::ostringstream out; out << IntValue(1); EXPECT_EQ(out.str(), "1"); } { std::ostringstream out; out << Value(IntValue(1)); EXPECT_EQ(out.str(), "1"); } } TEST_F(IntValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT( IntValue(1).ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(number_value: 1)pb")); } TEST_F(IntValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(IntValue(1)), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(IntValue(1))), NativeTypeId::For()); } TEST_F(IntValueTest, HashValue) { EXPECT_EQ(absl::HashOf(IntValue(1)), absl::HashOf(int64_t{1})); } TEST_F(IntValueTest, Equality) { EXPECT_NE(IntValue(0), 1); EXPECT_NE(1, IntValue(0)); EXPECT_NE(IntValue(0), IntValue(1)); } TEST_F(IntValueTest, LessThan) { EXPECT_LT(IntValue(0), 1); EXPECT_LT(0, IntValue(1)); EXPECT_LT(IntValue(0), IntValue(1)); } } // namespace } // namespace cel ================================================ FILE: common/values/legacy_list_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/legacy_list_value.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/value.h" #include "common/values/list_value_builder.h" #include "common/values/values.h" #include "eval/public/cel_value.h" #include "internal/casts.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::common_internal { absl::Status LegacyListValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (auto list_value = other.AsList(); list_value.has_value()) { return ListValueEqual(*this, *list_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } bool IsLegacyListValue(const Value& value) { return value.variant_.Is(); } LegacyListValue GetLegacyListValue(const Value& value) { ABSL_DCHECK(IsLegacyListValue(value)); return value.variant_.Get(); } absl::optional AsLegacyListValue(const Value& value) { if (IsLegacyListValue(value)) { return GetLegacyListValue(value); } if (auto custom_list_value = value.AsCustomList(); custom_list_value) { NativeTypeId native_type_id = custom_list_value->GetTypeId(); if (native_type_id == NativeTypeId::For()) { return LegacyListValue( static_cast( cel::internal::down_cast( custom_list_value->interface()))); } else if (native_type_id == NativeTypeId::For()) { return LegacyListValue( static_cast( cel::internal::down_cast( custom_list_value->interface()))); } } return absl::nullopt; } } // namespace cel::common_internal ================================================ FILE: common/values/legacy_list_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/values/list_value.h" // IWYU pragma: friend "common/values/list_value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_LIST_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_LIST_VALUE_H_ #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/value_kind.h" #include "common/values/custom_list_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { class CelList; } namespace cel { class Value; namespace common_internal { class LegacyListValue; class LegacyListValue final : private common_internal::ListValueMixin { public: static constexpr ValueKind kKind = ValueKind::kList; explicit LegacyListValue( const google::api::expr::runtime::CelList* absl_nullability_unknown impl) : impl_(impl) {} // By default, this creates an empty list whose type is `list(dyn)`. Unless // you can help it, you should use a more specific typed list value. LegacyListValue() = default; LegacyListValue(const LegacyListValue&) = default; LegacyListValue(LegacyListValue&&) = default; LegacyListValue& operator=(const LegacyListValue&) = default; LegacyListValue& operator=(LegacyListValue&&) = default; constexpr ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return "list"; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonArray(). absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Equal; bool IsZeroValue() const { return IsEmpty(); } bool IsEmpty() const; size_t Size() const; // See ListValueInterface::Get for documentation. absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Get; using ForEachCallback = typename CustomListValueInterface::ForEachCallback; using ForEachWithIndexCallback = typename CustomListValueInterface::ForEachWithIndexCallback; absl::Status ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; using ListValueMixin::ForEach; absl::StatusOr NewIterator() const; absl::Status Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Contains; const google::api::expr::runtime::CelList* absl_nullability_unknown cel_list() const { return impl_; } friend void swap(LegacyListValue& lhs, LegacyListValue& rhs) noexcept { using std::swap; swap(lhs.impl_, rhs.impl_); } private: friend class common_internal::ValueMixin; friend class common_internal::ListValueMixin; const google::api::expr::runtime::CelList* absl_nullability_unknown impl_ = nullptr; }; inline std::ostream& operator<<(std::ostream& out, const LegacyListValue& type) { return out << type.DebugString(); } bool IsLegacyListValue(const Value& value); LegacyListValue GetLegacyListValue(const Value& value); absl::optional AsLegacyListValue(const Value& value); } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_LIST_VALUE_H_ ================================================ FILE: common/values/legacy_map_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/legacy_map_value.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/value.h" #include "common/values/map_value_builder.h" #include "common/values/values.h" #include "eval/public/cel_value.h" #include "internal/casts.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::common_internal { absl::Status LegacyMapValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (auto map_value = other.AsMap(); map_value.has_value()) { return MapValueEqual(*this, *map_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } bool IsLegacyMapValue(const Value& value) { return value.variant_.Is(); } LegacyMapValue GetLegacyMapValue(const Value& value) { ABSL_DCHECK(IsLegacyMapValue(value)); return value.variant_.Get(); } absl::optional AsLegacyMapValue(const Value& value) { if (IsLegacyMapValue(value)) { return GetLegacyMapValue(value); } if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { NativeTypeId native_type_id = NativeTypeId::Of(*custom_map_value); if (native_type_id == NativeTypeId::For()) { return LegacyMapValue( static_cast( cel::internal::down_cast( custom_map_value->interface()))); } else if (native_type_id == NativeTypeId::For()) { return LegacyMapValue( static_cast( cel::internal::down_cast( custom_map_value->interface()))); } } return absl::nullopt; } } // namespace cel::common_internal ================================================ FILE: common/values/legacy_map_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/values/map_value.h" // IWYU pragma: friend "common/values/map_value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_MAP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_MAP_VALUE_H_ #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/value_kind.h" #include "common/values/custom_map_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { class CelMap; } namespace cel { class Value; namespace common_internal { class LegacyMapValue; class LegacyMapValue final : private common_internal::MapValueMixin { public: static constexpr ValueKind kKind = ValueKind::kMap; explicit LegacyMapValue( const google::api::expr::runtime::CelMap* absl_nullability_unknown impl) : impl_(impl) {} // By default, this creates an empty map whose type is `map(dyn, dyn)`. // Unless you can help it, you should use a more specific typed map value. LegacyMapValue() = default; LegacyMapValue(const LegacyMapValue&) = default; LegacyMapValue(LegacyMapValue&&) = default; LegacyMapValue& operator=(const LegacyMapValue&) = default; LegacyMapValue& operator=(LegacyMapValue&&) = default; constexpr ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return "map"; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonObject(). absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Equal; bool IsZeroValue() const { return IsEmpty(); } bool IsEmpty() const; size_t Size() const; // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Get; // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Has; // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::StatusOr NewIterator() const; const google::api::expr::runtime::CelMap* absl_nonnull cel_map() const { return impl_; } friend void swap(LegacyMapValue& lhs, LegacyMapValue& rhs) noexcept { using std::swap; swap(lhs.impl_, rhs.impl_); } private: friend class common_internal::ValueMixin; friend class common_internal::MapValueMixin; const google::api::expr::runtime::CelMap* absl_nullability_unknown impl_ = nullptr; }; inline std::ostream& operator<<(std::ostream& out, const LegacyMapValue& type) { return out << type.DebugString(); } bool IsLegacyMapValue(const Value& value); LegacyMapValue GetLegacyMapValue(const Value& value); absl::optional AsLegacyMapValue(const Value& value); } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_MAP_VALUE_H_ ================================================ FILE: common/values/legacy_struct_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/log/absl_check.h" #include "absl/types/optional.h" #include "common/type.h" #include "common/value.h" #include "google/protobuf/message.h" namespace cel::common_internal { StructType LegacyStructValue::GetRuntimeType() const { return MessageType(message_ptr_->GetDescriptor()); } bool IsLegacyStructValue(const Value& value) { return value.variant_.Is(); } LegacyStructValue GetLegacyStructValue(const Value& value) { ABSL_DCHECK(IsLegacyStructValue(value)); return value.variant_.Get(); } absl::optional AsLegacyStructValue(const Value& value) { if (IsLegacyStructValue(value)) { return GetLegacyStructValue(value); } return absl::nullopt; } } // namespace cel::common_internal ================================================ FILE: common/values/legacy_struct_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_STRUCT_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_STRUCT_VALUE_H_ #include #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_struct_value.h" #include "common/values/values.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { class LegacyTypeInfoApis; } namespace cel { class Value; namespace common_internal { class LegacyStructValue; // `LegacyStructValue` is a wrapper around the old representation of protocol // buffer messages in `google::api::expr::runtime::CelValue`. It only supports // arena allocation. class LegacyStructValue final : private common_internal::StructValueMixin { public: static constexpr ValueKind kKind = ValueKind::kStruct; LegacyStructValue() = default; LegacyStructValue( const google::protobuf::Message* absl_nullability_unknown message_ptr, const google::api::expr::runtime:: LegacyTypeInfoApis* absl_nullability_unknown legacy_type_info) : message_ptr_(message_ptr), legacy_type_info_(legacy_type_info) {} LegacyStructValue(const LegacyStructValue&) = default; LegacyStructValue& operator=(const LegacyStructValue&) = default; constexpr ValueKind kind() const { return kKind; } StructType GetRuntimeType() const; absl::string_view GetTypeName() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonObject(). absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::Equal; bool IsZeroValue() const; absl::Status GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByName; absl::Status GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByNumber; absl::StatusOr HasFieldByName(absl::string_view name) const; absl::StatusOr HasFieldByNumber(int64_t number) const; using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; absl::Status ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::Status Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const; using StructValueMixin::Qualify; const google::protobuf::Message* absl_nullability_unknown message_ptr() const { return message_ptr_; } const google::api::expr::runtime::LegacyTypeInfoApis* absl_nullability_unknown legacy_type_info() const { return legacy_type_info_; } friend void swap(LegacyStructValue& lhs, LegacyStructValue& rhs) noexcept { using std::swap; swap(lhs.message_ptr_, rhs.message_ptr_); swap(lhs.legacy_type_info_, rhs.legacy_type_info_); } private: friend class common_internal::ValueMixin; friend class common_internal::StructValueMixin; const google::protobuf::Message* absl_nullability_unknown message_ptr_ = nullptr; const google::api::expr::runtime::LegacyTypeInfoApis* absl_nullability_unknown legacy_type_info_ = nullptr; }; inline std::ostream& operator<<(std::ostream& out, const LegacyStructValue& value) { return out << value.DebugString(); } bool IsLegacyStructValue(const Value& value); LegacyStructValue GetLegacyStructValue(const Value& value); absl::optional AsLegacyStructValue(const Value& value); } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_STRUCT_VALUE_H_ ================================================ FILE: common/values/list_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/value.h" #include "common/values/value_variant.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { NativeTypeId ListValue::GetTypeId() const { return variant_.Visit([](const auto& alternative) -> NativeTypeId { return NativeTypeId::Of(alternative); }); } std::string ListValue::DebugString() const { return variant_.Visit([](const auto& alternative) -> std::string { return alternative.DebugString(); }); } absl::Status ListValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.SerializeTo(descriptor_pool, message_factory, output); }); } absl::Status ListValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ConvertToJson(descriptor_pool, message_factory, json); }); } absl::Status ListValue::ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ConvertToJsonArray(descriptor_pool, message_factory, json); }); } absl::Status ListValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.Equal(other, descriptor_pool, message_factory, arena, result); }); } bool ListValue::IsZeroValue() const { return variant_.Visit([](const auto& alternative) -> bool { return alternative.IsZeroValue(); }); } absl::StatusOr ListValue::IsEmpty() const { return variant_.Visit([](const auto& alternative) -> absl::StatusOr { return alternative.IsEmpty(); }); } absl::StatusOr ListValue::Size() const { return variant_.Visit([](const auto& alternative) -> absl::StatusOr { return alternative.Size(); }); } absl::Status ListValue::Get( size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.Get(index, descriptor_pool, message_factory, arena, result); }); } absl::Status ListValue::ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ForEach(callback, descriptor_pool, message_factory, arena); }); } absl::StatusOr ListValue::NewIterator() const { return variant_.Visit([](const auto& alternative) -> absl::StatusOr { return alternative.NewIterator(); }); } absl::Status ListValue::Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.Contains(other, descriptor_pool, message_factory, arena, result); }); } namespace common_internal { absl::Status ListValueEqual( const ListValue& lhs, const ListValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); CEL_ASSIGN_OR_RETURN(auto lhs_size, lhs.Size()); CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); if (lhs_size != rhs_size) { *result = FalseValue(); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto lhs_iterator, lhs.NewIterator()); CEL_ASSIGN_OR_RETURN(auto rhs_iterator, rhs.NewIterator()); Value lhs_element; Value rhs_element; for (size_t index = 0; index < lhs_size; ++index) { ABSL_CHECK(lhs_iterator->HasNext()); // Crash OK ABSL_CHECK(rhs_iterator->HasNext()); // Crash OK CEL_RETURN_IF_ERROR(lhs_iterator->Next(descriptor_pool, message_factory, arena, &lhs_element)); CEL_RETURN_IF_ERROR(rhs_iterator->Next(descriptor_pool, message_factory, arena, &rhs_element)); CEL_RETURN_IF_ERROR(lhs_element.Equal(rhs_element, descriptor_pool, message_factory, arena, result)); if (result->IsFalse()) { return absl::OkStatus(); } } ABSL_DCHECK(!lhs_iterator->HasNext()); ABSL_DCHECK(!rhs_iterator->HasNext()); *result = TrueValue(); return absl::OkStatus(); } absl::Status ListValueEqual( const CustomListValueInterface& lhs, const ListValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); auto lhs_size = lhs.Size(); CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); if (lhs_size != rhs_size) { *result = FalseValue(); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto lhs_iterator, lhs.NewIterator()); CEL_ASSIGN_OR_RETURN(auto rhs_iterator, rhs.NewIterator()); Value lhs_element; Value rhs_element; for (size_t index = 0; index < lhs_size; ++index) { ABSL_CHECK(lhs_iterator->HasNext()); // Crash OK ABSL_CHECK(rhs_iterator->HasNext()); // Crash OK CEL_RETURN_IF_ERROR(lhs_iterator->Next(descriptor_pool, message_factory, arena, &lhs_element)); CEL_RETURN_IF_ERROR(rhs_iterator->Next(descriptor_pool, message_factory, arena, &rhs_element)); CEL_RETURN_IF_ERROR(lhs_element.Equal(rhs_element, descriptor_pool, message_factory, arena, result)); if (result->IsFalse()) { return absl::OkStatus(); } } ABSL_DCHECK(!lhs_iterator->HasNext()); ABSL_DCHECK(!rhs_iterator->HasNext()); *result = TrueValue(); return absl::OkStatus(); } } // namespace common_internal optional_ref ListValue::AsCustom() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional ListValue::AsCustom() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } const CustomListValue& ListValue::GetCustom() const& { ABSL_DCHECK(IsCustom()); return variant_.Get(); } CustomListValue ListValue::GetCustom() && { ABSL_DCHECK(IsCustom()); return std::move(variant_).Get(); } common_internal::ValueVariant ListValue::ToValueVariant() const& { return variant_.Visit( [](const auto& alternative) -> common_internal::ValueVariant { return common_internal::ValueVariant(alternative); }); } common_internal::ValueVariant ListValue::ToValueVariant() && { return std::move(variant_).Visit( [](auto&& alternative) -> common_internal::ValueVariant { // NOLINTNEXTLINE(bugprone-move-forwarding-reference) return common_internal::ValueVariant(std::move(alternative)); }); } } // namespace cel ================================================ FILE: common/values/list_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" // `ListValue` represents values of the primitive `list` type. // `ListValueInterface` is the abstract base class of implementations. // `ListValue` acts as a smart pointer to `ListValueInterface`. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/meta/type_traits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/utility/utility.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/value_kind.h" #include "common/values/custom_list_value.h" #include "common/values/legacy_list_value.h" #include "common/values/list_value_variant.h" #include "common/values/parsed_json_list_value.h" #include "common/values/parsed_repeated_field_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class ListValueInterface; class ListValue; class Value; class ListValue final : private common_internal::ListValueMixin { public: static constexpr ValueKind kKind = ValueKind::kList; // Move constructor for alternative struct values. template < typename T, typename = std::enable_if_t< common_internal::IsListValueAlternativeV>>> // NOLINTNEXTLINE(google-explicit-constructor) ListValue(T&& value) : variant_(absl::in_place_type>, std::forward(value)) {} ListValue() = default; ListValue(const ListValue&) = default; ListValue(ListValue&&) = default; ListValue& operator=(const ListValue&) = default; ListValue& operator=(ListValue&&) = default; static constexpr ValueKind kind() { return kKind; } static absl::string_view GetTypeName() { return "list"; } NativeTypeId GetTypeId() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // Like ConvertToJson(), except `json` **MUST** be an instance of // `google.protobuf.ListValue`. absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Equal; bool IsZeroValue() const; absl::StatusOr IsEmpty() const; absl::StatusOr Size() const; // See ListValueInterface::Get for documentation. absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Get; using ForEachCallback = typename CustomListValueInterface::ForEachCallback; using ForEachWithIndexCallback = typename CustomListValueInterface::ForEachWithIndexCallback; absl::Status ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; using ListValueMixin::ForEach; absl::StatusOr NewIterator() const; absl::Status Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Contains; // Returns `true` if this value is an instance of a custom list value. bool IsCustom() const { return variant_.Is(); } // Convenience method for use with template metaprogramming. See // `IsParsed()`. template std::enable_if_t, bool> Is() const { return IsCustom(); } // Performs a checked cast from a value to a custom list value, // returning a non-empty optional with either a value or reference to the // custom list value. Otherwise an empty optional is returned. optional_ref AsCustom() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsCustom(); } optional_ref AsCustom() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsCustom() &&; absl::optional AsCustom() const&& { return common_internal::AsOptional(AsCustom()); } // Convenience method for use with template metaprogramming. See // `AsCustom()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustom(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustom(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsCustom(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsCustom(); } // Performs an unchecked cast from a value to a custom list value. In // debug builds a best effort is made to crash. If `IsCustom()` would // return false, calling this method is undefined behavior. const CustomListValue& GetCustom() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetCustom(); } const CustomListValue& GetCustom() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; CustomListValue GetCustom() &&; CustomListValue GetCustom() const&& { return GetCustom(); } // Convenience method for use with template metaprogramming. See // `GetCustom()`. template std::enable_if_t, const CustomListValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustom(); } template std::enable_if_t, const CustomListValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustom(); } template std::enable_if_t, CustomListValue> Get() && { return std::move(*this).GetCustom(); } template std::enable_if_t, CustomListValue> Get() const&& { return std::move(*this).GetCustom(); } friend void swap(ListValue& lhs, ListValue& rhs) noexcept { using std::swap; swap(lhs.variant_, rhs.variant_); } private: friend class Value; friend class common_internal::ValueMixin; friend class common_internal::ListValueMixin; common_internal::ValueVariant ToValueVariant() const&; common_internal::ValueVariant ToValueVariant() &&; // Unlike many of the other derived values, `ListValue` is itself a composed // type. This is to avoid making `ListValue` too big and by extension // `Value` too big. Instead we store the derived `ListValue` values in // `Value` and not `ListValue` itself. common_internal::ListValueVariant variant_; }; inline std::ostream& operator<<(std::ostream& out, const ListValue& value) { return out << value.DebugString(); } template <> struct NativeTypeTraits final { static NativeTypeId Id(const ListValue& value) { return value.GetTypeId(); } }; class ListValueBuilder { public: virtual ~ListValueBuilder() = default; virtual absl::Status Add(Value value) = 0; virtual void UnsafeAdd(Value value) = 0; virtual bool IsEmpty() const { return Size() == 0; } virtual size_t Size() const = 0; virtual void Reserve(size_t capacity [[maybe_unused]]) {} virtual ListValue Build() && = 0; }; using ListValueBuilderPtr = std::unique_ptr; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_H_ ================================================ FILE: common/values/list_value_builder.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_BUILDER_H_ #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/native_type.h" #include "common/value.h" #include "eval/public/cel_value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { class ValueFactory; namespace common_internal { // Special implementation of list which is both a modern list and legacy list. // Do not try this at home. This should only be implemented in // `list_value_builder.cc`. class CompatListValue : public CustomListValueInterface, public google::api::expr::runtime::CelList { private: NativeTypeId GetNativeTypeId() const final { return NativeTypeId::For(); } }; const CompatListValue* absl_nonnull EmptyCompatListValue(); absl::StatusOr MakeCompatListValue( const CustomListValue& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); // Extension of ParsedListValueInterface which is also mutable. Accessing this // like a normal list before all elements are finished being appended is a bug. // This is primarily used by the runtime to efficiently implement comprehensions // which accumulate results into a list. // // IMPORTANT: This type is only meant to be utilized by the runtime. class MutableListValue : public CustomListValueInterface { public: virtual absl::Status Append(Value value) const = 0; virtual void Reserve(size_t capacity) const {} private: NativeTypeId GetNativeTypeId() const override { return NativeTypeId::For(); } }; // Special implementation of list which is both a modern list, legacy list, and // mutable. // // NOTE: We do not extend CompatListValue to avoid having to use virtual // inheritance and `dynamic_cast`. class MutableCompatListValue : public MutableListValue, public google::api::expr::runtime::CelList { private: NativeTypeId GetNativeTypeId() const final { return NativeTypeId::For(); } }; MutableListValue* absl_nonnull NewMutableListValue( google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); bool IsMutableListValue(const Value& value); bool IsMutableListValue(const ListValue& value); const MutableListValue* absl_nullable AsMutableListValue( const Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); const MutableListValue* absl_nullable AsMutableListValue( const ListValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); const MutableListValue& GetMutableListValue( const Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); const MutableListValue& GetMutableListValue( const ListValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); absl_nonnull cel::ListValueBuilderPtr NewListValueBuilder( google::protobuf::Arena* absl_nonnull arena); } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_BUILDER_H_ ================================================ FILE: common/values/list_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "common/casting.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::ErrorValueIs; using ::testing::ElementsAreArray; class ListValueTest : public common_internal::ValueTest<> { public: template absl::StatusOr NewIntListValue(Args&&... args) { auto builder = NewListValueBuilder(arena()); (static_cast(builder->Add(std::forward(args))), ...); return std::move(*builder).Build(); } }; TEST_F(ListValueTest, Default) { ListValue value; EXPECT_THAT(value.IsEmpty(), IsOkAndHolds(true)); EXPECT_THAT(value.Size(), IsOkAndHolds(0)); EXPECT_EQ(value.DebugString(), "[]"); } TEST_F(ListValueTest, Kind) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); EXPECT_EQ(value.kind(), ListValue::kKind); EXPECT_EQ(Value(value).kind(), ListValue::kKind); } TEST_F(ListValueTest, DebugString) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); { std::ostringstream out; out << value; EXPECT_EQ(out.str(), "[0, 1, 2]"); } { std::ostringstream out; out << Value(value); EXPECT_EQ(out.str(), "[0, 1, 2]"); } } TEST_F(ListValueTest, IsEmpty) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); EXPECT_THAT(value.IsEmpty(), IsOkAndHolds(false)); } TEST_F(ListValueTest, Size) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); EXPECT_THAT(value.Size(), IsOkAndHolds(3)); } TEST_F(ListValueTest, Get) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); ASSERT_OK_AND_ASSIGN(auto element, value.Get(0, descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(element)); ASSERT_EQ(Cast(element).NativeValue(), 0); ASSERT_OK_AND_ASSIGN( element, value.Get(1, descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(element)); ASSERT_EQ(Cast(element).NativeValue(), 1); ASSERT_OK_AND_ASSIGN( element, value.Get(2, descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(element)); ASSERT_EQ(Cast(element).NativeValue(), 2); EXPECT_THAT( value.Get(3, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); } TEST_F(ListValueTest, ForEach) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); std::vector elements; EXPECT_THAT(value.ForEach( [&elements](const Value& element) { elements.push_back(Cast(element).NativeValue()); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(elements, ElementsAreArray({0, 1, 2})); } TEST_F(ListValueTest, Contains) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); ASSERT_OK_AND_ASSIGN(auto contained, value.Contains(IntValue(2), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(contained)); EXPECT_TRUE(Cast(contained).NativeValue()); ASSERT_OK_AND_ASSIGN(contained, value.Contains(IntValue(3), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(contained)); EXPECT_FALSE(Cast(contained).NativeValue()); } TEST_F(ListValueTest, NewIterator) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); std::vector elements; while (iterator->HasNext()) { ASSERT_OK_AND_ASSIGN( auto element, iterator->Next(descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(element)); elements.push_back(Cast(element).NativeValue()); } EXPECT_EQ(iterator->HasNext(), false); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); EXPECT_THAT(elements, ElementsAreArray({0, 1, 2})); } TEST_F(ListValueTest, ConvertToJson) { ASSERT_OK_AND_ASSIGN(auto value, NewIntListValue(IntValue(0), IntValue(1), IntValue(2))); auto* message = NewArenaValueMessage(); EXPECT_THAT( value.ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(list_value: { values: { number_value: 0 } values: { number_value: 1 } values: { number_value: 2 } })pb")); } } // namespace } // namespace cel ================================================ FILE: common/values/list_value_variant.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_VARIANT_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_VARIANT_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/meta/type_traits.h" #include "absl/utility/utility.h" #include "common/values/custom_list_value.h" #include "common/values/legacy_list_value.h" #include "common/values/parsed_json_list_value.h" #include "common/values/parsed_repeated_field_value.h" namespace cel::common_internal { enum class ListValueIndex : uint16_t { kCustom = 0, kParsedField, kParsedJson, kLegacy, }; template struct ListValueAlternative; template <> struct ListValueAlternative { static constexpr ListValueIndex kIndex = ListValueIndex::kCustom; }; template <> struct ListValueAlternative { static constexpr ListValueIndex kIndex = ListValueIndex::kParsedField; }; template <> struct ListValueAlternative { static constexpr ListValueIndex kIndex = ListValueIndex::kParsedJson; }; template <> struct ListValueAlternative { static constexpr ListValueIndex kIndex = ListValueIndex::kLegacy; }; template struct IsListValueAlternative : std::false_type {}; template struct IsListValueAlternative{})>> : std::true_type {}; template inline constexpr bool IsListValueAlternativeV = IsListValueAlternative::value; inline constexpr size_t kListValueVariantAlign = 8; inline constexpr size_t kListValueVariantSize = 24; // ListValueVariant is a subset of alternatives from the main ValueVariant that // is only lists. It is not stored directly in ValueVariant. class alignas(kListValueVariantAlign) ListValueVariant final { public: ListValueVariant() : ListValueVariant(absl::in_place_type) {} ListValueVariant(const ListValueVariant&) = default; ListValueVariant(ListValueVariant&&) = default; ListValueVariant& operator=(const ListValueVariant&) = default; ListValueVariant& operator=(ListValueVariant&&) = default; template explicit ListValueVariant(absl::in_place_type_t, Args&&... args) : index_(ListValueAlternative::kIndex) { static_assert(alignof(T) <= kListValueVariantAlign); static_assert(sizeof(T) <= kListValueVariantSize); static_assert(std::is_trivially_copyable_v); ::new (static_cast(&raw_[0])) T(std::forward(args)...); } template >>> explicit ListValueVariant(T&& value) : ListValueVariant(absl::in_place_type>, std::forward(value)) {} template void Assign(T&& value) { using U = absl::remove_cvref_t; static_assert(alignof(U) <= kListValueVariantAlign); static_assert(sizeof(U) <= kListValueVariantSize); static_assert(std::is_trivially_copyable_v); index_ = ListValueAlternative::kIndex; ::new (static_cast(&raw_[0])) U(std::forward(value)); } template bool Is() const { return index_ == ListValueAlternative::kIndex; } template T& Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return *At(); } template const T& Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return *At(); } template T&& Get() && ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return std::move(*At()); } template const T&& Get() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return std::move(*At()); } template T* absl_nullable As() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (Is()) { return At(); } return nullptr; } template const T* absl_nullable As() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (Is()) { return At(); } return nullptr; } template decltype(auto) Visit(Visitor&& visitor) const { switch (index_) { case ListValueIndex::kCustom: return std::forward(visitor)(Get()); case ListValueIndex::kParsedField: return std::forward(visitor)(Get()); case ListValueIndex::kParsedJson: return std::forward(visitor)(Get()); case ListValueIndex::kLegacy: return std::forward(visitor)(Get()); } } friend void swap(ListValueVariant& lhs, ListValueVariant& rhs) noexcept { using std::swap; swap(lhs.index_, rhs.index_); swap(lhs.raw_, rhs.raw_); } private: template ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull At() ABSL_ATTRIBUTE_LIFETIME_BOUND { static_assert(alignof(T) <= kListValueVariantAlign); static_assert(sizeof(T) <= kListValueVariantSize); static_assert(std::is_trivially_copyable_v); return std::launder(reinterpret_cast(&raw_[0])); } template ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull At() const ABSL_ATTRIBUTE_LIFETIME_BOUND { static_assert(alignof(T) <= kListValueVariantAlign); static_assert(sizeof(T) <= kListValueVariantSize); static_assert(std::is_trivially_copyable_v); return std::launder(reinterpret_cast(&raw_[0])); } ListValueIndex index_ = ListValueIndex::kCustom; alignas(8) std::byte raw_[kListValueVariantSize]; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_LIST_VALUE_VARIANT_H_ ================================================ FILE: common/values/map_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/value.h" #include "common/value_kind.h" #include "common/values/value_variant.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { absl::Status InvalidMapKeyTypeError(ValueKind kind) { return absl::InvalidArgumentError( absl::StrCat("Invalid map key type: '", ValueKindToString(kind), "'")); } } // namespace NativeTypeId MapValue::GetTypeId() const { return variant_.Visit([](const auto& alternative) -> NativeTypeId { return NativeTypeId::Of(alternative); }); } std::string MapValue::DebugString() const { return variant_.Visit([](const auto& alternative) -> std::string { return alternative.DebugString(); }); } absl::Status MapValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.SerializeTo(descriptor_pool, message_factory, output); }); } absl::Status MapValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ConvertToJson(descriptor_pool, message_factory, json); }); } absl::Status MapValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }); } absl::Status MapValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.Equal(other, descriptor_pool, message_factory, arena, result); }); } bool MapValue::IsZeroValue() const { return variant_.Visit([](const auto& alternative) -> bool { return alternative.IsZeroValue(); }); } absl::StatusOr MapValue::IsEmpty() const { return variant_.Visit([](const auto& alternative) -> absl::StatusOr { return alternative.IsEmpty(); }); } absl::StatusOr MapValue::Size() const { return variant_.Visit([](const auto& alternative) -> absl::StatusOr { return alternative.Size(); }); } absl::Status MapValue::Get( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.Get(key, descriptor_pool, message_factory, arena, result); }); } absl::StatusOr MapValue::Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::StatusOr { return alternative.Find(key, descriptor_pool, message_factory, arena, result); }); } absl::Status MapValue::Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.Has(key, descriptor_pool, message_factory, arena, result); }); } absl::Status MapValue::ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ListKeys(descriptor_pool, message_factory, arena, result); }); } absl::Status MapValue::ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ForEach(callback, descriptor_pool, message_factory, arena); }); } absl::StatusOr MapValue::NewIterator() const { return variant_.Visit([](const auto& alternative) -> absl::StatusOr { return alternative.NewIterator(); }); } namespace common_internal { absl::Status MapValueEqual( const MapValue& lhs, const MapValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); CEL_ASSIGN_OR_RETURN(auto lhs_size, lhs.Size()); CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); if (lhs_size != rhs_size) { *result = FalseValue(); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto lhs_iterator, lhs.NewIterator()); Value lhs_key; Value lhs_value; Value rhs_value; for (size_t index = 0; index < lhs_size; ++index) { ABSL_CHECK(lhs_iterator->HasNext()); // Crash OK CEL_RETURN_IF_ERROR( lhs_iterator->Next(descriptor_pool, message_factory, arena, &lhs_key)); bool rhs_value_found; CEL_ASSIGN_OR_RETURN( rhs_value_found, rhs.Find(lhs_key, descriptor_pool, message_factory, arena, &rhs_value)); if (!rhs_value_found) { *result = FalseValue(); return absl::OkStatus(); } CEL_RETURN_IF_ERROR( lhs.Get(lhs_key, descriptor_pool, message_factory, arena, &lhs_value)); CEL_RETURN_IF_ERROR(lhs_value.Equal(rhs_value, descriptor_pool, message_factory, arena, result)); if (result->IsFalse()) { return absl::OkStatus(); } } ABSL_DCHECK(!lhs_iterator->HasNext()); *result = TrueValue(); return absl::OkStatus(); } absl::Status MapValueEqual( const CustomMapValueInterface& lhs, const MapValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); auto lhs_size = lhs.Size(); CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); if (lhs_size != rhs_size) { *result = FalseValue(); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto lhs_iterator, lhs.NewIterator()); Value lhs_key; Value lhs_value; Value rhs_value; for (size_t index = 0; index < lhs_size; ++index) { ABSL_CHECK(lhs_iterator->HasNext()); // Crash OK CEL_RETURN_IF_ERROR( lhs_iterator->Next(descriptor_pool, message_factory, arena, &lhs_key)); bool rhs_value_found; CEL_ASSIGN_OR_RETURN( rhs_value_found, rhs.Find(lhs_key, descriptor_pool, message_factory, arena, &rhs_value)); if (!rhs_value_found) { *result = FalseValue(); return absl::OkStatus(); } CEL_RETURN_IF_ERROR( CustomMapValue(&lhs, arena) .Get(lhs_key, descriptor_pool, message_factory, arena, &lhs_value)); CEL_RETURN_IF_ERROR(lhs_value.Equal(rhs_value, descriptor_pool, message_factory, arena, result)); if (result->IsFalse()) { return absl::OkStatus(); } } ABSL_DCHECK(!lhs_iterator->HasNext()); *result = TrueValue(); return absl::OkStatus(); } } // namespace common_internal absl::Status CheckMapKey(const Value& key) { switch (key.kind()) { case ValueKind::kBool: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kInt: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUint: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kString: return absl::OkStatus(); case ValueKind::kError: return key.GetError().NativeValue(); default: return InvalidMapKeyTypeError(key.kind()); } } optional_ref MapValue::AsCustom() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional MapValue::AsCustom() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } const CustomMapValue& MapValue::GetCustom() const& { ABSL_DCHECK(IsCustom()); return variant_.Get(); } CustomMapValue MapValue::GetCustom() && { ABSL_DCHECK(IsCustom()); return std::move(variant_).Get(); } common_internal::ValueVariant MapValue::ToValueVariant() const& { return variant_.Visit( [](const auto& alternative) -> common_internal::ValueVariant { return common_internal::ValueVariant(alternative); }); } common_internal::ValueVariant MapValue::ToValueVariant() && { return std::move(variant_).Visit( [](auto&& alternative) -> common_internal::ValueVariant { // NOLINTNEXTLINE(bugprone-move-forwarding-reference) return common_internal::ValueVariant(std::move(alternative)); }); } } // namespace cel ================================================ FILE: common/values/map_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" // `MapValue` represents values of the primitive `map` type. It provides a // unified interface for accessing map contents, regardless of the underlying // implementation (e.g., JSON, protobuf map field, or custom implementation). // // Public member functions: // - `IsEmpty()` / `Size()`: Query map size. // - `Get()` / `Find()` / `Has()`: Access entries by key. // - `ListKeys()` / `NewIterator()` / `ForEach()`: Iterate over entries. // - `ConvertToJson()` / `ConvertToJsonObject()`: JSON conversion. // - `IsCustom()` / `AsCustom()` / `GetCustom()`: Access custom implementation. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/meta/type_traits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/utility/utility.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/value_kind.h" #include "common/values/custom_map_value.h" #include "common/values/legacy_map_value.h" #include "common/values/map_value_variant.h" #include "common/values/parsed_json_map_value.h" #include "common/values/parsed_map_field_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class MapValue; class Value; absl::Status CheckMapKey(const Value& key); class MapValue final : private common_internal::MapValueMixin { public: static constexpr ValueKind kKind = ValueKind::kMap; // Move constructor for alternative struct values. template >>> // NOLINTNEXTLINE(google-explicit-constructor) MapValue(T&& value) : variant_(absl::in_place_type>, std::forward(value)) {} MapValue() = default; MapValue(const MapValue&) = default; MapValue(MapValue&&) = default; MapValue& operator=(const MapValue&) = default; MapValue& operator=(MapValue&&) = default; constexpr ValueKind kind() const { return kKind; } static absl::string_view GetTypeName() { return "map"; } NativeTypeId GetTypeId() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // Like ConvertToJson(), except `json` **MUST** be an instance of // `google.protobuf.Struct`. absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Equal; bool IsZeroValue() const; absl::StatusOr IsEmpty() const; absl::StatusOr Size() const; // `Get` sets the value `result` to (via `result`) the value associated with // `key`. If `key` is not found, `no such key` is set to `result`. If an error // occurs (e.g., invalid key type), an `no such key` is returned. // // A non-ok status may be returned if an unexpected error is encountered or to // propagate an error from a custom implementation, in which case `result` is // unspecified. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Get; // `Find` returns `true` if `key` is found in the map, and stores the // associated value in `result`. If `key` is not found, `false` is returned // and `result` is unchanged. // // A non-ok status may be returned if an unexpected error is encountered or to // propagate an error from a custom implementation, in which case `result` is // unspecified. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; // `Has` returns `true` if `key` is found in the map, and stores the BoolValue // result in `result`. In case of an error, the result is set to an // ErrorValue. // // A non-ok status may be returned if an unexpected error is encountered or to // propagate an error from a custom implementation, in which case `result` is // unspecified. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Has; // `ListKeys` returns a `ListValue` containing all keys in the map. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; // `ForEachCallback` is the callback type for `ForEach`. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; // `ForEach` calls `callback` for each entry in the map. Iteration continues // until all entries are visited or `callback` returns an error or `false`. absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; // `NewIterator` returns a new iterator for the map. absl::StatusOr NewIterator() const; // Returns `true` if this value is an instance of a custom map value. bool IsCustom() const { return variant_.Is(); } // Convenience method for use with template metaprogramming. See // `IsCustom()`. template std::enable_if_t, bool> Is() const { return IsCustom(); } // Performs a checked cast from a value to a custom map value, // returning a non-empty optional with either a value or reference to the // custom map value. Otherwise an empty optional is returned. optional_ref AsCustom() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsCustom(); } optional_ref AsCustom() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsCustom() &&; absl::optional AsCustom() const&& { return common_internal::AsOptional(AsCustom()); } // Convenience method for use with template metaprogramming. See // `AsCustom()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustom(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsCustom(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsCustom(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsCustom(); } // Performs an unchecked cast from a value to a custom map value. In // debug builds a best effort is made to crash. If `IsCustom()` would // return false, calling this method is undefined behavior. const CustomMapValue& GetCustom() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetCustom(); } const CustomMapValue& GetCustom() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; CustomMapValue GetCustom() &&; CustomMapValue GetCustom() const&& { return GetCustom(); } // Convenience method for use with template metaprogramming. See // `GetCustom()`. template std::enable_if_t, const CustomMapValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustom(); } template std::enable_if_t, const CustomMapValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetCustom(); } template std::enable_if_t, CustomMapValue> Get() && { return std::move(*this).GetCustom(); } template std::enable_if_t, CustomMapValue> Get() const&& { return std::move(*this).GetCustom(); } friend void swap(MapValue& lhs, MapValue& rhs) noexcept { using std::swap; swap(lhs.variant_, rhs.variant_); } private: friend class Value; friend class common_internal::ValueMixin; friend class common_internal::MapValueMixin; common_internal::ValueVariant ToValueVariant() const&; common_internal::ValueVariant ToValueVariant() &&; // Unlike many of the other derived values, `MapValue` is itself a composed // type. This is to avoid making `MapValue` too big and by extension // `Value` too big. Instead we store the derived `MapValue` values in // `Value` and not `MapValue` itself. common_internal::MapValueVariant variant_; }; inline std::ostream& operator<<(std::ostream& out, const MapValue& value) { return out << value.DebugString(); } template <> struct NativeTypeTraits final { static NativeTypeId Id(const MapValue& value) { return value.GetTypeId(); } }; class MapValueBuilder { public: virtual ~MapValueBuilder() = default; virtual absl::Status Put(Value key, Value value) = 0; virtual void UnsafePut(Value key, Value value) = 0; virtual bool IsEmpty() const { return Size() == 0; } virtual size_t Size() const = 0; virtual void Reserve(size_t capacity [[maybe_unused]]) {} virtual MapValue Build() && = 0; }; using MapValueBuilderPtr = std::unique_ptr; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ ================================================ FILE: common/values/map_value_builder.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_BUILDER_H_ #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/native_type.h" #include "common/value.h" #include "eval/public/cel_value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { class ValueFactory; namespace common_internal { // Special implementation of map which is both a modern map and legacy map. Do // not try this at home. This should only be implemented in // `map_value_builder.cc`. class CompatMapValue : public CustomMapValueInterface, public google::api::expr::runtime::CelMap { private: NativeTypeId GetNativeTypeId() const final { return NativeTypeId::For(); } }; const CompatMapValue* absl_nonnull EmptyCompatMapValue(); absl::StatusOr MakeCompatMapValue( const CustomMapValue& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); // Extension of ParsedMapValueInterface which is also mutable. Accessing this // like a normal map before all entries are finished being inserted is a bug. // This is primarily used by the runtime to efficiently implement comprehensions // which accumulate results into a map. // // IMPORTANT: This type is only meant to be utilized by the runtime. class MutableMapValue : public CustomMapValueInterface { public: virtual absl::Status Put(Value key, Value value) const = 0; virtual void Reserve(size_t capacity) const {} private: NativeTypeId GetNativeTypeId() const override { return NativeTypeId::For(); } }; // Special implementation of map which is both a modern map, legacy map, and // mutable. // // NOTE: We do not extend CompatMapValue to avoid having to use virtual // inheritance and `dynamic_cast`. class MutableCompatMapValue : public MutableMapValue, public google::api::expr::runtime::CelMap { private: NativeTypeId GetNativeTypeId() const final { return NativeTypeId::For(); } }; MutableMapValue* absl_nonnull NewMutableMapValue( google::protobuf::Arena* absl_nonnull arena); bool IsMutableMapValue(const Value& value); bool IsMutableMapValue(const MapValue& value); const MutableMapValue* absl_nullable AsMutableMapValue( const Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); const MutableMapValue* absl_nullable AsMutableMapValue( const MapValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); const MutableMapValue& GetMutableMapValue( const Value& value ABSL_ATTRIBUTE_LIFETIME_BOUND); const MutableMapValue& GetMutableMapValue( const MapValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND); absl_nonnull cel::MapValueBuilderPtr NewMapValueBuilder( google::protobuf::Arena* absl_nonnull arena); } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_BUILDER_H_ ================================================ FILE: common/values/map_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "common/casting.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::ErrorValueIs; using ::testing::IsEmpty; using ::testing::Not; using ::testing::UnorderedElementsAreArray; TEST(MapValue, CheckKey) { EXPECT_THAT(CheckMapKey(BoolValue()), IsOk()); EXPECT_THAT(CheckMapKey(IntValue()), IsOk()); EXPECT_THAT(CheckMapKey(UintValue()), IsOk()); EXPECT_THAT(CheckMapKey(StringValue()), IsOk()); EXPECT_THAT(CheckMapKey(BytesValue()), StatusIs(absl::StatusCode::kInvalidArgument)); } class MapValueTest : public common_internal::ValueTest<> { public: template absl::StatusOr NewIntDoubleMapValue(Args&&... args) { auto builder = NewMapValueBuilder(arena()); (static_cast(builder->Put(std::forward(args).first, std::forward(args).second)), ...); return std::move(*builder).Build(); } template absl::StatusOr NewJsonMapValue(Args&&... args) { auto builder = NewMapValueBuilder(arena()); (static_cast(builder->Put(std::forward(args).first, std::forward(args).second)), ...); return std::move(*builder).Build(); } }; TEST_F(MapValueTest, Default) { MapValue map_value; EXPECT_THAT(map_value.IsEmpty(), IsOkAndHolds(true)); EXPECT_THAT(map_value.Size(), IsOkAndHolds(0)); EXPECT_EQ(map_value.DebugString(), "{}"); ASSERT_OK_AND_ASSIGN( auto list_value, map_value.ListKeys(descriptor_pool(), message_factory(), arena())); EXPECT_THAT(list_value.IsEmpty(), IsOkAndHolds(true)); EXPECT_THAT(list_value.Size(), IsOkAndHolds(0)); EXPECT_EQ(list_value.DebugString(), "[]"); ASSERT_OK_AND_ASSIGN(auto iterator, map_value.NewIterator()); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(MapValueTest, Kind) { ASSERT_OK_AND_ASSIGN( auto value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); EXPECT_EQ(value.kind(), MapValue::kKind); EXPECT_EQ(Value(value).kind(), MapValue::kKind); } TEST_F(MapValueTest, DebugString) { ASSERT_OK_AND_ASSIGN( auto value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); { std::ostringstream out; out << value; EXPECT_THAT(out.str(), Not(IsEmpty())); } { std::ostringstream out; out << Value(value); EXPECT_THAT(out.str(), Not(IsEmpty())); } } TEST_F(MapValueTest, IsEmpty) { ASSERT_OK_AND_ASSIGN( auto value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); EXPECT_THAT(value.IsEmpty(), IsOkAndHolds(false)); } TEST_F(MapValueTest, Size) { ASSERT_OK_AND_ASSIGN( auto value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); EXPECT_THAT(value.Size(), IsOkAndHolds(3)); } TEST_F(MapValueTest, Get) { ASSERT_OK_AND_ASSIGN( auto map_value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); ASSERT_OK_AND_ASSIGN(auto value, map_value.Get(IntValue(0), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(value)); ASSERT_EQ(Cast(value).NativeValue(), 3.0); ASSERT_OK_AND_ASSIGN(value, map_value.Get(IntValue(1), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(value)); ASSERT_EQ(Cast(value).NativeValue(), 4.0); ASSERT_OK_AND_ASSIGN(value, map_value.Get(IntValue(2), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(value)); ASSERT_EQ(Cast(value).NativeValue(), 5.0); EXPECT_THAT( map_value.Get(IntValue(3), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); } TEST_F(MapValueTest, Find) { ASSERT_OK_AND_ASSIGN( auto map_value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); absl::optional entry; ASSERT_OK_AND_ASSIGN(entry, map_value.Find(IntValue(0), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(entry); ASSERT_TRUE(InstanceOf(*entry)); ASSERT_EQ(Cast(*entry).NativeValue(), 3.0); ASSERT_OK_AND_ASSIGN(entry, map_value.Find(IntValue(1), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(entry); ASSERT_TRUE(InstanceOf(*entry)); ASSERT_EQ(Cast(*entry).NativeValue(), 4.0); ASSERT_OK_AND_ASSIGN(entry, map_value.Find(IntValue(2), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(entry); ASSERT_TRUE(InstanceOf(*entry)); ASSERT_EQ(Cast(*entry).NativeValue(), 5.0); ASSERT_OK_AND_ASSIGN(entry, map_value.Find(IntValue(3), descriptor_pool(), message_factory(), arena())); ASSERT_FALSE(entry); } TEST_F(MapValueTest, Has) { ASSERT_OK_AND_ASSIGN( auto map_value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); ASSERT_OK_AND_ASSIGN(auto value, map_value.Has(IntValue(0), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(value)); ASSERT_TRUE(Cast(value).NativeValue()); ASSERT_OK_AND_ASSIGN(value, map_value.Has(IntValue(1), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(value)); ASSERT_TRUE(Cast(value).NativeValue()); ASSERT_OK_AND_ASSIGN(value, map_value.Has(IntValue(2), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(value)); ASSERT_TRUE(Cast(value).NativeValue()); ASSERT_OK_AND_ASSIGN(value, map_value.Has(IntValue(3), descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(value)); ASSERT_FALSE(Cast(value).NativeValue()); } TEST_F(MapValueTest, ListKeys) { ASSERT_OK_AND_ASSIGN( auto map_value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); ASSERT_OK_AND_ASSIGN( auto list_keys, map_value.ListKeys(descriptor_pool(), message_factory(), arena())); std::vector keys; ASSERT_THAT(list_keys.ForEach( [&keys](const Value& element) -> bool { keys.push_back(Cast(element).NativeValue()); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(keys, UnorderedElementsAreArray({0, 1, 2})); } TEST_F(MapValueTest, ForEach) { ASSERT_OK_AND_ASSIGN( auto value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); std::vector> entries; EXPECT_THAT(value.ForEach( [&entries](const Value& key, const Value& value) { entries.push_back( std::pair{Cast(key).NativeValue(), Cast(value).NativeValue()}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAreArray( {std::pair{0, 3.0}, std::pair{1, 4.0}, std::pair{2, 5.0}})); } TEST_F(MapValueTest, NewIterator) { ASSERT_OK_AND_ASSIGN( auto value, NewIntDoubleMapValue(std::pair{IntValue(0), DoubleValue(3.0)}, std::pair{IntValue(1), DoubleValue(4.0)}, std::pair{IntValue(2), DoubleValue(5.0)})); ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); std::vector keys; while (iterator->HasNext()) { ASSERT_OK_AND_ASSIGN( auto element, iterator->Next(descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(InstanceOf(element)); keys.push_back(Cast(element).NativeValue()); } EXPECT_EQ(iterator->HasNext(), false); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); EXPECT_THAT(keys, UnorderedElementsAreArray({0, 1, 2})); } TEST_F(MapValueTest, ConvertToJson) { ASSERT_OK_AND_ASSIGN( auto value, NewJsonMapValue(std::pair{StringValue("0"), DoubleValue(3.0)}, std::pair{StringValue("1"), DoubleValue(4.0)}, std::pair{StringValue("2"), DoubleValue(5.0)})); auto* message = NewArenaValueMessage(); EXPECT_THAT( value.ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(struct_value: { fields: { key: "0" value: { number_value: 3 } } fields: { key: "1" value: { number_value: 4 } } fields: { key: "2" value: { number_value: 5 } } })pb")); } } // namespace } // namespace cel ================================================ FILE: common/values/map_value_variant.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_VARIANT_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_VARIANT_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/meta/type_traits.h" #include "absl/utility/utility.h" #include "common/values/custom_map_value.h" #include "common/values/legacy_map_value.h" #include "common/values/parsed_json_map_value.h" #include "common/values/parsed_map_field_value.h" namespace cel::common_internal { enum class MapValueIndex : uint16_t { kCustom = 0, kParsedField, kParsedJson, kLegacy, }; template struct MapValueAlternative; template <> struct MapValueAlternative { static constexpr MapValueIndex kIndex = MapValueIndex::kCustom; }; template <> struct MapValueAlternative { static constexpr MapValueIndex kIndex = MapValueIndex::kParsedField; }; template <> struct MapValueAlternative { static constexpr MapValueIndex kIndex = MapValueIndex::kParsedJson; }; template <> struct MapValueAlternative { static constexpr MapValueIndex kIndex = MapValueIndex::kLegacy; }; template struct IsMapValueAlternative : std::false_type {}; template struct IsMapValueAlternative{})>> : std::true_type {}; template inline constexpr bool IsMapValueAlternativeV = IsMapValueAlternative::value; inline constexpr size_t kMapValueVariantAlign = 8; inline constexpr size_t kMapValueVariantSize = 24; // MapValueVariant is a subset of alternatives from the main ValueVariant that // is only maps. It is not stored directly in ValueVariant. class alignas(kMapValueVariantAlign) MapValueVariant final { public: MapValueVariant() : MapValueVariant(absl::in_place_type) {} MapValueVariant(const MapValueVariant&) = default; MapValueVariant(MapValueVariant&&) = default; MapValueVariant& operator=(const MapValueVariant&) = default; MapValueVariant& operator=(MapValueVariant&&) = default; template explicit MapValueVariant(absl::in_place_type_t, Args&&... args) : index_(MapValueAlternative::kIndex) { static_assert(alignof(T) <= kMapValueVariantAlign); static_assert(sizeof(T) <= kMapValueVariantSize); static_assert(std::is_trivially_copyable_v); ::new (static_cast(&raw_[0])) T(std::forward(args)...); } template >>> explicit MapValueVariant(T&& value) : MapValueVariant(absl::in_place_type>, std::forward(value)) {} template void Assign(T&& value) { using U = absl::remove_cvref_t; static_assert(alignof(U) <= kMapValueVariantAlign); static_assert(sizeof(U) <= kMapValueVariantSize); static_assert(std::is_trivially_copyable_v); index_ = MapValueAlternative::kIndex; ::new (static_cast(&raw_[0])) U(std::forward(value)); } template bool Is() const { return index_ == MapValueAlternative::kIndex; } template T& Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return *At(); } template const T& Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return *At(); } template T&& Get() && ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return std::move(*At()); } template const T&& Get() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return std::move(*At()); } template T* absl_nullable As() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (Is()) { return At(); } return nullptr; } template const T* absl_nullable As() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (Is()) { return At(); } return nullptr; } template decltype(auto) Visit(Visitor&& visitor) const { switch (index_) { case MapValueIndex::kCustom: return std::forward(visitor)(Get()); case MapValueIndex::kParsedField: return std::forward(visitor)(Get()); case MapValueIndex::kParsedJson: return std::forward(visitor)(Get()); case MapValueIndex::kLegacy: return std::forward(visitor)(Get()); } } friend void swap(MapValueVariant& lhs, MapValueVariant& rhs) noexcept { using std::swap; swap(lhs.index_, rhs.index_); swap(lhs.raw_, rhs.raw_); } private: template ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull At() ABSL_ATTRIBUTE_LIFETIME_BOUND { static_assert(alignof(T) <= kMapValueVariantAlign); static_assert(sizeof(T) <= kMapValueVariantSize); static_assert(std::is_trivially_copyable_v); return std::launder(reinterpret_cast(&raw_[0])); } template ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull At() const ABSL_ATTRIBUTE_LIFETIME_BOUND { static_assert(alignof(T) <= kMapValueVariantAlign); static_assert(sizeof(T) <= kMapValueVariantSize); static_assert(std::is_trivially_copyable_v); return std::launder(reinterpret_cast(&raw_[0])); } MapValueIndex index_ = MapValueIndex::kCustom; alignas(8) std::byte raw_[kMapValueVariantSize]; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_VARIANT_H_ ================================================ FILE: common/values/message_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/message_value.h" #include #include #include #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/attribute.h" #include "common/optional_ref.h" #include "common/value.h" #include "common/values/parsed_message_value.h" #include "common/values/value_variant.h" #include "common/values/values.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() const { ABSL_CHECK(*this); // Crash OK return absl::visit( absl::Overload( [](absl::monostate) -> const google::protobuf::Descriptor* absl_nonnull { ABSL_UNREACHABLE(); }, [](const ParsedMessageValue& alternative) -> const google::protobuf::Descriptor* absl_nonnull { return alternative.GetDescriptor(); }), variant_); } std::string MessageValue::DebugString() const { return absl::visit( absl::Overload([](absl::monostate) -> std::string { return "INVALID"; }, [](const ParsedMessageValue& alternative) -> std::string { return alternative.DebugString(); }), variant_); } bool MessageValue::IsZeroValue() const { ABSL_DCHECK(*this); return absl::visit( absl::Overload([](absl::monostate) -> bool { return true; }, [](const ParsedMessageValue& alternative) -> bool { return alternative.IsZeroValue(); }), variant_); } absl::Status MessageValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::Status { return alternative.SerializeTo(descriptor_pool, message_factory, output); }), variant_); } absl::Status MessageValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::Status { return alternative.ConvertToJson(descriptor_pool, message_factory, json); }), variant_); } absl::Status MessageValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJsonObject` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }), variant_); } absl::Status MessageValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Equal` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::Status { return alternative.Equal(other, descriptor_pool, message_factory, arena, result); }), variant_); } absl::Status MessageValue::GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByName` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::Status { return alternative.GetFieldByName(name, unboxing_options, descriptor_pool, message_factory, arena, result); }), variant_); } absl::Status MessageValue::GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByNumber` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::Status { return alternative.GetFieldByNumber(number, unboxing_options, descriptor_pool, message_factory, arena, result); }), variant_); } absl::StatusOr MessageValue::HasFieldByName( absl::string_view name) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByName` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::StatusOr { return alternative.HasFieldByName(name); }), variant_); } absl::StatusOr MessageValue::HasFieldByNumber(int64_t number) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByNumber` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::StatusOr { return alternative.HasFieldByNumber(number); }), variant_); } absl::Status MessageValue::ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ForEachField` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::Status { return alternative.ForEachField(callback, descriptor_pool, message_factory, arena); }), variant_); } absl::Status MessageValue::Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const { return absl::visit( absl::Overload( [](absl::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Qualify` on " "an invalid `MessageValue`"); }, [&](const ParsedMessageValue& alternative) -> absl::Status { return alternative.Qualify(qualifiers, presence_test, descriptor_pool, message_factory, arena, result, count); }), variant_); } cel::optional_ref MessageValue::AsParsed() const& { if (const auto* alternative = absl::get_if(&variant_); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional MessageValue::AsParsed() && { if (auto* alternative = absl::get_if(&variant_); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } const ParsedMessageValue& MessageValue::GetParsed() const& { ABSL_DCHECK(IsParsed()); return absl::get(variant_); } ParsedMessageValue MessageValue::GetParsed() && { ABSL_DCHECK(IsParsed()); return absl::get(std::move(variant_)); } common_internal::ValueVariant MessageValue::ToValueVariant() const& { return common_internal::ValueVariant(absl::get(variant_)); } common_internal::ValueVariant MessageValue::ToValueVariant() && { return common_internal::ValueVariant( absl::get(std::move(variant_))); } common_internal::StructValueVariant MessageValue::ToStructValueVariant() const& { return common_internal::StructValueVariant( absl::get(variant_)); } common_internal::StructValueVariant MessageValue::ToStructValueVariant() && { return common_internal::StructValueVariant( absl::get(std::move(variant_))); } } // namespace cel ================================================ FILE: common/values/message_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MESSAGE_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MESSAGE_VALUE_H_ #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "absl/utility/utility.h" #include "base/attribute.h" #include "common/arena.h" #include "common/optional_ref.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_struct_value.h" #include "common/values/parsed_message_value.h" #include "common/values/values.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class StructValue; class MessageValue final : private common_internal::StructValueMixin { public: static constexpr ValueKind kKind = ValueKind::kStruct; // NOLINTNEXTLINE(google-explicit-constructor) MessageValue(const ParsedMessageValue& other) : variant_(absl::in_place_type, other) {} // NOLINTNEXTLINE(google-explicit-constructor) MessageValue(ParsedMessageValue&& other) : variant_(absl::in_place_type, std::move(other)) {} // Places the `MessageValue` into an unspecified state. Anything except // assigning to `MessageValue` is undefined behavior. MessageValue() = default; MessageValue(const MessageValue&) = default; MessageValue(MessageValue&&) = default; MessageValue& operator=(const MessageValue&) = default; MessageValue& operator=(MessageValue&&) = default; static ValueKind kind() { return kKind; } absl::string_view GetTypeName() const { return GetDescriptor()->full_name(); } MessageType GetRuntimeType() const { return MessageType(GetDescriptor()); } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const; bool IsZeroValue() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonObject(). absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::Equal; absl::Status GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByName; absl::Status GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByNumber; absl::StatusOr HasFieldByName(absl::string_view name) const; absl::StatusOr HasFieldByNumber(int64_t number) const; using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; absl::Status ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::Status Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const; using StructValueMixin::Qualify; bool IsParsed() const { return absl::holds_alternative(variant_); } template std::enable_if_t, bool> Is() const { return IsParsed(); } cel::optional_ref AsParsed() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsParsed(); } cel::optional_ref AsParsed() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsParsed() &&; absl::optional AsParsed() const&& { return common_internal::AsOptional(AsParsed()); } template std::enable_if_t, cel::optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsed(); } template std::enable_if_t, cel::optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return IsParsed(); } template std::enable_if_t, absl::optional> As() && ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::move(*this).AsParsed(); } template std::enable_if_t, absl::optional> As() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::move(*this).AsParsed(); } const ParsedMessageValue& GetParsed() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetParsed(); } const ParsedMessageValue& GetParsed() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; ParsedMessageValue GetParsed() &&; ParsedMessageValue GetParsed() const&& { return GetParsed(); } template std::enable_if_t, const ParsedMessageValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsed(); } template std::enable_if_t, const ParsedMessageValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsed(); } template std::enable_if_t, ParsedMessageValue> Get() && { return std::move(*this).GetParsed(); } template std::enable_if_t, ParsedMessageValue> Get() const&& { return std::move(*this).GetParsed(); } explicit operator bool() const { return !absl::holds_alternative(variant_); } friend void swap(MessageValue& lhs, MessageValue& rhs) noexcept { lhs.variant_.swap(rhs.variant_); } private: friend class Value; friend class StructValue; friend class common_internal::ValueMixin; friend class common_internal::StructValueMixin; friend struct ArenaTraits; common_internal::ValueVariant ToValueVariant() const&; common_internal::ValueVariant ToValueVariant() &&; common_internal::StructValueVariant ToStructValueVariant() const&; common_internal::StructValueVariant ToStructValueVariant() &&; absl::variant variant_; }; inline std::ostream& operator<<(std::ostream& out, const MessageValue& value) { return out << value.DebugString(); } template <> struct ArenaTraits { static bool trivially_destructible(const MessageValue& value) { return absl::visit( [](const auto& alternative) -> bool { return ArenaTraits<>::trivially_destructible(alternative); }, value.variant_); } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_MESSAGE_VALUE_H_ ================================================ FILE: common/values/message_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/attribute.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "internal/testing.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::StatusIs; using ::testing::An; using ::testing::Optional; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; using MessageValueTest = common_internal::ValueTest<>; TEST_F(MessageValueTest, Default) { MessageValue value; EXPECT_FALSE(value); google::protobuf::io::CordOutputStream output; EXPECT_THAT(value.SerializeTo(descriptor_pool(), message_factory(), &output), StatusIs(absl::StatusCode::kInternal)); Value scratch; int count; EXPECT_THAT( value.Equal(NullValue(), descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT(value.Equal(NullValue(), descriptor_pool(), message_factory(), arena(), &scratch), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT( value.GetFieldByName("", descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT(value.GetFieldByName("", descriptor_pool(), message_factory(), arena(), &scratch), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT( value.GetFieldByNumber(0, descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT(value.GetFieldByNumber(0, descriptor_pool(), message_factory(), arena(), &scratch), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT(value.HasFieldByName(""), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT(value.HasFieldByNumber(0), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT(value.ForEachField([](absl::string_view, const Value&) -> absl::StatusOr { return true; }, descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT(value.Qualify({AttributeQualifier::OfString("foo")}, false, descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kInternal)); EXPECT_THAT(value.Qualify({AttributeQualifier::OfString("foo")}, false, descriptor_pool(), message_factory(), arena(), &scratch, &count), StatusIs(absl::StatusCode::kInternal)); } template constexpr T& AsLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return t; } template constexpr const T& AsConstLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return t; } template constexpr T&& AsRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return static_cast(t); } template constexpr const T&& AsConstRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return static_cast(t); } TEST_F(MessageValueTest, Parsed) { MessageValue value(ParsedMessageValue( DynamicParseTextProto(R"pb()pb"), arena())); MessageValue other_value = value; EXPECT_TRUE(value); EXPECT_TRUE(value.Is()); EXPECT_THAT(value.As(), Optional(An())); EXPECT_THAT(AsLValueRef(value).Get(), An()); EXPECT_THAT(AsConstLValueRef(value).Get(), An()); EXPECT_THAT(AsRValueRef(value).Get(), An()); EXPECT_THAT( AsConstRValueRef(other_value).Get(), An()); } TEST_F(MessageValueTest, Kind) { MessageValue value; EXPECT_EQ(value.kind(), ParsedMessageValue::kKind); EXPECT_EQ(value.kind(), ValueKind::kStruct); } TEST_F(MessageValueTest, GetTypeName) { MessageValue value(ParsedMessageValue( DynamicParseTextProto(R"pb()pb"), arena())); EXPECT_EQ(value.GetTypeName(), "cel.expr.conformance.proto3.TestAllTypes"); } TEST_F(MessageValueTest, GetRuntimeType) { MessageValue value(ParsedMessageValue( DynamicParseTextProto(R"pb()pb"), arena())); EXPECT_EQ(value.GetRuntimeType(), MessageType(value.GetDescriptor())); } } // namespace } // namespace cel ================================================ FILE: common/values/mutable_list_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "common/value.h" #include "common/value_testing.h" #include "common/values/list_value_builder.h" #include "internal/testing.h" namespace cel::common_internal { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::ErrorValueIs; using ::cel::test::StringValueIs; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::UnorderedElementsAre; using MutableListValueTest = common_internal::ValueTest<>; TEST_F(MutableListValueTest, DebugString) { auto* mutable_list_value = NewMutableListValue(arena()); EXPECT_THAT(CustomListValue(mutable_list_value, arena()).DebugString(), "[]"); } TEST_F(MutableListValueTest, IsEmpty) { auto* mutable_list_value = NewMutableListValue(arena()); mutable_list_value->Reserve(1); EXPECT_TRUE(CustomListValue(mutable_list_value, arena()).IsEmpty()); EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); EXPECT_FALSE(CustomListValue(mutable_list_value, arena()).IsEmpty()); } TEST_F(MutableListValueTest, Size) { auto* mutable_list_value = NewMutableListValue(arena()); mutable_list_value->Reserve(1); EXPECT_THAT(CustomListValue(mutable_list_value, arena()).Size(), 0); EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); EXPECT_THAT(CustomListValue(mutable_list_value, arena()).Size(), 1); } TEST_F(MutableListValueTest, ForEach) { auto* mutable_list_value = NewMutableListValue(arena()); mutable_list_value->Reserve(1); std::vector> elements; auto for_each_callback = [&](size_t index, const Value& value) -> absl::StatusOr { elements.push_back(std::pair{index, value}); return true; }; EXPECT_THAT(CustomListValue(mutable_list_value, arena()) .ForEach(for_each_callback, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(elements, IsEmpty()); EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); EXPECT_THAT(CustomListValue(mutable_list_value, arena()) .ForEach(for_each_callback, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(elements, UnorderedElementsAre(Pair(0, StringValueIs("foo")))); } TEST_F(MutableListValueTest, NewIterator) { auto* mutable_list_value = NewMutableListValue(arena()); mutable_list_value->Reserve(1); ASSERT_OK_AND_ASSIGN( auto iterator, CustomListValue(mutable_list_value, arena()).NewIterator()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); ASSERT_OK_AND_ASSIGN( iterator, CustomListValue(mutable_list_value, arena()).NewIterator()); EXPECT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs("foo"))); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(MutableListValueTest, Get) { auto* mutable_list_value = NewMutableListValue(arena()); mutable_list_value->Reserve(1); Value value; EXPECT_THAT( CustomListValue(mutable_list_value, arena()) .Get(0, descriptor_pool(), message_factory(), arena(), &value), IsOk()); EXPECT_THAT(value, ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))); EXPECT_THAT(mutable_list_value->Append(StringValue("foo")), IsOk()); EXPECT_THAT( CustomListValue(mutable_list_value, arena()) .Get(0, descriptor_pool(), message_factory(), arena(), &value), IsOk()); EXPECT_THAT(value, StringValueIs("foo")); } TEST_F(MutableListValueTest, IsMutablListValue) { auto* mutable_list_value = NewMutableListValue(arena()); EXPECT_TRUE( IsMutableListValue(Value(CustomListValue(mutable_list_value, arena())))); EXPECT_TRUE(IsMutableListValue( ListValue(CustomListValue(mutable_list_value, arena())))); } TEST_F(MutableListValueTest, AsMutableListValue) { auto* mutable_list_value = NewMutableListValue(arena()); EXPECT_EQ( AsMutableListValue(Value(CustomListValue(mutable_list_value, arena()))), mutable_list_value); EXPECT_EQ(AsMutableListValue( ListValue(CustomListValue(mutable_list_value, arena()))), mutable_list_value); } TEST_F(MutableListValueTest, GetMutableListValue) { auto* mutable_list_value = NewMutableListValue(arena()); EXPECT_EQ( &GetMutableListValue(Value(CustomListValue(mutable_list_value, arena()))), mutable_list_value); EXPECT_EQ(&GetMutableListValue( ListValue(CustomListValue(mutable_list_value, arena()))), mutable_list_value); } } // namespace } // namespace cel::common_internal ================================================ FILE: common/values/mutable_map_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "common/value.h" #include "common/value_testing.h" #include "common/values/map_value_builder.h" #include "internal/testing.h" namespace cel::common_internal { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::IntValueIs; using ::cel::test::IsNullValue; using ::cel::test::ListValueElements; using ::cel::test::ListValueIs; using ::cel::test::StringValueIs; using ::testing::IsEmpty; using ::testing::IsFalse; using ::testing::IsTrue; using ::testing::Pair; using ::testing::UnorderedElementsAre; using MutableMapValueTest = common_internal::ValueTest<>; TEST_F(MutableMapValueTest, DebugString) { auto mutable_map_value = NewMutableMapValue(arena()); EXPECT_THAT(CustomMapValue(mutable_map_value, arena()).DebugString(), "{}"); } TEST_F(MutableMapValueTest, IsEmpty) { auto mutable_map_value = NewMutableMapValue(arena()); mutable_map_value->Reserve(1); EXPECT_TRUE(CustomMapValue(mutable_map_value, arena()).IsEmpty()); EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); EXPECT_FALSE(CustomMapValue(mutable_map_value, arena()).IsEmpty()); } TEST_F(MutableMapValueTest, Size) { auto mutable_map_value = NewMutableMapValue(arena()); mutable_map_value->Reserve(1); EXPECT_THAT(CustomMapValue(mutable_map_value, arena()).Size(), 0); EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); EXPECT_THAT(CustomMapValue(mutable_map_value, arena()).Size(), 1); } TEST_F(MutableMapValueTest, ListKeys) { auto mutable_map_value = NewMutableMapValue(arena()); mutable_map_value->Reserve(1); ListValue keys; EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); EXPECT_THAT( CustomMapValue(mutable_map_value, arena()) .ListKeys(descriptor_pool(), message_factory(), arena(), &keys), IsOk()); EXPECT_THAT(keys, ListValueIs(ListValueElements( UnorderedElementsAre(StringValueIs("foo")), descriptor_pool(), message_factory(), arena()))); } TEST_F(MutableMapValueTest, ForEach) { auto mutable_map_value = NewMutableMapValue(arena()); mutable_map_value->Reserve(1); std::vector> entries; auto for_each_callback = [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{key, value}); return true; }; EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) .ForEach(for_each_callback, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, IsEmpty()); EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) .ForEach(for_each_callback, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre(Pair(StringValueIs("foo"), IntValueIs(1)))); } TEST_F(MutableMapValueTest, NewIterator) { auto mutable_map_value = NewMutableMapValue(arena()); mutable_map_value->Reserve(1); ASSERT_OK_AND_ASSIGN( auto iterator, CustomMapValue(mutable_map_value, arena()).NewIterator()); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); ASSERT_OK_AND_ASSIGN( iterator, CustomMapValue(mutable_map_value, arena()).NewIterator()); EXPECT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs("foo"))); EXPECT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(MutableMapValueTest, FindHas) { auto* mutable_map_value = NewMutableMapValue(arena()); mutable_map_value->Reserve(1); Value value; EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) .Find(StringValue("foo"), descriptor_pool(), message_factory(), arena(), &value), IsOkAndHolds(IsFalse())); EXPECT_THAT(value, IsNullValue()); EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) .Has(StringValue("foo"), descriptor_pool(), message_factory(), arena(), &value), IsOk()); EXPECT_THAT(value, BoolValueIs(false)); EXPECT_THAT(mutable_map_value->Put(StringValue("foo"), IntValue(1)), IsOk()); EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) .Find(StringValue("foo"), descriptor_pool(), message_factory(), arena(), &value), IsOkAndHolds(IsTrue())); EXPECT_THAT(value, IntValueIs(1)); EXPECT_THAT(CustomMapValue(mutable_map_value, arena()) .Has(StringValue("foo"), descriptor_pool(), message_factory(), arena(), &value), IsOk()); EXPECT_THAT(value, BoolValueIs(true)); } TEST_F(MutableMapValueTest, IsMutableMapValue) { auto* mutable_map_value = NewMutableMapValue(arena()); EXPECT_TRUE( IsMutableMapValue(Value(CustomMapValue(mutable_map_value, arena())))); EXPECT_TRUE( IsMutableMapValue(MapValue(CustomMapValue(mutable_map_value, arena())))); } TEST_F(MutableMapValueTest, AsMutableMapValue) { auto* mutable_map_value = NewMutableMapValue(arena()); EXPECT_EQ( AsMutableMapValue(Value(CustomMapValue(mutable_map_value, arena()))), mutable_map_value); EXPECT_EQ( AsMutableMapValue(MapValue(CustomMapValue(mutable_map_value, arena()))), mutable_map_value); } TEST_F(MutableMapValueTest, GetMutableMapValue) { auto* mutable_map_value = NewMutableMapValue(arena()); EXPECT_EQ( &GetMutableMapValue(Value(CustomMapValue(mutable_map_value, arena()))), mutable_map_value); EXPECT_EQ( &GetMutableMapValue(MapValue(CustomMapValue(mutable_map_value, arena()))), mutable_map_value); } } // namespace } // namespace cel::common_internal ================================================ FILE: common/values/null_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { using ::cel::well_known_types::ValueReflection; absl::Status NullValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::Value message; message.set_null_value(google::protobuf::NULL_VALUE); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Value"); } return absl::OkStatus(); } absl::Status NullValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.SetNullValue(json); return absl::OkStatus(); } absl::Status NullValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); *result = BoolValue(other.IsNull()); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/null_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_NULL_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_NULL_VALUE_H_ #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class NullValue; // `NullValue` represents the CEL `null` value. class NullValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kNull; NullValue() = default; NullValue(const NullValue&) = default; NullValue(NullValue&&) = default; NullValue& operator=(const NullValue&) = default; NullValue& operator=(NullValue&&) = default; constexpr ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return NullType::kName; } std::string DebugString() const { return "null"; } // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return true; } friend void swap(NullValue&, NullValue&) noexcept {} private: friend class common_internal::ValueMixin; }; inline bool operator==(NullValue, NullValue) { return true; } inline bool operator!=(NullValue lhs, NullValue rhs) { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, const NullValue& value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_NULL_VALUE_H_ ================================================ FILE: common/values/null_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/casting.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::testing::An; using ::testing::Ne; using NullValueTest = common_internal::ValueTest<>; TEST_F(NullValueTest, Kind) { EXPECT_EQ(NullValue().kind(), NullValue::kKind); EXPECT_EQ(Value(NullValue()).kind(), NullValue::kKind); } TEST_F(NullValueTest, DebugString) { { std::ostringstream out; out << NullValue(); EXPECT_EQ(out.str(), "null"); } { std::ostringstream out; out << Value(NullValue()); EXPECT_EQ(out.str(), "null"); } } TEST_F(NullValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT( NullValue().ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(null_value: NULL_VALUE)pb")); } TEST_F(NullValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(NullValue()), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(NullValue())), NativeTypeId::For()); } TEST_F(NullValueTest, InstanceOf) { EXPECT_TRUE(InstanceOf(NullValue())); EXPECT_TRUE(InstanceOf(Value(NullValue()))); } TEST_F(NullValueTest, Cast) { EXPECT_THAT(Cast(NullValue()), An()); EXPECT_THAT(Cast(Value(NullValue())), An()); } TEST_F(NullValueTest, As) { EXPECT_THAT(As(Value(NullValue())), Ne(absl::nullopt)); } } // namespace } // namespace cel ================================================ FILE: common/values/opaque_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/type.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { // Code below assumes OptionalValue has the same layout as OpaqueValue. static_assert(std::is_base_of_v); static_assert(sizeof(OpaqueValue) == sizeof(OptionalValue)); static_assert(alignof(OpaqueValue) == alignof(OptionalValue)); OpaqueValue OpaqueValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { OpaqueValueInterface::Content content = content_.To(); if (content.interface == nullptr) { return *this; } if (content.arena != arena) { return content.interface->Clone(arena); } return *this; } if (dispatcher_->get_arena(dispatcher_, content_) != arena) { return dispatcher_->clone(dispatcher_, content_, arena); } return *this; } OpaqueType OpaqueValue::GetRuntimeType() const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { OpaqueValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->GetRuntimeType(); } return dispatcher_->get_runtime_type(dispatcher_, content_); } absl::string_view OpaqueValue::GetTypeName() const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { OpaqueValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->GetTypeName(); } return dispatcher_->get_type_name(dispatcher_, content_); } std::string OpaqueValue::DebugString() const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { OpaqueValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->DebugString(); } return dispatcher_->debug_string(dispatcher_, content_); } // See Value::SerializeTo(). absl::Status OpaqueValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); return absl::FailedPreconditionError( absl::StrCat(GetTypeName(), "is unserializable")); } // See Value::ConvertToJson(). absl::Status OpaqueValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); return absl::FailedPreconditionError( absl::StrCat(GetTypeName(), " is not convertable to JSON")); } absl::Status OpaqueValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_opaque = other.AsOpaque(); other_opaque) { if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { OpaqueValueInterface::Content content = content_.To(); ABSL_DCHECK(content.interface != nullptr); return content.interface->Equal(*other_opaque, descriptor_pool, message_factory, arena, result); } return dispatcher_->equal(dispatcher_, content_, *other_opaque, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } NativeTypeId OpaqueValue::GetTypeId() const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(dispatcher_ == nullptr)) { OpaqueValueInterface::Content content = content_.To(); if (content.interface == nullptr) { return NativeTypeId(); } return content.interface->GetNativeTypeId(); } return dispatcher_->get_type_id(dispatcher_, content_); } bool OpaqueValue::IsOptional() const { return dispatcher_ != nullptr && dispatcher_->get_type_id(dispatcher_, content_) == NativeTypeId::For(); } optional_ref OpaqueValue::AsOptional() const& { if (IsOptional()) { return *reinterpret_cast(this); } return absl::nullopt; } absl::optional OpaqueValue::AsOptional() && { if (IsOptional()) { return std::move(*reinterpret_cast(this)); } return absl::nullopt; } const OptionalValue& OpaqueValue::GetOptional() const& { ABSL_DCHECK(IsOptional()) << *this; return *reinterpret_cast(this); } OptionalValue OpaqueValue::GetOptional() && { ABSL_DCHECK(IsOptional()) << *this; return std::move(*reinterpret_cast(this)); } } // namespace cel ================================================ FILE: common/values/opaque_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" // IWYU pragma: friend "common/values/optional_value.h" // `OpaqueValue` represents values of the `opaque` type. `OpaqueValueView` // is a non-owning view of `OpaqueValue`. `OpaqueValueInterface` is the abstract // base class of implementations. `OpaqueValue` and `OpaqueValueView` act as // smart pointers to `OpaqueValueInterface`. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPAQUE_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPAQUE_VALUE_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class OpaqueValueInterface; class OpaqueValueInterfaceIterator; class OpaqueValue; using OpaqueValueContent = CustomValueContent; struct OpaqueValueDispatcher { using GetTypeId = NativeTypeId (*)(const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content); using GetArena = google::protobuf::Arena* absl_nullable (*)( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content); using GetTypeName = absl::string_view (*)( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content); using DebugString = std::string (*)(const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content); using GetRuntimeType = OpaqueType (*)(const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content); using Equal = absl::Status (*)( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, const OpaqueValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); using Clone = OpaqueValue (*)( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena); absl_nonnull GetTypeId get_type_id; absl_nonnull GetArena get_arena; absl_nonnull GetTypeName get_type_name; absl_nonnull DebugString debug_string; absl_nonnull GetRuntimeType get_runtime_type; absl_nonnull Equal equal; absl_nonnull Clone clone; }; class OpaqueValueInterface { public: OpaqueValueInterface() = default; OpaqueValueInterface(const OpaqueValueInterface&) = delete; OpaqueValueInterface(OpaqueValueInterface&&) = delete; virtual ~OpaqueValueInterface() = default; OpaqueValueInterface& operator=(const OpaqueValueInterface&) = delete; OpaqueValueInterface& operator=(OpaqueValueInterface&&) = delete; private: friend class OpaqueValue; virtual std::string DebugString() const = 0; virtual absl::string_view GetTypeName() const = 0; virtual OpaqueType GetRuntimeType() const = 0; virtual absl::Status Equal( const OpaqueValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; virtual OpaqueValue Clone(google::protobuf::Arena* absl_nonnull arena) const = 0; virtual NativeTypeId GetNativeTypeId() const = 0; struct Content { const OpaqueValueInterface* absl_nonnull interface; google::protobuf::Arena* absl_nonnull arena; }; }; // Creates an opaque value from a manual dispatch table `dispatcher` and // opaque data `content` whose format is only know to functions in the manual // dispatch table. The dispatch table should probably be valid for the lifetime // of the process, but at a minimum must outlive all instances of the resulting // value. // // IMPORTANT: This approach to implementing OpaqueValue should only be // used when you know exactly what you are doing. When in doubt, just implement // OpaqueValueInterface. OpaqueValue UnsafeOpaqueValue(const OpaqueValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, OpaqueValueContent content); class OpaqueValue : private common_internal::OpaqueValueMixin { public: static constexpr ValueKind kKind = ValueKind::kOpaque; // Constructs an opaque value from an implementation of // `OpaqueValueInterface` `interface` whose lifetime is tied to that of // the arena `arena`. OpaqueValue(const OpaqueValueInterface* absl_nonnull interface ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(interface != nullptr); ABSL_DCHECK(arena != nullptr); content_ = OpaqueValueContent::From( OpaqueValueInterface::Content{.interface = interface, .arena = arena}); } OpaqueValue() = default; OpaqueValue(const OpaqueValue&) = default; OpaqueValue(OpaqueValue&&) = default; OpaqueValue& operator=(const OpaqueValue&) = default; OpaqueValue& operator=(OpaqueValue&&) = default; static constexpr ValueKind kind() { return kKind; } NativeTypeId GetTypeId() const; OpaqueType GetRuntimeType() const; absl::string_view GetTypeName() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using OpaqueValueMixin::Equal; bool IsZeroValue() const { return false; } OpaqueValue Clone(google::protobuf::Arena* absl_nonnull arena) const; // Returns `true` if this opaque value is an instance of an optional value. bool IsOptional() const; // Convenience method for use with template metaprogramming. See // `IsOptional()`. template std::enable_if_t, bool> Is() const { return IsOptional(); } // Performs a checked cast from an opaque value to an optional value, // returning a non-empty optional with either a value or reference to the // optional value. Otherwise an empty optional is returned. optional_ref AsOptional() & ABSL_ATTRIBUTE_LIFETIME_BOUND; optional_ref AsOptional() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsOptional() &&; absl::optional AsOptional() const&&; // Convenience method for use with template metaprogramming. See // `AsOptional()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND; template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; template std::enable_if_t, absl::optional> As() &&; template std::enable_if_t, absl::optional> As() const&&; // Performs an unchecked cast from an opaque value to an optional value. In // debug builds a best effort is made to crash. If `IsOptional()` would return // false, calling this method is undefined behavior. const OptionalValue& GetOptional() & ABSL_ATTRIBUTE_LIFETIME_BOUND; const OptionalValue& GetOptional() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; OptionalValue GetOptional() &&; OptionalValue GetOptional() const&&; // Convenience method for use with template metaprogramming. See // `Optional()`. template std::enable_if_t, const OptionalValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND; template std::enable_if_t, const OptionalValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; template std::enable_if_t, OptionalValue> Get() &&; template std::enable_if_t, OptionalValue> Get() const&&; const OpaqueValueDispatcher* absl_nullable dispatcher() const { return dispatcher_; } OpaqueValueContent content() const { ABSL_DCHECK(dispatcher_ != nullptr); return content_; } const OpaqueValueInterface* absl_nullable interface() const { if (dispatcher_ == nullptr) { return content_.To().interface; } return nullptr; } friend void swap(OpaqueValue& lhs, OpaqueValue& rhs) noexcept { using std::swap; swap(lhs.dispatcher_, rhs.dispatcher_); swap(lhs.content_, rhs.content_); } explicit operator bool() const { if (dispatcher_ == nullptr) { return content_.To().interface != nullptr; } return true; } protected: OpaqueValue(const OpaqueValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, OpaqueValueContent content) : dispatcher_(dispatcher), content_(content) { ABSL_DCHECK(dispatcher != nullptr); ABSL_DCHECK(dispatcher->get_type_id != nullptr); ABSL_DCHECK(dispatcher->get_type_name != nullptr); ABSL_DCHECK(dispatcher->clone != nullptr); } private: friend class common_internal::ValueMixin; friend class common_internal::OpaqueValueMixin; friend OpaqueValue UnsafeOpaqueValue(const OpaqueValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, OpaqueValueContent content); const OpaqueValueDispatcher* absl_nullable dispatcher_ = nullptr; OpaqueValueContent content_ = OpaqueValueContent::Zero(); }; inline std::ostream& operator<<(std::ostream& out, const OpaqueValue& type) { return out << type.DebugString(); } template <> struct NativeTypeTraits final { static NativeTypeId Id(const OpaqueValue& type) { return type.GetTypeId(); } }; inline OpaqueValue UnsafeOpaqueValue(const OpaqueValueDispatcher* absl_nonnull dispatcher ABSL_ATTRIBUTE_LIFETIME_BOUND, OpaqueValueContent content) { return OpaqueValue(dispatcher, content); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPAQUE_VALUE_H_ ================================================ FILE: common/values/optional_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/base/attributes.h" #include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/arena.h" #include "common/native_type.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace { struct OptionalValueDispatcher : public OpaqueValueDispatcher { using HasValue = bool (*)(const OptionalValueDispatcher* absl_nonnull dispatcher, CustomValueContent content); using Value = void (*)(const OptionalValueDispatcher* absl_nonnull dispatcher, CustomValueContent content, cel::Value* absl_nonnull result); absl_nonnull HasValue has_value; absl_nonnull Value value; }; NativeTypeId OptionalValueGetTypeId(const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { return NativeTypeId::For(); } absl::string_view OptionalValueGetTypeName( const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { return "optional_type"; } OpaqueType OptionalValueGetRuntimeType( const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { return OptionalType(); } std::string OptionalValueDebugString( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content) { if (!static_cast(dispatcher) ->has_value(static_cast(dispatcher), content)) { return "optional.none()"; } Value value; static_cast(dispatcher) ->value(static_cast(dispatcher), content, &value); return absl::StrCat("optional.of(", value.DebugString(), ")"); } bool OptionalValueHasValue(const OptionalValueDispatcher* absl_nonnull, OpaqueValueContent) { return true; } absl::Status OptionalValueEqual( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, const OpaqueValue& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_optional = other.AsOptional(); other_optional) { const bool lhs_has_value = static_cast(dispatcher) ->has_value(static_cast(dispatcher), content); const bool rhs_has_value = other_optional->HasValue(); if (lhs_has_value != rhs_has_value) { *result = FalseValue(); return absl::OkStatus(); } if (!lhs_has_value) { *result = TrueValue(); return absl::OkStatus(); } Value lhs_value; Value rhs_value; static_cast(dispatcher) ->value(static_cast(dispatcher), content, &lhs_value); other_optional->Value(&rhs_value); return lhs_value.Equal(rhs_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } ABSL_CONST_INIT const OptionalValueDispatcher empty_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { return nullptr; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { return common_internal::MakeOptionalValue(dispatcher, content); }, }, [](const OptionalValueDispatcher* absl_nonnull dispatcher, CustomValueContent content) -> bool { return false; }, [](const OptionalValueDispatcher* absl_nonnull dispatcher, CustomValueContent content, cel::Value* absl_nonnull result) -> void { *result = ErrorValue( absl::FailedPreconditionError("optional.none() dereference")); }, }; ABSL_CONST_INIT const OptionalValueDispatcher null_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { return nullptr; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { return common_internal::MakeOptionalValue(dispatcher, content); }, }, &OptionalValueHasValue, [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent, cel::Value* absl_nonnull result) -> void { *result = NullValue(); }, }; ABSL_CONST_INIT const OptionalValueDispatcher bool_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { return nullptr; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { return common_internal::MakeOptionalValue(dispatcher, content); }, }, &OptionalValueHasValue, [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, cel::Value* absl_nonnull result) -> void { *result = BoolValue(content.To()); }, }; ABSL_CONST_INIT const OptionalValueDispatcher int_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { return nullptr; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { return common_internal::MakeOptionalValue(dispatcher, content); }, }, &OptionalValueHasValue, [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, cel::Value* absl_nonnull result) -> void { *result = IntValue(content.To()); }, }; ABSL_CONST_INIT const OptionalValueDispatcher uint_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { return nullptr; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { return common_internal::MakeOptionalValue(dispatcher, content); }, }, &OptionalValueHasValue, [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, cel::Value* absl_nonnull result) -> void { *result = UintValue(content.To()); }, }; ABSL_CONST_INIT const OptionalValueDispatcher double_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { return nullptr; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { return common_internal::MakeOptionalValue(dispatcher, content); }, }, &OptionalValueHasValue, [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, cel::Value* absl_nonnull result) -> void { *result = DoubleValue(content.To()); }, }; ABSL_CONST_INIT const OptionalValueDispatcher duration_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { return nullptr; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { return common_internal::MakeOptionalValue(dispatcher, content); }, }, &OptionalValueHasValue, [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, cel::Value* absl_nonnull result) -> void { *result = UnsafeDurationValue(content.To()); }, }; ABSL_CONST_INIT const OptionalValueDispatcher timestamp_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { return nullptr; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { return common_internal::MakeOptionalValue(dispatcher, content); }, }, &OptionalValueHasValue, [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, cel::Value* absl_nonnull result) -> void { *result = UnsafeTimestampValue(content.To()); }, }; struct OptionalValueContent { const Value* absl_nonnull value; google::protobuf::Arena* absl_nonnull arena; }; ABSL_CONST_INIT const OptionalValueDispatcher optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent content) -> google::protobuf::Arena* absl_nullable { return content.To().arena; }, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { ABSL_DCHECK(arena != nullptr); cel::Value* absl_nonnull result = ::new ( arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) cel::Value( content.To().value->Clone(arena)); if (!ArenaTraits<>::trivially_destructible(result)) { arena->OwnDestructor(result); } return common_internal::MakeOptionalValue( &optional_value_dispatcher, OpaqueValueContent::From( OptionalValueContent{.value = result, .arena = arena})); }, }, &OptionalValueHasValue, [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, cel::Value* absl_nonnull result) -> void { *result = *content.To().value; }, }; } // namespace OptionalValue OptionalValue::Of(cel::Value value, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(value.kind() != ValueKind::kError && value.kind() != ValueKind::kUnknown); ABSL_DCHECK(arena != nullptr); // We can actually fit a lot more of the underlying values, avoiding arena // allocations and destructors. For now, we just do scalars. switch (value.kind()) { case ValueKind::kNull: return OptionalValue(&null_optional_value_dispatcher, OpaqueValueContent::Zero()); case ValueKind::kBool: return OptionalValue( &bool_optional_value_dispatcher, OpaqueValueContent::From(absl::implicit_cast(value.GetBool()))); case ValueKind::kInt: return OptionalValue(&int_optional_value_dispatcher, OpaqueValueContent::From( absl::implicit_cast(value.GetInt()))); case ValueKind::kUint: return OptionalValue(&uint_optional_value_dispatcher, OpaqueValueContent::From( absl::implicit_cast(value.GetUint()))); case ValueKind::kDouble: return OptionalValue(&double_optional_value_dispatcher, OpaqueValueContent::From( absl::implicit_cast(value.GetDouble()))); case ValueKind::kDuration: return OptionalValue( &duration_optional_value_dispatcher, OpaqueValueContent::From(value.GetDuration().ToDuration())); case ValueKind::kTimestamp: return OptionalValue( ×tamp_optional_value_dispatcher, OpaqueValueContent::From(value.GetTimestamp().ToTime())); default: { cel::Value* absl_nonnull result = ::new ( arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) cel::Value(std::move(value)); if (!ArenaTraits<>::trivially_destructible(result)) { arena->OwnDestructor(result); } return OptionalValue(&optional_value_dispatcher, OpaqueValueContent::From(OptionalValueContent{ .value = result, .arena = arena})); } } } OptionalValue OptionalValue::None() { return OptionalValue(&empty_optional_value_dispatcher, OpaqueValueContent::Zero()); } bool OptionalValue::HasValue() const { return static_cast(OpaqueValue::dispatcher()) ->has_value(static_cast( OpaqueValue::dispatcher()), OpaqueValue::content()); } void OptionalValue::Value(cel::Value* absl_nonnull result) const { ABSL_DCHECK(result != nullptr); static_cast(OpaqueValue::dispatcher()) ->value(static_cast( OpaqueValue::dispatcher()), OpaqueValue::content(), result); } cel::Value OptionalValue::Value() const { cel::Value result; Value(&result); return result; } } // namespace cel ================================================ FILE: common/values/optional_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" // `OptionalValue` represents values of the `optional_type` type. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPTIONAL_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPTIONAL_VALUE_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/types/optional.h" #include "common/optional_ref.h" #include "common/type.h" #include "common/values/opaque_value.h" #include "google/protobuf/arena.h" namespace cel { class Value; class OptionalValue; namespace common_internal { OptionalValue MakeOptionalValue( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content); } class OptionalValue final : public OpaqueValue { public: static OptionalValue None(); static OptionalValue Of(cel::Value value, google::protobuf::Arena* absl_nonnull arena); OptionalValue() : OptionalValue(None()) {} OptionalValue(const OptionalValue&) = default; OptionalValue(OptionalValue&&) = default; OptionalValue& operator=(const OptionalValue&) = default; OptionalValue& operator=(OptionalValue&&) = default; OptionalType GetRuntimeType() const { return OpaqueValue::GetRuntimeType().GetOptional(); } bool HasValue() const; void Value(cel::Value* absl_nonnull result) const; cel::Value Value() const; bool IsOptional() const = delete; template std::enable_if_t, bool> Is() const = delete; optional_ref AsOptional() & = delete; optional_ref AsOptional() const& = delete; absl::optional AsOptional() && = delete; absl::optional AsOptional() const&& = delete; const OptionalValue& GetOptional() & = delete; const OptionalValue& GetOptional() const& = delete; OptionalValue GetOptional() && = delete; OptionalValue GetOptional() const&& = delete; template std::enable_if_t, optional_ref> As() & = delete; template std::enable_if_t, optional_ref> As() const& = delete; template std::enable_if_t, absl::optional> As() && = delete; template std::enable_if_t, absl::optional> As() const&& = delete; template std::enable_if_t, optional_ref> Get() & = delete; template std::enable_if_t, optional_ref> Get() const& = delete; template std::enable_if_t, absl::optional> Get() && = delete; template std::enable_if_t, absl::optional> Get() const&& = delete; private: friend OptionalValue common_internal::MakeOptionalValue( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content); OptionalValue(const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content) : OpaqueValue(dispatcher, content) {} using OpaqueValue::content; using OpaqueValue::dispatcher; using OpaqueValue::interface; }; inline optional_ref OpaqueValue::AsOptional() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsOptional(); } inline absl::optional OpaqueValue::AsOptional() const&& { return common_internal::AsOptional(AsOptional()); } template inline std::enable_if_t, optional_ref> OpaqueValue::As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsOptional(); } template inline std::enable_if_t, optional_ref> OpaqueValue::As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsOptional(); } template inline std::enable_if_t, absl::optional> OpaqueValue::As() && { return std::move(*this).AsOptional(); } template inline std::enable_if_t, absl::optional> OpaqueValue::As() const&& { return std::move(*this).AsOptional(); } inline const OptionalValue& OpaqueValue::GetOptional() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetOptional(); } inline OptionalValue OpaqueValue::GetOptional() const&& { return GetOptional(); } template std::enable_if_t, const OptionalValue&> OpaqueValue::Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetOptional(); } template std::enable_if_t, const OptionalValue&> OpaqueValue::Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetOptional(); } template std::enable_if_t, OptionalValue> OpaqueValue::Get() && { return std::move(*this).GetOptional(); } template std::enable_if_t, OptionalValue> OpaqueValue::Get() const&& { return std::move(*this).GetOptional(); } namespace common_internal { inline OptionalValue MakeOptionalValue( const OpaqueValueDispatcher* absl_nonnull dispatcher, OpaqueValueContent content) { return OptionalValue(dispatcher, content); } } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_OPTIONAL_VALUE_H_ ================================================ FILE: common/values/optional_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/time/time.h" #include "common/native_type.h" #include "common/type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::DoubleValueIs; using ::cel::test::DurationValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::cel::test::IsNullValue; using ::cel::test::StringValueIs; using ::cel::test::TimestampValueIs; using ::cel::test::UintValueIs; class OptionalValueTest : public common_internal::ValueTest<> { public: OptionalValue OptionalNone() { return OptionalValue::None(); } OptionalValue OptionalOf(Value value) { return OptionalValue::Of(std::move(value), arena()); } }; TEST_F(OptionalValueTest, Kind) { EXPECT_EQ(OptionalValue::kind(), OptionalValue::kKind); } TEST_F(OptionalValueTest, GetRuntimeType) { EXPECT_EQ(OptionalValue().GetRuntimeType(), OptionalType()); EXPECT_EQ(OpaqueValue(OptionalValue()).GetRuntimeType(), OptionalType()); } TEST_F(OptionalValueTest, DebugString) { EXPECT_EQ(OptionalValue().DebugString(), "optional.none()"); EXPECT_EQ(OptionalOf(NullValue()).DebugString(), "optional.of(null)"); EXPECT_EQ(OptionalOf(TrueValue()).DebugString(), "optional.of(true)"); EXPECT_EQ(OptionalOf(IntValue(1)).DebugString(), "optional.of(1)"); EXPECT_EQ(OptionalOf(UintValue(1u)).DebugString(), "optional.of(1u)"); EXPECT_EQ(OptionalOf(DoubleValue(1.0)).DebugString(), "optional.of(1.0)"); EXPECT_EQ(OptionalOf(DurationValue()).DebugString(), "optional.of(0)"); EXPECT_EQ(OptionalOf(TimestampValue()).DebugString(), "optional.of(1970-01-01T00:00:00Z)"); EXPECT_EQ(OptionalOf(StringValue()).DebugString(), "optional.of(\"\")"); } TEST_F(OptionalValueTest, SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(OptionalValue().SerializeTo(descriptor_pool(), message_factory(), &output), StatusIs(absl::StatusCode::kFailedPrecondition)); EXPECT_THAT(OpaqueValue(OptionalValue()) .SerializeTo(descriptor_pool(), message_factory(), &output), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(OptionalValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(OptionalValue().ConvertToJson(descriptor_pool(), message_factory(), message), StatusIs(absl::StatusCode::kFailedPrecondition)); EXPECT_THAT(OpaqueValue(OptionalValue()) .ConvertToJson(descriptor_pool(), message_factory(), message), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(OptionalValueTest, GetTypeId) { EXPECT_EQ(OpaqueValue(OptionalValue()).GetTypeId(), NativeTypeId::For()); EXPECT_EQ(OpaqueValue(OptionalOf(NullValue())).GetTypeId(), NativeTypeId::For()); EXPECT_EQ(OpaqueValue(OptionalOf(TrueValue())).GetTypeId(), NativeTypeId::For()); EXPECT_EQ(OpaqueValue(OptionalOf(IntValue(1))).GetTypeId(), NativeTypeId::For()); EXPECT_EQ(OpaqueValue(OptionalOf(UintValue(1u))).GetTypeId(), NativeTypeId::For()); EXPECT_EQ(OpaqueValue(OptionalOf(DoubleValue(1.0))).GetTypeId(), NativeTypeId::For()); EXPECT_EQ(OpaqueValue(OptionalOf(DurationValue())).GetTypeId(), NativeTypeId::For()); EXPECT_EQ(OpaqueValue(OptionalOf(TimestampValue())).GetTypeId(), NativeTypeId::For()); EXPECT_EQ(OpaqueValue(OptionalOf(StringValue())).GetTypeId(), NativeTypeId::For()); } TEST_F(OptionalValueTest, HasValue) { EXPECT_FALSE(OptionalValue().HasValue()); EXPECT_TRUE(OptionalOf(NullValue()).HasValue()); EXPECT_TRUE(OptionalOf(TrueValue()).HasValue()); EXPECT_TRUE(OptionalOf(IntValue(1)).HasValue()); EXPECT_TRUE(OptionalOf(UintValue(1u)).HasValue()); EXPECT_TRUE(OptionalOf(DoubleValue(1.0)).HasValue()); EXPECT_TRUE(OptionalOf(DurationValue()).HasValue()); EXPECT_TRUE(OptionalOf(TimestampValue()).HasValue()); EXPECT_TRUE(OptionalOf(StringValue()).HasValue()); } TEST_F(OptionalValueTest, Value) { EXPECT_THAT(OptionalValue().Value(), ErrorValueIs(StatusIs(absl::StatusCode::kFailedPrecondition))); EXPECT_THAT(OptionalOf(NullValue()).Value(), IsNullValue()); EXPECT_THAT(OptionalOf(TrueValue()).Value(), BoolValueIs(true)); EXPECT_THAT(OptionalOf(IntValue(1)).Value(), IntValueIs(1)); EXPECT_THAT(OptionalOf(UintValue(1u)).Value(), UintValueIs(1u)); EXPECT_THAT(OptionalOf(DoubleValue(1.0)).Value(), DoubleValueIs(1.0)); EXPECT_THAT(OptionalOf(DurationValue()).Value(), DurationValueIs(absl::ZeroDuration())); EXPECT_THAT(OptionalOf(TimestampValue()).Value(), TimestampValueIs(absl::UnixEpoch())); EXPECT_THAT(OptionalOf(StringValue()).Value(), StringValueIs("")); } } // namespace } // namespace cel ================================================ FILE: common/values/parsed_json_list_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/parsed_json_list_value.h" #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "common/memory.h" #include "common/value.h" #include "common/values/parsed_json_value.h" #include "common/values/values.h" #include "internal/json.h" #include "internal/message_equality.h" #include "internal/number.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { using ::cel::well_known_types::ValueReflection; namespace common_internal { absl::Status CheckWellKnownListValueMessage(const google::protobuf::Message& message) { return internal::CheckJsonList(message); } } // namespace common_internal std::string ParsedJsonListValue::DebugString() const { if (value_ == nullptr) { return "[]"; } return internal::JsonListDebugString(*value_); } absl::Status ParsedJsonListValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); if (value_ == nullptr) { return absl::OkStatus(); } if (!value_->SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.ListValue"); } return absl::OkStatus(); } absl::Status ParsedJsonListValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); auto* message = value_reflection.MutableListValue(json); message->Clear(); if (value_ == nullptr) { return absl::OkStatus(); } if (value_->GetDescriptor() == message->GetDescriptor()) { // We can directly use google::protobuf::Message::Copy(). message->CopyFrom(*value_); } else { // Equivalent descriptors but not identical. Must serialize and deserialize. absl::Cord serialized; if (!value_->SerializePartialToString(&serialized)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", value_->GetTypeName())); } if (!message->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parsed message: ", message->GetTypeName())); } } return absl::OkStatus(); } absl::Status ParsedJsonListValue::ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); if (value_ == nullptr) { json->Clear(); return absl::OkStatus(); } if (value_->GetDescriptor() == json->GetDescriptor()) { // We can directly use google::protobuf::Message::Copy(). json->CopyFrom(*value_); } else { // Equivalent descriptors but not identical. Must serialize and deserialize. absl::Cord serialized; if (!value_->SerializePartialToString(&serialized)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", value_->GetTypeName())); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parsed message: ", json->GetTypeName())); } } return absl::OkStatus(); } absl::Status ParsedJsonListValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsParsedJsonList(); other_value) { *result = BoolValue(*this == *other_value); return absl::OkStatus(); } if (auto other_value = other.AsParsedRepeatedField(); other_value) { if (value_ == nullptr) { *result = BoolValue(other_value->IsEmpty()); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN( auto equal, internal::MessageFieldEquals( *value_, *other_value->message_, other_value->field_, descriptor_pool, message_factory)); *result = BoolValue(equal); return absl::OkStatus(); } if (auto other_value = other.AsList(); other_value) { return common_internal::ListValueEqual(ListValue(*this), *other_value, descriptor_pool, message_factory, arena, result); } *result = BoolValue(false); return absl::OkStatus(); } ParsedJsonListValue ParsedJsonListValue::Clone( google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); if (value_ == nullptr) { return ParsedJsonListValue(); } if (arena_ == arena) { return *this; } auto* cloned = value_->New(arena); cloned->CopyFrom(*value_); return ParsedJsonListValue(cloned, arena); } size_t ParsedJsonListValue::Size() const { if (value_ == nullptr) { return 0; } return static_cast( well_known_types::GetListValueReflectionOrDie(value_->GetDescriptor()) .ValuesSize(*value_)); } // See ListValueInterface::Get for documentation. absl::Status ParsedJsonListValue::Get( size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (value_ == nullptr) { *result = IndexOutOfBoundsError(index); return absl::OkStatus(); } const auto reflection = well_known_types::GetListValueReflectionOrDie(value_->GetDescriptor()); if (ABSL_PREDICT_FALSE(index >= static_cast(reflection.ValuesSize(*value_)))) { *result = IndexOutOfBoundsError(index); return absl::OkStatus(); } *result = common_internal::ParsedJsonValue( &reflection.Values(*value_, static_cast(index)), arena); return absl::OkStatus(); } absl::Status ParsedJsonListValue::ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); if (value_ == nullptr) { return absl::OkStatus(); } Value scratch; const auto reflection = well_known_types::GetListValueReflectionOrDie(value_->GetDescriptor()); const int size = reflection.ValuesSize(*value_); for (int i = 0; i < size; ++i) { scratch = common_internal::ParsedJsonValue(&reflection.Values(*value_, i), arena); CEL_ASSIGN_OR_RETURN(auto ok, callback(static_cast(i), scratch)); if (!ok) { break; } } return absl::OkStatus(); } namespace { class ParsedJsonListValueIterator final : public ValueIterator { public: explicit ParsedJsonListValueIterator( const google::protobuf::Message* absl_nonnull message) : message_(message), reflection_(well_known_types::GetListValueReflectionOrDie( message_->GetDescriptor())), size_(reflection_.ValuesSize(*message_)) {} bool HasNext() override { return index_ < size_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (ABSL_PREDICT_FALSE(index_ >= size_)) { return absl::FailedPreconditionError( "`ValueIterator::Next` called after `ValueIterator::HasNext` " "returned false"); } *result = common_internal::ParsedJsonValue( &reflection_.Values(*message_, index_), arena); ++index_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (index_ >= size_) { return false; } *key_or_value = common_internal::ParsedJsonValue( &reflection_.Values(*message_, index_), arena); ++index_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (index_ >= size_) { return false; } if (value != nullptr) { *value = common_internal::ParsedJsonValue( &reflection_.Values(*message_, index_), arena); } *key = IntValue(index_); ++index_; return true; } private: const google::protobuf::Message* absl_nonnull const message_; const well_known_types::ListValueReflection reflection_; const int size_; int index_ = 0; }; } // namespace absl::StatusOr> ParsedJsonListValue::NewIterator() const { if (value_ == nullptr) { return NewEmptyValueIterator(); } return std::make_unique(value_); } namespace { absl::optional AsNumber(const Value& value) { if (auto int_value = value.AsInt(); int_value) { return internal::Number::FromInt64(*int_value); } if (auto uint_value = value.AsUint(); uint_value) { return internal::Number::FromUint64(*uint_value); } if (auto double_value = value.AsDouble(); double_value) { return internal::Number::FromDouble(*double_value); } return absl::nullopt; } } // namespace absl::Status ParsedJsonListValue::Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (value_ == nullptr) { *result = FalseValue(); return absl::OkStatus(); } if (ABSL_PREDICT_FALSE(other.IsError() || other.IsUnknown())) { *result = other; return absl::OkStatus(); } // Other must be comparable to `null`, `double`, `string`, `list`, or `map`. const auto reflection = well_known_types::GetListValueReflectionOrDie(value_->GetDescriptor()); if (reflection.ValuesSize(*value_) > 0) { const auto value_reflection = well_known_types::GetValueReflectionOrDie( reflection.GetValueDescriptor()); if (other.IsNull()) { for (const auto& element : reflection.Values(*value_)) { const auto element_kind_case = value_reflection.GetKindCase(element); if (element_kind_case == google::protobuf::Value::KIND_NOT_SET || element_kind_case == google::protobuf::Value::kNullValue) { *result = TrueValue(); return absl::OkStatus(); } } } else if (const auto other_value = other.AsBool(); other_value) { for (const auto& element : reflection.Values(*value_)) { if (value_reflection.GetKindCase(element) == google::protobuf::Value::kBoolValue && value_reflection.GetBoolValue(element) == *other_value) { *result = TrueValue(); return absl::OkStatus(); } } } else if (const auto other_value = AsNumber(other); other_value) { for (const auto& element : reflection.Values(*value_)) { if (value_reflection.GetKindCase(element) == google::protobuf::Value::kNumberValue && internal::Number::FromDouble( value_reflection.GetNumberValue(element)) == *other_value) { *result = TrueValue(); return absl::OkStatus(); } } } else if (const auto other_value = other.AsString(); other_value) { std::string scratch; for (const auto& element : reflection.Values(*value_)) { if (value_reflection.GetKindCase(element) == google::protobuf::Value::kStringValue && absl::visit( [&](const auto& alternative) -> bool { return *other_value == alternative; }, well_known_types::AsVariant( value_reflection.GetStringValue(element, scratch)))) { *result = TrueValue(); return absl::OkStatus(); } } } else if (const auto other_value = other.AsList(); other_value) { for (const auto& element : reflection.Values(*value_)) { if (value_reflection.GetKindCase(element) == google::protobuf::Value::kListValue) { CEL_RETURN_IF_ERROR(other_value->Equal( ParsedJsonListValue(&value_reflection.GetListValue(element), arena), descriptor_pool, message_factory, arena, result)); if (result->IsTrue()) { return absl::OkStatus(); } } } } else if (const auto other_value = other.AsMap(); other_value) { for (const auto& element : reflection.Values(*value_)) { if (value_reflection.GetKindCase(element) == google::protobuf::Value::kStructValue) { CEL_RETURN_IF_ERROR(other_value->Equal( ParsedJsonMapValue(&value_reflection.GetStructValue(element), arena), descriptor_pool, message_factory, arena, result)); if (result->IsTrue()) { return absl::OkStatus(); } } } } } *result = FalseValue(); return absl::OkStatus(); } bool operator==(const ParsedJsonListValue& lhs, const ParsedJsonListValue& rhs) { if (cel::to_address(lhs.value_) == cel::to_address(rhs.value_)) { return true; } if (cel::to_address(lhs.value_) == nullptr) { return rhs.IsEmpty(); } if (cel::to_address(rhs.value_) == nullptr) { return lhs.IsEmpty(); } return internal::JsonListEquals(*lhs.value_, *rhs.value_); } } // namespace cel ================================================ FILE: common/values/parsed_json_list_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_LIST_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_LIST_VALUE_H_ #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/memory.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_list_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class ValueIterator; class ParsedRepeatedFieldValue; namespace common_internal { absl::Status CheckWellKnownListValueMessage(const google::protobuf::Message& message); } // namespace common_internal // ParsedJsonListValue is a ListValue backed by the google.protobuf.ListValue // well known message type. class ParsedJsonListValue final : private common_internal::ListValueMixin { public: static constexpr ValueKind kKind = ValueKind::kList; static constexpr absl::string_view kName = "google.protobuf.ListValue"; using element_type = const google::protobuf::Message; ParsedJsonListValue( const google::protobuf::Message* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_(value), arena_(arena) { ABSL_DCHECK(value != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK_OK(CheckListValue(value_)); ABSL_DCHECK_OK(CheckArena(value_, arena_)); } // Constructs an empty `ParsedJsonListValue`. ParsedJsonListValue() = default; ParsedJsonListValue(const ParsedJsonListValue&) = default; ParsedJsonListValue(ParsedJsonListValue&&) = default; ParsedJsonListValue& operator=(const ParsedJsonListValue&) = default; ParsedJsonListValue& operator=(ParsedJsonListValue&&) = default; static ValueKind kind() { return kKind; } static absl::string_view GetTypeName() { return kName; } static ListType GetRuntimeType() { return JsonListType(); } const google::protobuf::Message& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(*this); return *value_; } const google::protobuf::Message* absl_nonnull operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(*this); return value_; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonArray(). absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Equal; bool IsZeroValue() const { return IsEmpty(); } ParsedJsonListValue Clone(google::protobuf::Arena* absl_nonnull arena) const; bool IsEmpty() const { return Size() == 0; } size_t Size() const; // See ListValueInterface::Get for documentation. absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Get; using ForEachCallback = typename CustomListValueInterface::ForEachCallback; using ForEachWithIndexCallback = typename CustomListValueInterface::ForEachWithIndexCallback; absl::Status ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; using ListValueMixin::ForEach; absl::StatusOr NewIterator() const; absl::Status Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Contains; explicit operator bool() const { return value_ != nullptr; } friend void swap(ParsedJsonListValue& lhs, ParsedJsonListValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); swap(lhs.arena_, rhs.arena_); } friend bool operator==(const ParsedJsonListValue& lhs, const ParsedJsonListValue& rhs); private: friend std::pointer_traits; friend class ParsedRepeatedFieldValue; friend class common_internal::ValueMixin; friend class common_internal::ListValueMixin; static absl::Status CheckListValue( const google::protobuf::Message* absl_nullable message) { return message == nullptr ? absl::OkStatus() : common_internal::CheckWellKnownListValueMessage(*message); } static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, google::protobuf::Arena* absl_nonnull arena) { if (message != nullptr && message->GetArena() != nullptr && message->GetArena() != arena) { return absl::InvalidArgumentError( "message arena must be the same as arena"); } return absl::OkStatus(); } const google::protobuf::Message* absl_nullable value_ = nullptr; google::protobuf::Arena* absl_nullable arena_ = nullptr; }; inline bool operator!=(const ParsedJsonListValue& lhs, const ParsedJsonListValue& rhs) { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, const ParsedJsonListValue& value) { return out << value.DebugString(); } } // namespace cel namespace std { template <> struct pointer_traits { using pointer = cel::ParsedJsonListValue; using element_type = typename cel::ParsedJsonListValue::element_type; using difference_type = ptrdiff_t; static element_type* to_address(const pointer& p) noexcept { return cel::to_address(p.value_); } }; } // namespace std #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_LIST_VALUE_H_ ================================================ FILE: common/values/parsed_json_list_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "internal/testing.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::cel::test::IsNullValue; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Optional; using ::testing::Pair; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; using ParsedJsonListValueTest = common_internal::ValueTest<>; TEST_F(ParsedJsonListValueTest, Kind) { EXPECT_EQ(ParsedJsonListValue::kind(), ParsedJsonListValue::kKind); EXPECT_EQ(ParsedJsonListValue::kind(), ValueKind::kList); } TEST_F(ParsedJsonListValueTest, GetTypeName) { EXPECT_EQ(ParsedJsonListValue::GetTypeName(), ParsedJsonListValue::kName); EXPECT_EQ(ParsedJsonListValue::GetTypeName(), "google.protobuf.ListValue"); } TEST_F(ParsedJsonListValueTest, GetRuntimeType) { EXPECT_EQ(ParsedJsonListValue::GetRuntimeType(), JsonListType()); } TEST_F(ParsedJsonListValueTest, DebugString_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_EQ(valid_value.DebugString(), "[]"); } TEST_F(ParsedJsonListValueTest, IsZeroValue_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_TRUE(valid_value.IsZeroValue()); } TEST_F(ParsedJsonListValueTest, SerializeTo_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); google::protobuf::io::CordOutputStream output; EXPECT_THAT( valid_value.SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), IsEmpty()); } TEST_F(ParsedJsonListValueTest, ConvertToJson_Dynamic) { auto json = DynamicParseTextProto(R"pb()pb"); ParsedJsonListValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_THAT(valid_value.ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(json)), IsOk()); EXPECT_THAT( *json, EqualsTextProto(R"pb(list_value: {})pb")); } TEST_F(ParsedJsonListValueTest, Equal_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_THAT(valid_value.Equal(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT( valid_value.Equal( ParsedJsonListValue( DynamicParseTextProto(R"pb()pb"), arena()), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(valid_value.Equal(ListValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } TEST_F(ParsedJsonListValueTest, Empty_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_TRUE(valid_value.IsEmpty()); } TEST_F(ParsedJsonListValueTest, Size_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_EQ(valid_value.Size(), 0); } TEST_F(ParsedJsonListValueTest, Get_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto( R"pb(values {} values { bool_value: true })pb"), arena()); EXPECT_THAT(valid_value.Get(0, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IsNullValue())); EXPECT_THAT(valid_value.Get(1, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( valid_value.Get(2, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); } TEST_F(ParsedJsonListValueTest, ForEach_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto( R"pb(values {} values { bool_value: true })pb"), arena()); { std::vector values; EXPECT_THAT(valid_value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(IsNullValue(), BoolValueIs(true))); } { std::vector values; EXPECT_THAT(valid_value.ForEach( [&](size_t, const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(IsNullValue(), BoolValueIs(true))); } } TEST_F(ParsedJsonListValueTest, NewIterator_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto( R"pb(values {} values { bool_value: true })pb"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IsNullValue())); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(ParsedJsonListValueTest, NewIterator1) { ParsedJsonListValue valid_value( DynamicParseTextProto( R"pb(values {} values { bool_value: true })pb"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsNullValue()))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(true)))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ParsedJsonListValueTest, NewIterator2) { ParsedJsonListValue valid_value( DynamicParseTextProto( R"pb(values {} values { bool_value: true })pb"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(IntValueIs(0), IsNullValue())))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(IntValueIs(1), BoolValueIs(true))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ParsedJsonListValueTest, Contains_Dynamic) { ParsedJsonListValue valid_value( DynamicParseTextProto( R"pb(values {} values { bool_value: true } values { number_value: 1.0 } values { string_value: "foo" } values { list_value: {} } values { struct_value: {} })pb"), arena()); EXPECT_THAT(valid_value.Contains(BytesValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(valid_value.Contains(NullValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(valid_value.Contains(BoolValue(false), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(valid_value.Contains(BoolValue(true), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(valid_value.Contains(DoubleValue(0.0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(valid_value.Contains(DoubleValue(1.0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(valid_value.Contains(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(valid_value.Contains(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(valid_value.Contains( ParsedJsonListValue( DynamicParseTextProto( R"pb(values {} values { bool_value: true } values { number_value: 1.0 } values { string_value: "foo" } values { list_value: {} } values { struct_value: {} })pb"), arena()), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(valid_value.Contains(ListValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( valid_value.Contains( ParsedJsonMapValue(DynamicParseTextProto( R"pb(fields { key: "foo" value: { bool_value: true } })pb"), arena()), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(valid_value.Contains(MapValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } } // namespace } // namespace cel ================================================ FILE: common/values/parsed_json_map_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/parsed_json_map_value.h" #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/allocator.h" #include "common/memory.h" #include "common/value.h" #include "common/values/parsed_json_value.h" #include "common/values/values.h" #include "internal/json.h" #include "internal/message_equality.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/map.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" namespace cel { using ::cel::well_known_types::ValueReflection; namespace common_internal { absl::Status CheckWellKnownStructMessage(const google::protobuf::Message& message) { return internal::CheckJsonMap(message); } } // namespace common_internal std::string ParsedJsonMapValue::DebugString() const { if (value_ == nullptr) { return "{}"; } return internal::JsonMapDebugString(*value_); } absl::Status ParsedJsonMapValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); if (value_ == nullptr) { return absl::OkStatus(); } if (!value_->SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( "failed to serialize message: google.protobuf.Struct"); } return absl::OkStatus(); } absl::Status ParsedJsonMapValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); auto* message = value_reflection.MutableStructValue(json); message->Clear(); if (value_ == nullptr) { return absl::OkStatus(); } if (value_->GetDescriptor() == message->GetDescriptor()) { // We can directly use google::protobuf::Message::Copy(). message->CopyFrom(*value_); } else { // Equivalent descriptors but not identical. Must serialize and deserialize. absl::Cord serialized; if (!value_->SerializePartialToString(&serialized)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", value_->GetTypeName())); } if (!message->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parsed message: ", message->GetTypeName())); } } return absl::OkStatus(); } absl::Status ParsedJsonMapValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); if (value_ == nullptr) { json->Clear(); return absl::OkStatus(); } if (value_->GetDescriptor() == json->GetDescriptor()) { // We can directly use google::protobuf::Message::Copy(). json->CopyFrom(*value_); } else { // Equivalent descriptors but not identical. Must serialize and deserialize. absl::Cord serialized; if (!value_->SerializePartialToString(&serialized)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", value_->GetTypeName())); } if (!json->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parsed message: ", json->GetTypeName())); } } return absl::OkStatus(); } absl::Status ParsedJsonMapValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (auto other_value = other.AsParsedJsonMap(); other_value) { *result = BoolValue(*this == *other_value); return absl::OkStatus(); } if (auto other_value = other.AsParsedMapField(); other_value) { if (value_ == nullptr) { *result = BoolValue(other_value->IsEmpty()); return absl::OkStatus(); } ABSL_DCHECK(other_value->field_ != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); CEL_ASSIGN_OR_RETURN( auto equal, internal::MessageFieldEquals( *value_, *other_value->message_, other_value->field_, descriptor_pool, message_factory)); *result = BoolValue(equal); return absl::OkStatus(); } if (auto other_value = other.AsMap(); other_value) { return common_internal::MapValueEqual(MapValue(*this), *other_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); return absl::OkStatus(); } ParsedJsonMapValue ParsedJsonMapValue::Clone( google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); if (value_ == nullptr) { return ParsedJsonMapValue(); } if (arena_ == arena) { return *this; } auto* cloned = value_->New(arena); cloned->CopyFrom(*value_); return ParsedJsonMapValue(cloned, arena); } size_t ParsedJsonMapValue::Size() const { if (value_ == nullptr) { return 0; } return static_cast( well_known_types::GetStructReflectionOrDie(value_->GetDescriptor()) .FieldsSize(*value_)); } absl::Status ParsedJsonMapValue::Get( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { CEL_ASSIGN_OR_RETURN( bool ok, Find(key, descriptor_pool, message_factory, arena, result)); if (ABSL_PREDICT_FALSE(!ok) && !(result->IsError() || result->IsUnknown())) { *result = NoSuchKeyError(key.DebugString()); } return absl::OkStatus(); } absl::StatusOr ParsedJsonMapValue::Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (key.IsError() || key.IsUnknown()) { *result = key; return false; } if (value_ != nullptr) { if (auto string_key = key.AsString(); string_key) { if (ABSL_PREDICT_FALSE(value_ == nullptr)) { *result = NullValue(); return false; } std::string key_scratch; if (const auto* value = well_known_types::GetStructReflectionOrDie( value_->GetDescriptor()) .FindField(*value_, string_key->NativeString(key_scratch)); value != nullptr) { *result = common_internal::ParsedJsonValue(value, arena); return true; } *result = NullValue(); return false; } } *result = NullValue(); return false; } absl::Status ParsedJsonMapValue::Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (key.IsError() || key.IsUnknown()) { *result = key; return absl::OkStatus(); } if (value_ != nullptr) { if (auto string_key = key.AsString(); string_key) { if (ABSL_PREDICT_FALSE(value_ == nullptr)) { *result = FalseValue(); return absl::OkStatus(); } std::string key_scratch; if (const auto* value = well_known_types::GetStructReflectionOrDie( value_->GetDescriptor()) .FindField(*value_, string_key->NativeString(key_scratch)); value != nullptr) { *result = TrueValue(); } else { *result = FalseValue(); } return absl::OkStatus(); } } *result = FalseValue(); return absl::OkStatus(); } absl::Status ParsedJsonMapValue::ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { if (value_ == nullptr) { *result = ListValue(); return absl::OkStatus(); } const auto reflection = well_known_types::GetStructReflectionOrDie(value_->GetDescriptor()); auto builder = NewListValueBuilder(arena); builder->Reserve(static_cast(reflection.FieldsSize(*value_))); auto keys_begin = reflection.BeginFields(*value_); const auto keys_end = reflection.EndFields(*value_); for (; keys_begin != keys_end; ++keys_begin) { CEL_RETURN_IF_ERROR(builder->Add( Value::WrapMapFieldKeyString(keys_begin.GetKey(), value_, arena))); } *result = std::move(*builder).Build(); return absl::OkStatus(); } absl::Status ParsedJsonMapValue::ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { if (value_ == nullptr) { return absl::OkStatus(); } const auto reflection = well_known_types::GetStructReflectionOrDie(value_->GetDescriptor()); Value key_scratch; Value value_scratch; auto map_begin = reflection.BeginFields(*value_); const auto map_end = reflection.EndFields(*value_); for (; map_begin != map_end; ++map_begin) { // We have to copy until `google::protobuf::MapKey` is just a view. key_scratch = StringValue(arena, map_begin.GetKey().GetStringValue()); value_scratch = common_internal::ParsedJsonValue( &map_begin.GetValueRef().GetMessageValue(), arena); CEL_ASSIGN_OR_RETURN(auto ok, callback(key_scratch, value_scratch)); if (!ok) { break; } } return absl::OkStatus(); } namespace { class ParsedJsonMapValueIterator final : public ValueIterator { public: explicit ParsedJsonMapValueIterator( const google::protobuf::Message* absl_nonnull message) : message_(message), reflection_(well_known_types::GetStructReflectionOrDie( message_->GetDescriptor())), begin_(reflection_.BeginFields(*message_)), end_(reflection_.EndFields(*message_)) {} bool HasNext() override { return begin_ != end_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (ABSL_PREDICT_FALSE(begin_ == end_)) { return absl::FailedPreconditionError( "`ValueIterator::Next` called after `ValueIterator::HasNext` " "returned false"); } *result = Value::WrapMapFieldKeyString(begin_.GetKey(), message_, arena); ++begin_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (begin_ == end_) { return false; } *key_or_value = Value::WrapMapFieldKeyString(begin_.GetKey(), message_, arena); ++begin_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (begin_ == end_) { return false; } *key = Value::WrapMapFieldKeyString(begin_.GetKey(), message_, arena); if (value != nullptr) { *value = common_internal::ParsedJsonValue( &begin_.GetValueRef().GetMessageValue(), arena); } ++begin_; return true; } private: const google::protobuf::Message* absl_nonnull const message_; const well_known_types::StructReflection reflection_; google::protobuf::ConstMapIterator begin_; const google::protobuf::ConstMapIterator end_; std::string scratch_; }; } // namespace absl::StatusOr> ParsedJsonMapValue::NewIterator() const { if (value_ == nullptr) { return NewEmptyValueIterator(); } return std::make_unique(value_); } bool operator==(const ParsedJsonMapValue& lhs, const ParsedJsonMapValue& rhs) { if (cel::to_address(lhs.value_) == cel::to_address(rhs.value_)) { return true; } if (cel::to_address(lhs.value_) == nullptr) { return rhs.IsEmpty(); } if (cel::to_address(rhs.value_) == nullptr) { return lhs.IsEmpty(); } return internal::JsonMapEquals(*lhs.value_, *rhs.value_); } } // namespace cel ================================================ FILE: common/values/parsed_json_map_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_MAP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_MAP_VALUE_H_ #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/memory.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_map_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class ListValue; class ValueIterator; class ParsedMapFieldValue; namespace common_internal { absl::Status CheckWellKnownStructMessage(const google::protobuf::Message& message); } // namespace common_internal // ParsedJsonMapValue is a MapValue backed by the google.protobuf.Struct // well known message type. class ParsedJsonMapValue final : private common_internal::MapValueMixin { public: static constexpr ValueKind kKind = ValueKind::kMap; static constexpr absl::string_view kName = "google.protobuf.Struct"; using element_type = const google::protobuf::Message; ParsedJsonMapValue( const google::protobuf::Message* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_(value), arena_(arena) { ABSL_DCHECK(value != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK_OK(CheckStruct(value_)); ABSL_DCHECK_OK(CheckArena(value_, arena_)); } // Constructs an empty `ParsedJsonMapValue`. ParsedJsonMapValue() = default; ParsedJsonMapValue(const ParsedJsonMapValue&) = default; ParsedJsonMapValue(ParsedJsonMapValue&&) = default; ParsedJsonMapValue& operator=(const ParsedJsonMapValue&) = default; ParsedJsonMapValue& operator=(ParsedJsonMapValue&&) = default; static constexpr ValueKind kind() { return kKind; } static absl::string_view GetTypeName() { return kName; } static MapType GetRuntimeType() { return JsonMapType(); } const google::protobuf::Message& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(*this); return *value_; } const google::protobuf::Message* absl_nonnull operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(*this); return value_; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonObject(). absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Equal; bool IsZeroValue() const { return IsEmpty(); } ParsedJsonMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const; bool IsEmpty() const { return Size() == 0; } size_t Size() const; // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Get; // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Has; // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::StatusOr> NewIterator() const; explicit operator bool() const { return value_ != nullptr; } friend void swap(ParsedJsonMapValue& lhs, ParsedJsonMapValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); swap(lhs.arena_, rhs.arena_); } friend bool operator==(const ParsedJsonMapValue& lhs, const ParsedJsonMapValue& rhs); private: friend std::pointer_traits; friend class ParsedMapFieldValue; friend class common_internal::ValueMixin; friend class common_internal::MapValueMixin; static absl::Status CheckStruct( const google::protobuf::Message* absl_nullable message) { return message == nullptr ? absl::OkStatus() : common_internal::CheckWellKnownStructMessage(*message); } static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, google::protobuf::Arena* absl_nonnull arena) { if (message != nullptr && message->GetArena() != nullptr && message->GetArena() != arena) { return absl::InvalidArgumentError( "message arena must be the same as arena"); } return absl::OkStatus(); } const google::protobuf::Message* absl_nullable value_ = nullptr; google::protobuf::Arena* absl_nullable arena_ = nullptr; }; inline bool operator!=(const ParsedJsonMapValue& lhs, const ParsedJsonMapValue& rhs) { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, const ParsedJsonMapValue& value) { return out << value.DebugString(); } } // namespace cel namespace std { template <> struct pointer_traits { using pointer = cel::ParsedJsonMapValue; using element_type = typename cel::ParsedJsonMapValue::element_type; using difference_type = ptrdiff_t; static element_type* to_address(const pointer& p) noexcept { return cel::to_address(p.value_); } }; } // namespace std #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_MAP_VALUE_H_ ================================================ FILE: common/values/parsed_json_map_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "internal/testing.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IsNullValue; using ::cel::test::StringValueIs; using ::testing::AnyOf; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Optional; using ::testing::Pair; using ::testing::UnorderedElementsAre; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; using ParsedJsonMapValueTest = common_internal::ValueTest<>; TEST_F(ParsedJsonMapValueTest, Kind) { EXPECT_EQ(ParsedJsonMapValue::kind(), ParsedJsonMapValue::kKind); EXPECT_EQ(ParsedJsonMapValue::kind(), ValueKind::kMap); } TEST_F(ParsedJsonMapValueTest, GetTypeName) { EXPECT_EQ(ParsedJsonMapValue::GetTypeName(), ParsedJsonMapValue::kName); EXPECT_EQ(ParsedJsonMapValue::GetTypeName(), "google.protobuf.Struct"); } TEST_F(ParsedJsonMapValueTest, GetRuntimeType) { EXPECT_EQ(ParsedJsonMapValue::GetRuntimeType(), JsonMapType()); } TEST_F(ParsedJsonMapValueTest, DebugString_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_EQ(valid_value.DebugString(), "{}"); } TEST_F(ParsedJsonMapValueTest, IsZeroValue_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_TRUE(valid_value.IsZeroValue()); } TEST_F(ParsedJsonMapValueTest, SerializeTo_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); google::protobuf::io::CordOutputStream output; EXPECT_THAT( valid_value.SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), IsEmpty()); } TEST_F(ParsedJsonMapValueTest, ConvertToJson_Dynamic) { auto json = DynamicParseTextProto(R"pb()pb"); ParsedJsonMapValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_THAT(valid_value.ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(json)), IsOk()); EXPECT_THAT(*json, EqualsTextProto( R"pb(struct_value: {})pb")); } TEST_F(ParsedJsonMapValueTest, Equal_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_THAT(valid_value.Equal(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT( valid_value.Equal( ParsedJsonMapValue( DynamicParseTextProto(R"pb()pb"), arena()), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(valid_value.Equal(MapValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } TEST_F(ParsedJsonMapValueTest, Empty_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_TRUE(valid_value.IsEmpty()); } TEST_F(ParsedJsonMapValueTest, Size_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto(R"pb()pb"), arena()); EXPECT_EQ(valid_value.Size(), 0); } TEST_F(ParsedJsonMapValueTest, Get_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto( R"pb(fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb"), arena()); EXPECT_THAT( valid_value.Get(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); EXPECT_THAT(valid_value.Get(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IsNullValue())); EXPECT_THAT(valid_value.Get(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( valid_value.Get(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); } TEST_F(ParsedJsonMapValueTest, Find_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto( R"pb(fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb"), arena()); EXPECT_THAT(valid_value.Find(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(valid_value.Find(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsNullValue()))); EXPECT_THAT(valid_value.Find(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(true)))); EXPECT_THAT(valid_value.Find(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ParsedJsonMapValueTest, Has_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto( R"pb(fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb"), arena()); EXPECT_THAT(valid_value.Has(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(valid_value.Has(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(valid_value.Has(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(valid_value.Has(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } TEST_F(ParsedJsonMapValueTest, ListKeys_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto( R"pb(fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb"), arena()); ASSERT_OK_AND_ASSIGN( auto keys, valid_value.ListKeys(descriptor_pool(), message_factory(), arena())); EXPECT_THAT(keys.Size(), IsOkAndHolds(2)); EXPECT_THAT(keys.DebugString(), AnyOf("[\"foo\", \"bar\"]", "[\"bar\", \"foo\"]")); EXPECT_THAT( keys.Contains(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(keys.Contains(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(keys.Get(0, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); EXPECT_THAT(keys.Get(1, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); } TEST_F(ParsedJsonMapValueTest, ForEach_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto( R"pb(fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb"), arena()); std::vector> entries; EXPECT_THAT( valid_value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre( Pair(StringValueIs("foo"), IsNullValue()), Pair(StringValueIs("bar"), BoolValueIs(true)))); } TEST_F(ParsedJsonMapValueTest, NewIterator_Dynamic) { ParsedJsonMapValue valid_value( DynamicParseTextProto( R"pb(fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); ASSERT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(ParsedJsonMapValueTest, NewIterator1) { ParsedJsonMapValue valid_value( DynamicParseTextProto( R"pb(fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds( Optional(AnyOf(StringValueIs("foo"), StringValueIs("bar"))))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds( Optional(AnyOf(StringValueIs("foo"), StringValueIs("bar"))))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ParsedJsonMapValueTest, NewIterator2) { ParsedJsonMapValue valid_value( DynamicParseTextProto( R"pb(fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, valid_value.NewIterator()); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional( AnyOf(Pair(StringValueIs("foo"), IsNullValue()), Pair(StringValueIs("bar"), BoolValueIs(true)))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional( AnyOf(Pair(StringValueIs("foo"), IsNullValue()), Pair(StringValueIs("bar"), BoolValueIs(true)))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } } // namespace } // namespace cel ================================================ FILE: common/values/parsed_json_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/parsed_json_value.h" #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "common/allocator.h" #include "common/memory.h" #include "common/value.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::common_internal { namespace { using ::cel::well_known_types::AsVariant; using ::cel::well_known_types::GetValueReflectionOrDie; google::protobuf::Arena* absl_nonnull MessageArenaOr( const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull or_arena) { google::protobuf::Arena* absl_nullable arena = message->GetArena(); if (arena == nullptr) { arena = or_arena; } return arena; } } // namespace Value ParsedJsonValue(const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull arena) { const auto reflection = GetValueReflectionOrDie(message->GetDescriptor()); const auto kind_case = reflection.GetKindCase(*message); switch (kind_case) { case google::protobuf::Value::KIND_NOT_SET: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Value::kNullValue: return NullValue(); case google::protobuf::Value::kBoolValue: return BoolValue(reflection.GetBoolValue(*message)); case google::protobuf::Value::kNumberValue: return DoubleValue(reflection.GetNumberValue(*message)); case google::protobuf::Value::kStringValue: { std::string scratch; return absl::visit( absl::Overload( [&](absl::string_view string) -> StringValue { if (string.empty()) { return StringValue(); } if (string.data() == scratch.data() && string.size() == scratch.size()) { return StringValue(arena, std::move(scratch)); } else { return StringValue( Borrower::Arena(MessageArenaOr(message, arena)), string); } }, [&](absl::Cord&& cord) -> StringValue { if (cord.empty()) { return StringValue(); } return StringValue(std::move(cord)); }), AsVariant(reflection.GetStringValue(*message, scratch))); } case google::protobuf::Value::kListValue: return ParsedJsonListValue(&reflection.GetListValue(*message), MessageArenaOr(message, arena)); case google::protobuf::Value::kStructValue: return ParsedJsonMapValue(&reflection.GetStructValue(*message), MessageArenaOr(message, arena)); default: return ErrorValue(absl::InvalidArgumentError( absl::StrCat("unexpected value kind case: ", kind_case))); } } } // namespace cel::common_internal ================================================ FILE: common/values/parsed_json_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_VALUE_H_ #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel { class Value; namespace common_internal { // Adapts the given instance of the well known message type // `google.protobuf.Value` to `cel::Value`. If the underlying value is a string // and the string had to be copied, `allocator` will be used to create a new // string value. This should be rare and unlikely. Value ParsedJsonValue(const google::protobuf::Message* absl_nonnull message, google::protobuf::Arena* absl_nonnull arena); } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_JSON_VALUE_H_ ================================================ FILE: common/values/parsed_json_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/parsed_json_value.h" #include "google/protobuf/struct.pb.h" #include "absl/strings/string_view.h" #include "common/value_testing.h" #include "internal/testing.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" namespace cel::common_internal { namespace { using ::cel::test::BoolValueIs; using ::cel::test::DoubleValueIs; using ::cel::test::IsNullValue; using ::cel::test::ListValueElements; using ::cel::test::ListValueIs; using ::cel::test::MapValueElements; using ::cel::test::MapValueIs; using ::cel::test::StringValueIs; using ::testing::ElementsAre; using ::testing::Pair; using ::testing::UnorderedElementsAre; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; using ParsedJsonValueTest = common_internal::ValueTest<>; TEST_F(ParsedJsonValueTest, Null_Dynamic) { EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( R"pb(null_value: NULL_VALUE)pb"), arena()), IsNullValue()); EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( R"pb(null_value: NULL_VALUE)pb"), arena()), IsNullValue()); } TEST_F(ParsedJsonValueTest, Bool_Dynamic) { EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( R"pb(bool_value: true)pb"), arena()), BoolValueIs(true)); } TEST_F(ParsedJsonValueTest, Double_Dynamic) { EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( R"pb(number_value: 1.0)pb"), arena()), DoubleValueIs(1.0)); } TEST_F(ParsedJsonValueTest, String_Dynamic) { EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( R"pb(string_value: "foo")pb"), arena()), StringValueIs("foo")); } TEST_F(ParsedJsonValueTest, List_Dynamic) { EXPECT_THAT(ParsedJsonValue(DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"), arena()), ListValueIs(ListValueElements( ElementsAre(IsNullValue(), BoolValueIs(true)), descriptor_pool(), message_factory(), arena()))); } TEST_F(ParsedJsonValueTest, Map_Dynamic) { EXPECT_THAT( ParsedJsonValue(DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"), arena()), MapValueIs(MapValueElements( UnorderedElementsAre(Pair(StringValueIs("foo"), IsNullValue()), Pair(StringValueIs("bar"), BoolValueIs(true))), descriptor_pool(), message_factory(), arena()))); } } // namespace } // namespace cel::common_internal ================================================ FILE: common/values/parsed_map_field_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/parsed_map_field_value.h" #include #include #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "common/value.h" #include "common/values/values.h" #include "extensions/protobuf/internal/map_reflection.h" #include "internal/json.h" #include "internal/message_equality.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" namespace cel { using ::cel::well_known_types::ValueReflection; std::string ParsedMapFieldValue::DebugString() const { if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return "INVALID"; } return "VALID"; } absl::Status ParsedMapFieldValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return absl::OkStatus(); } // We have to convert to google.protobuf.Struct first. google::protobuf::Value message; CEL_RETURN_IF_ERROR(internal::MessageFieldToJson( *message_, field_, descriptor_pool, message_factory, &message)); if (!message.list_value().SerializePartialToZeroCopyStream(output)) { return absl::UnknownError("failed to serialize google.protobuf.Struct"); } return absl::OkStatus(); } absl::Status ParsedMapFieldValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.MutableStructValue(json)->Clear(); return absl::OkStatus(); } return internal::MessageFieldToJson(*message_, field_, descriptor_pool, message_factory, json); } absl::Status ParsedMapFieldValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { json->Clear(); return absl::OkStatus(); } return internal::MessageFieldToJson(*message_, field_, descriptor_pool, message_factory, json); } absl::Status ParsedMapFieldValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (auto other_value = other.AsParsedMapField(); other_value) { ABSL_DCHECK(field_ != nullptr); ABSL_DCHECK(other_value->field_ != nullptr); CEL_ASSIGN_OR_RETURN( auto equal, internal::MessageFieldEquals( *message_, field_, *other_value->message_, other_value->field_, descriptor_pool, message_factory)); *result = BoolValue(equal); return absl::OkStatus(); } if (auto other_value = other.AsParsedJsonMap(); other_value) { if (other_value->value_ == nullptr) { *result = BoolValue(IsEmpty()); return absl::OkStatus(); } ABSL_DCHECK(field_ != nullptr); CEL_ASSIGN_OR_RETURN( auto equal, internal::MessageFieldEquals(*message_, field_, *other_value->value_, descriptor_pool, message_factory)); *result = BoolValue(equal); return absl::OkStatus(); } if (auto other_value = other.AsMap(); other_value) { return common_internal::MapValueEqual(MapValue(*this), *other_value, descriptor_pool, message_factory, arena, result); } *result = BoolValue(false); return absl::OkStatus(); } bool ParsedMapFieldValue::IsZeroValue() const { return IsEmpty(); } ParsedMapFieldValue ParsedMapFieldValue::Clone( google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return ParsedMapFieldValue(); } if (arena_ == arena) { return *this; } auto field = message_->GetReflection()->GetRepeatedFieldRef( *message_, field_); auto* cloned = message_->New(arena); auto cloned_field = cloned->GetReflection()->GetMutableRepeatedFieldRef( cloned, field_); cloned_field.CopyFrom(field); return ParsedMapFieldValue(cloned, field_, arena); } bool ParsedMapFieldValue::IsEmpty() const { return Size() == 0; } size_t ParsedMapFieldValue::Size() const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return 0; } return static_cast(extensions::protobuf_internal::MapSize( *GetReflection(), *message_, *field_)); } namespace { absl::optional ValueAsInt32(const Value& value) { if (auto int_value = value.AsInt(); int_value && int_value->NativeValue() >= std::numeric_limits::min() && int_value->NativeValue() <= std::numeric_limits::max()) { return static_cast(int_value->NativeValue()); } else if (auto uint_value = value.AsUint(); uint_value && uint_value->NativeValue() <= std::numeric_limits::max()) { return static_cast(uint_value->NativeValue()); } else if (auto double_value = value.AsDouble(); double_value && static_cast(static_cast( double_value->NativeValue())) == double_value->NativeValue()) { return static_cast(double_value->NativeValue()); } return absl::nullopt; } absl::optional ValueAsInt64(const Value& value) { if (auto int_value = value.AsInt(); int_value) { return int_value->NativeValue(); } else if (auto uint_value = value.AsUint(); uint_value && uint_value->NativeValue() <= std::numeric_limits::max()) { return static_cast(uint_value->NativeValue()); } else if (auto double_value = value.AsDouble(); double_value && static_cast(static_cast( double_value->NativeValue())) == double_value->NativeValue()) { return static_cast(double_value->NativeValue()); } return absl::nullopt; } absl::optional ValueAsUInt32(const Value& value) { if (auto int_value = value.AsInt(); int_value && int_value->NativeValue() >= 0 && int_value->NativeValue() <= std::numeric_limits::max()) { return static_cast(int_value->NativeValue()); } else if (auto uint_value = value.AsUint(); uint_value && uint_value->NativeValue() <= std::numeric_limits::max()) { return static_cast(uint_value->NativeValue()); } else if (auto double_value = value.AsDouble(); double_value && static_cast(static_cast( double_value->NativeValue())) == double_value->NativeValue()) { return static_cast(double_value->NativeValue()); } return absl::nullopt; } absl::optional ValueAsUInt64(const Value& value) { if (auto int_value = value.AsInt(); int_value && int_value->NativeValue() >= 0) { return static_cast(int_value->NativeValue()); } else if (auto uint_value = value.AsUint(); uint_value) { return uint_value->NativeValue(); } else if (auto double_value = value.AsDouble(); double_value && static_cast(static_cast( double_value->NativeValue())) == double_value->NativeValue()) { return static_cast(double_value->NativeValue()); } return absl::nullopt; } bool ValueToProtoMapKey(const Value& key, google::protobuf::FieldDescriptor::CppType cpp_type, google::protobuf::MapKey* absl_nonnull proto_key, std::string& proto_key_scratch) { switch (cpp_type) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { if (auto bool_key = key.AsBool(); bool_key) { proto_key->SetBoolValue(bool_key->NativeValue()); return true; } return false; } case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { if (auto int_key = ValueAsInt32(key); int_key) { proto_key->SetInt32Value(*int_key); return true; } return false; } case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { if (auto int_key = ValueAsInt64(key); int_key) { proto_key->SetInt64Value(*int_key); return true; } return false; } case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { if (auto int_key = ValueAsUInt32(key); int_key) { proto_key->SetUInt32Value(*int_key); return true; } return false; } case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { if (auto int_key = ValueAsUInt64(key); int_key) { proto_key->SetUInt64Value(*int_key); return true; } return false; } case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { if (auto string_key = key.AsString(); string_key) { proto_key_scratch = string_key->NativeString(); proto_key->SetStringValue(proto_key_scratch); return true; } return false; } default: // protobuf map keys can only be bool, integrals, or string. return false; } } } // namespace absl::Status ParsedMapFieldValue::Get( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { CEL_ASSIGN_OR_RETURN( bool ok, Find(key, descriptor_pool, message_factory, arena, result)); if (ABSL_PREDICT_FALSE(!ok) && !(result->IsError() || result->IsUnknown())) { *result = ErrorValue(NoSuchKeyError(key.DebugString())); } return absl::OkStatus(); } absl::StatusOr ParsedMapFieldValue::Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(*this); ABSL_DCHECK(message_ != nullptr); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { *result = NullValue(); return false; } if (key.IsError() || key.IsUnknown()) { *result = key; return false; } const google::protobuf::Descriptor* absl_nonnull entry_descriptor = field_->message_type(); const google::protobuf::FieldDescriptor* absl_nonnull key_field = entry_descriptor->map_key(); const google::protobuf::FieldDescriptor* absl_nonnull value_field = entry_descriptor->map_value(); std::string proto_key_scratch; google::protobuf::MapKey proto_key; if (!ValueToProtoMapKey(key, key_field->cpp_type(), &proto_key, proto_key_scratch)) { *result = NullValue(); return false; } google::protobuf::MapValueConstRef proto_value; if (!extensions::protobuf_internal::LookupMapValue( *GetReflection(), *message_, *field_, proto_key, &proto_value)) { *result = NullValue(); return false; } if (arena_ == nullptr) { *result = Value::WrapMapFieldValueUnsafe(proto_value, message_, value_field, descriptor_pool, message_factory, arena); } else { *result = Value::WrapMapFieldValue(proto_value, message_, value_field, descriptor_pool, message_factory, arena); } return true; } absl::Status ParsedMapFieldValue::Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { *result = BoolValue(false); return absl::OkStatus(); } const google::protobuf::FieldDescriptor* absl_nonnull key_field = field_->message_type()->map_key(); std::string proto_key_scratch; google::protobuf::MapKey proto_key; bool bool_result; if (ValueToProtoMapKey(key, key_field->cpp_type(), &proto_key, proto_key_scratch)) { google::protobuf::MapValueConstRef proto_value; bool_result = extensions::protobuf_internal::LookupMapValue( *GetReflection(), *message_, *field_, proto_key, &proto_value); } else { bool_result = false; } *result = BoolValue(bool_result); return absl::OkStatus(); } absl::Status ParsedMapFieldValue::ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const { ABSL_DCHECK(*this); if (field_ == nullptr) { *result = ListValue(); return absl::OkStatus(); } const auto* reflection = message_->GetReflection(); if (reflection->FieldSize(*message_, field_) == 0) { *result = ListValue(); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto key_accessor, common_internal::MapFieldKeyAccessorFor( field_->message_type()->map_key())); auto builder = NewListValueBuilder(arena); builder->Reserve(Size()); auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, *message_, *field_); const auto end = extensions::protobuf_internal::ConstMapEnd( *reflection, *message_, *field_); for (; begin != end; ++begin) { Value scratch; (*key_accessor)(begin.GetKey(), message_, arena, &scratch); CEL_RETURN_IF_ERROR(builder->Add(std::move(scratch))); } *result = std::move(*builder).Build(); return absl::OkStatus(); } absl::Status ParsedMapFieldValue::ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(*this); if (field_ == nullptr) { return absl::OkStatus(); } const auto* reflection = message_->GetReflection(); if (reflection->FieldSize(*message_, field_) > 0) { const auto* value_field = field_->message_type()->map_value(); CEL_ASSIGN_OR_RETURN(auto key_accessor, common_internal::MapFieldKeyAccessorFor( field_->message_type()->map_key())); CEL_ASSIGN_OR_RETURN( auto value_accessor, common_internal::MapFieldValueAccessorFor(value_field)); auto begin = extensions::protobuf_internal::ConstMapBegin( *reflection, *message_, *field_); const auto end = extensions::protobuf_internal::ConstMapEnd( *reflection, *message_, *field_); Value key_scratch; Value value_scratch; for (; begin != end; ++begin) { (*key_accessor)(begin.GetKey(), message_, arena, &key_scratch); (*value_accessor)(begin.GetValueRef(), message_, value_field, descriptor_pool, message_factory, arena, &value_scratch); CEL_ASSIGN_OR_RETURN(auto ok, callback(key_scratch, value_scratch)); if (!ok) { break; } } } return absl::OkStatus(); } namespace { class ParsedMapFieldValueIterator final : public ValueIterator { public: ParsedMapFieldValueIterator( const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, absl_nonnull common_internal::MapFieldKeyAccessor key_accessor, absl_nonnull common_internal::MapFieldValueAccessor value_accessor) : message_(message), value_field_(field->message_type()->map_value()), key_accessor_(key_accessor), value_accessor_(value_accessor), begin_(extensions::protobuf_internal::ConstMapBegin( *message_->GetReflection(), *message_, *field)), end_(extensions::protobuf_internal::ConstMapEnd( *message_->GetReflection(), *message_, *field)) {} bool HasNext() override { return begin_ != end_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (ABSL_PREDICT_FALSE(begin_ == end_)) { return absl::FailedPreconditionError( "ValueIterator::Next called after ValueIterator::HasNext returned " "false"); } (*key_accessor_)(begin_.GetKey(), message_, arena, result); ++begin_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (begin_ == end_) { return false; } (*key_accessor_)(begin_.GetKey(), message_, arena, key_or_value); ++begin_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (begin_ == end_) { return false; } (*key_accessor_)(begin_.GetKey(), message_, arena, key); if (value != nullptr) { (*value_accessor_)(begin_.GetValueRef(), message_, value_field_, descriptor_pool, message_factory, arena, value); } ++begin_; return true; } private: const google::protobuf::Message* absl_nonnull const message_; const google::protobuf::FieldDescriptor* absl_nonnull const value_field_; const absl_nonnull common_internal::MapFieldKeyAccessor key_accessor_; const absl_nonnull common_internal::MapFieldValueAccessor value_accessor_; google::protobuf::ConstMapIterator begin_; const google::protobuf::ConstMapIterator end_; }; } // namespace absl::StatusOr> ParsedMapFieldValue::NewIterator() const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return NewEmptyValueIterator(); } CEL_ASSIGN_OR_RETURN(auto key_accessor, common_internal::MapFieldKeyAccessorFor( field_->message_type()->map_key())); CEL_ASSIGN_OR_RETURN(auto value_accessor, common_internal::MapFieldValueAccessorFor( field_->message_type()->map_value())); return std::make_unique( message_, field_, key_accessor, value_accessor); } const google::protobuf::Reflection* absl_nonnull ParsedMapFieldValue::GetReflection() const { return message_->GetReflection(); } } // namespace cel ================================================ FILE: common/values/parsed_map_field_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_FIELD_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_FIELD_VALUE_H_ #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_map_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class ValueIterator; class ListValue; class ParsedJsonMapValue; // ParsedMapFieldValue is a MapValue over a map field of a parsed protocol // buffer message. class ParsedMapFieldValue final : private common_internal::MapValueMixin { public: static constexpr ValueKind kKind = ValueKind::kMap; static constexpr absl::string_view kName = "map"; ParsedMapFieldValue(const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::Arena* absl_nonnull arena) : message_(message), field_(field), arena_(arena) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(field_->is_map()) << field_->full_name() << " must be a map field"; ABSL_DCHECK_OK(CheckArena(message_, arena_)); } // Places the `ParsedMapFieldValue` into an invalid state. Anything // except assigning to `ParsedMapFieldValue` is undefined behavior. ParsedMapFieldValue() = default; ParsedMapFieldValue(const ParsedMapFieldValue&) = default; ParsedMapFieldValue(ParsedMapFieldValue&&) = default; ParsedMapFieldValue& operator=(const ParsedMapFieldValue&) = default; ParsedMapFieldValue& operator=(ParsedMapFieldValue&&) = default; static constexpr ValueKind kind() { return kKind; } static constexpr absl::string_view GetTypeName() { return kName; } static MapType GetRuntimeType() { return MapType(); } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonObject(). absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Equal; bool IsZeroValue() const; ParsedMapFieldValue Clone(google::protobuf::Arena* absl_nonnull arena) const; bool IsEmpty() const; size_t Size() const; // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Get; // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Has; // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::StatusOr> NewIterator() const; const google::protobuf::Message& message() const { ABSL_DCHECK(*this); return *message_; } const google::protobuf::FieldDescriptor* absl_nonnull field() const { ABSL_DCHECK(*this); return field_; } // Returns `true` if `ParsedMapFieldValue` is in a valid state. explicit operator bool() const { return field_ != nullptr; } friend void swap(ParsedMapFieldValue& lhs, ParsedMapFieldValue& rhs) noexcept { using std::swap; swap(lhs.message_, rhs.message_); swap(lhs.field_, rhs.field_); swap(lhs.arena_, rhs.arena_); } private: friend class ParsedJsonMapValue; friend class common_internal::ValueMixin; friend class common_internal::MapValueMixin; friend ParsedMapFieldValue UnsafeParsedMapFieldValue( const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field); ParsedMapFieldValue(const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field) : message_(message), field_(field), arena_(message->GetArena()) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(field_->is_map()) << field_->full_name() << " must be a map field"; } static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, google::protobuf::Arena* absl_nonnull arena) { if (message != nullptr && message->GetArena() != nullptr && message->GetArena() != arena) { return absl::InvalidArgumentError( "message arena must be the same as arena"); } return absl::OkStatus(); } const google::protobuf::Reflection* absl_nonnull GetReflection() const; const google::protobuf::Message* absl_nullable message_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable field_ = nullptr; google::protobuf::Arena* absl_nullable arena_ = nullptr; }; // Creates a `ParsedMapFieldValue` without specifying a managing arena. // The message must outlive the `ParsedMapFieldValue` or any value that // might be derived from it. Prefer to use // `cel::Value::WrapMapFieldValueUnsafe()`. inline ParsedMapFieldValue UnsafeParsedMapFieldValue( const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field) { return ParsedMapFieldValue(message, field); } inline std::ostream& operator<<(std::ostream& out, const ParsedMapFieldValue& value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MAP_FIELD_VALUE_H_ ================================================ FILE: common/values/parsed_map_field_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "internal/testing.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::BytesValueIs; using ::cel::test::DoubleValueIs; using ::cel::test::DurationValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::cel::test::IsNullValue; using ::cel::test::StringValueIs; using ::cel::test::UintValueIs; using ::testing::_; using ::testing::AnyOf; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Optional; using ::testing::Pair; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; using ParsedMapFieldValueTest = common_internal::ValueTest<>; TEST_F(ParsedMapFieldValueTest, Field) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_TRUE(value); } TEST_F(ParsedMapFieldValueTest, Kind) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_EQ(value.kind(), ParsedMapFieldValue::kKind); EXPECT_EQ(value.kind(), ValueKind::kMap); } TEST_F(ParsedMapFieldValueTest, GetTypeName) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_EQ(value.GetTypeName(), ParsedMapFieldValue::kName); EXPECT_EQ(value.GetTypeName(), "map"); } TEST_F(ParsedMapFieldValueTest, GetRuntimeType) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_EQ(value.GetRuntimeType(), MapType()); } TEST_F(ParsedMapFieldValueTest, DebugString) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_THAT(value.DebugString(), _); } TEST_F(ParsedMapFieldValueTest, IsZeroValue) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_TRUE(value.IsZeroValue()); } TEST_F(ParsedMapFieldValueTest, SerializeTo) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); google::protobuf::io::CordOutputStream output; EXPECT_THAT(value.SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), IsEmpty()); } TEST_F(ParsedMapFieldValueTest, ConvertToJson) { auto json = DynamicParseTextProto(R"pb()pb"); ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_THAT(value.ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(json)), IsOk()); EXPECT_THAT(*json, EqualsTextProto( R"pb(struct_value: {})pb")); } TEST_F(ParsedMapFieldValueTest, Equal_MapField) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_THAT( value.Equal(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT( value.Equal( ParsedMapFieldValue( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int32_int32"), arena()), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( value.Equal(MapValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } TEST_F(ParsedMapFieldValueTest, Equal_JsonMap) { ParsedMapFieldValue map_value( DynamicParseTextProto( R"pb(map_string_string { key: "foo" value: "bar" } map_string_string { key: "bar" value: "foo" })pb"), DynamicGetField("map_string_string"), arena()); ParsedJsonMapValue json_value(DynamicParseTextProto( R"pb( fields { key: "foo" value { string_value: "bar" } } fields { key: "bar" value { string_value: "foo" } } )pb"), arena()); EXPECT_THAT(map_value.Equal(json_value, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(json_value.Equal(map_value, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } TEST_F(ParsedMapFieldValueTest, Empty) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_TRUE(value.IsEmpty()); } TEST_F(ParsedMapFieldValueTest, Size) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("map_int64_int64"), arena()); EXPECT_EQ(value.Size(), 0); } TEST_F(ParsedMapFieldValueTest, Get) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bool { key: "foo" value: false } map_string_bool { key: "bar" value: true } )pb"), DynamicGetField("map_string_bool"), arena()); EXPECT_THAT( value.Get(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); EXPECT_THAT(value.Get(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Get(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( value.Get(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound)))); } TEST_F(ParsedMapFieldValueTest, Find) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bool { key: "foo" value: false } map_string_bool { key: "bar" value: true } )pb"), DynamicGetField("map_string_bool"), arena()); EXPECT_THAT( value.Find(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(value.Find(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(false)))); EXPECT_THAT(value.Find(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(true)))); EXPECT_THAT(value.Find(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ParsedMapFieldValueTest, Has) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bool { key: "foo" value: false } map_string_bool { key: "bar" value: true } )pb"), DynamicGetField("map_string_bool"), arena()); EXPECT_THAT( value.Has(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Has(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(value.Has(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(value.Has(StringValue("baz"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } TEST_F(ParsedMapFieldValueTest, ListKeys) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bool { key: "foo" value: false } map_string_bool { key: "bar" value: true } )pb"), DynamicGetField("map_string_bool"), arena()); ASSERT_OK_AND_ASSIGN( auto keys, value.ListKeys(descriptor_pool(), message_factory(), arena())); EXPECT_THAT(keys.Size(), IsOkAndHolds(2)); EXPECT_THAT(keys.DebugString(), AnyOf("[\"foo\", \"bar\"]", "[\"bar\", \"foo\"]")); EXPECT_THAT( keys.Contains(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(keys.Contains(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(keys.Get(0, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); EXPECT_THAT(keys.Get(1, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); } TEST_F(ParsedMapFieldValueTest, ForEach_StringBool) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bool { key: "foo" value: false } map_string_bool { key: "bar" value: true } )pb"), DynamicGetField("map_string_bool"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre( Pair(StringValueIs("foo"), BoolValueIs(false)), Pair(StringValueIs("bar"), BoolValueIs(true)))); } TEST_F(ParsedMapFieldValueTest, ForEach_Int32Double) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_int32_double { key: 1 value: 2 } map_int32_double { key: 2 value: 1 } )pb"), DynamicGetField("map_int32_double"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre(Pair(IntValueIs(1), DoubleValueIs(2)), Pair(IntValueIs(2), DoubleValueIs(1)))); } TEST_F(ParsedMapFieldValueTest, ForEach_Int64Float) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_int64_float { key: 1 value: 2 } map_int64_float { key: 2 value: 1 } )pb"), DynamicGetField("map_int64_float"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre(Pair(IntValueIs(1), DoubleValueIs(2)), Pair(IntValueIs(2), DoubleValueIs(1)))); } TEST_F(ParsedMapFieldValueTest, ForEach_UInt32UInt64) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_uint32_uint64 { key: 1 value: 2 } map_uint32_uint64 { key: 2 value: 1 } )pb"), DynamicGetField("map_uint32_uint64"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre(Pair(UintValueIs(1), UintValueIs(2)), Pair(UintValueIs(2), UintValueIs(1)))); } TEST_F(ParsedMapFieldValueTest, ForEach_UInt64Int32) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_uint64_int32 { key: 1 value: 2 } map_uint64_int32 { key: 2 value: 1 } )pb"), DynamicGetField("map_uint64_int32"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre(Pair(UintValueIs(1), IntValueIs(2)), Pair(UintValueIs(2), IntValueIs(1)))); } TEST_F(ParsedMapFieldValueTest, ForEach_BoolUInt32) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_bool_uint32 { key: true value: 2 } map_bool_uint32 { key: false value: 1 } )pb"), DynamicGetField("map_bool_uint32"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre(Pair(BoolValueIs(true), UintValueIs(2)), Pair(BoolValueIs(false), UintValueIs(1)))); } TEST_F(ParsedMapFieldValueTest, ForEach_StringString) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_string { key: "foo" value: "bar" } map_string_string { key: "bar" value: "foo" } )pb"), DynamicGetField("map_string_string"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre( Pair(StringValueIs("foo"), StringValueIs("bar")), Pair(StringValueIs("bar"), StringValueIs("foo")))); } TEST_F(ParsedMapFieldValueTest, ForEach_StringDuration) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_duration { key: "foo" value: { seconds: 1 nanos: 1 } } map_string_duration { key: "bar" value: {} } )pb"), DynamicGetField("map_string_duration"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT( entries, UnorderedElementsAre( Pair(StringValueIs("foo"), DurationValueIs(absl::Seconds(1) + absl::Nanoseconds(1))), Pair(StringValueIs("bar"), DurationValueIs(absl::ZeroDuration())))); } TEST_F(ParsedMapFieldValueTest, ForEach_StringBytes) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bytes { key: "foo" value: "bar" } map_string_bytes { key: "bar" value: "foo" } )pb"), DynamicGetField("map_string_bytes"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre( Pair(StringValueIs("foo"), BytesValueIs("bar")), Pair(StringValueIs("bar"), BytesValueIs("foo")))); } TEST_F(ParsedMapFieldValueTest, ForEach_StringEnum) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_enum { key: "foo" value: BAR } map_string_enum { key: "bar" value: FOO } )pb"), DynamicGetField("map_string_enum"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre(Pair(StringValueIs("foo"), IntValueIs(1)), Pair(StringValueIs("bar"), IntValueIs(0)))); } TEST_F(ParsedMapFieldValueTest, ForEach_StringNull) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_null_value { key: "foo" value: NULL_VALUE } map_string_null_value { key: "bar" value: NULL_VALUE } )pb"), DynamicGetField("map_string_null_value"), arena()); std::vector> entries; EXPECT_THAT( value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { entries.push_back(std::pair{std::move(key), std::move(value)}); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(entries, UnorderedElementsAre(Pair(StringValueIs("foo"), IsNullValue()), Pair(StringValueIs("bar"), IsNullValue()))); } TEST_F(ParsedMapFieldValueTest, NewIterator) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bool { key: "foo" value: false } map_string_bool { key: "bar" value: true } )pb"), DynamicGetField("map_string_bool"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(AnyOf(StringValueIs("foo"), StringValueIs("bar")))); ASSERT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(ParsedMapFieldValueTest, NewIterator1) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bool { key: "foo" value: false } map_string_bool { key: "bar" value: true } )pb"), DynamicGetField("map_string_bool"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds( Optional(AnyOf(StringValueIs("foo"), StringValueIs("bar"))))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds( Optional(AnyOf(StringValueIs("foo"), StringValueIs("bar"))))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ParsedMapFieldValueTest, NewIterator2) { ParsedMapFieldValue value( DynamicParseTextProto(R"pb( map_string_bool { key: "foo" value: false } map_string_bool { key: "bar" value: true } )pb"), DynamicGetField("map_string_bool"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional( AnyOf(Pair(StringValueIs("foo"), BoolValueIs(false)), Pair(StringValueIs("bar"), BoolValueIs(true)))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional( AnyOf(Pair(StringValueIs("foo"), BoolValueIs(false)), Pair(StringValueIs("bar"), BoolValueIs(true)))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } } // namespace } // namespace cel ================================================ FILE: common/values/parsed_message_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/parsed_message_value.h" #include #include #include #include #include #include #include "google/protobuf/empty.pb.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/memory.h" #include "common/value.h" #include "extensions/protobuf/internal/qualify.h" #include "internal/empty_descriptors.h" #include "internal/json.h" #include "internal/message_equality.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" namespace cel { namespace { using ::cel::well_known_types::ValueReflection; template std::enable_if_t, const google::protobuf::Message* absl_nonnull> EmptyParsedMessageValue() { return &T::default_instance(); } template std::enable_if_t< std::conjunction_v, std::negation>>, const google::protobuf::Message* absl_nonnull> EmptyParsedMessageValue() { return internal::GetEmptyDefaultInstance(); } } // namespace ParsedMessageValue::ParsedMessageValue() : value_(EmptyParsedMessageValue()), arena_(nullptr) {} bool ParsedMessageValue::IsZeroValue() const { const auto* reflection = GetReflection(); if (!reflection->GetUnknownFields(*value_).empty()) { return false; } std::vector fields; reflection->ListFields(*value_, &fields); return fields.empty(); } std::string ParsedMessageValue::DebugString() const { return absl::StrCat(*value_); } absl::Status ParsedMessageValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); if (!value_->SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", value_->GetTypeName())); } return absl::OkStatus(); } absl::Status ParsedMessageValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); google::protobuf::Message* json_object = value_reflection.MutableStructValue(json); return internal::MessageToJson(*value_, descriptor_pool, message_factory, json_object); } absl::Status ParsedMessageValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); return internal::MessageToJson(*value_, descriptor_pool, message_factory, json); } absl::Status ParsedMessageValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_message = other.AsParsedMessage(); other_message) { CEL_ASSIGN_OR_RETURN( auto equal, internal::MessageEquals(*value_, **other_message, descriptor_pool, message_factory)); *result = BoolValue(equal); return absl::OkStatus(); } if (auto other_struct = other.AsStruct(); other_struct) { return common_internal::StructValueEqual(StructValue(*this), *other_struct, descriptor_pool, message_factory, arena, result); } *result = BoolValue(false); return absl::OkStatus(); } ParsedMessageValue ParsedMessageValue::Clone( google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); if (arena_ == arena) { return *this; } auto* cloned = value_->New(arena); cloned->CopyFrom(*value_); return ParsedMessageValue(cloned, arena); } absl::Status ParsedMessageValue::GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); const auto* descriptor = GetDescriptor(); const auto* field = descriptor->FindFieldByName(name); if (field == nullptr) { field = descriptor->file()->pool()->FindExtensionByPrintableName(descriptor, name); if (field == nullptr) { *result = NoSuchFieldError(name); return absl::OkStatus(); } } return GetField(field, unboxing_options, descriptor_pool, message_factory, arena, result); } absl::Status ParsedMessageValue::GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); const auto* descriptor = GetDescriptor(); if (number < std::numeric_limits::min() || number > std::numeric_limits::max()) { *result = NoSuchFieldError(absl::StrCat(number)); return absl::OkStatus(); } const auto* field = descriptor->FindFieldByNumber(static_cast(number)); if (field == nullptr) { *result = NoSuchFieldError(absl::StrCat(number)); return absl::OkStatus(); } return GetField(field, unboxing_options, descriptor_pool, message_factory, arena, result); } absl::StatusOr ParsedMessageValue::HasFieldByName( absl::string_view name) const { const auto* descriptor = GetDescriptor(); const auto* field = descriptor->FindFieldByName(name); if (field == nullptr) { field = descriptor->file()->pool()->FindExtensionByPrintableName(descriptor, name); if (field == nullptr) { return NoSuchFieldError(name).NativeValue(); } } return HasField(field); } absl::StatusOr ParsedMessageValue::HasFieldByNumber( int64_t number) const { const auto* descriptor = GetDescriptor(); if (number < std::numeric_limits::min() || number > std::numeric_limits::max()) { return NoSuchFieldError(absl::StrCat(number)).NativeValue(); } const auto* field = descriptor->FindFieldByNumber(static_cast(number)); if (field == nullptr) { return NoSuchFieldError(absl::StrCat(number)).NativeValue(); } return HasField(field); } absl::Status ParsedMessageValue::ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); std::vector fields; const auto* reflection = GetReflection(); reflection->ListFields(*value_, &fields); for (const auto* field : fields) { auto value = Value::WrapField(value_, field, descriptor_pool, message_factory, arena); CEL_ASSIGN_OR_RETURN(auto ok, callback(field->name(), value)); if (!ok) { break; } } return absl::OkStatus(); } namespace { class ParsedMessageValueQualifyState final : public extensions::protobuf_internal::ProtoQualifyState { public: ParsedMessageValueQualifyState( const google::protobuf::Message* absl_nonnull message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) : ProtoQualifyState(message, message->GetDescriptor(), message->GetReflection()), descriptor_pool_(descriptor_pool), message_factory_(message_factory), arena_(arena) {} absl::optional& result() { return result_; } private: void SetResultFromError(absl::Status status, cel::MemoryManagerRef) override { result_ = ErrorValue(std::move(status)); } void SetResultFromBool(bool value) override { result_ = BoolValue(value); } absl::Status SetResultFromField(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef) override { result_ = Value::WrapField(unboxing_option, message, field, descriptor_pool_, message_factory_, arena_); return absl::OkStatus(); } absl::Status SetResultFromRepeatedField(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, int index, cel::MemoryManagerRef) override { result_ = Value::WrapRepeatedField(index, message, field, descriptor_pool_, message_factory_, arena_); return absl::OkStatus(); } absl::Status SetResultFromMapField(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, const google::protobuf::MapValueConstRef& value, cel::MemoryManagerRef) override { result_ = Value::WrapMapFieldValue(value, message, field, descriptor_pool_, message_factory_, arena_); return absl::OkStatus(); } const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; google::protobuf::MessageFactory* absl_nonnull const message_factory_; google::protobuf::Arena* absl_nonnull const arena_; absl::optional result_; }; } // namespace absl::Status ParsedMessageValue::Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const { ABSL_DCHECK(!qualifiers.empty()); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(count != nullptr); if (ABSL_PREDICT_FALSE(qualifiers.empty())) { return absl::InvalidArgumentError("invalid select qualifier path."); } ParsedMessageValueQualifyState qualify_state(value_, descriptor_pool, message_factory, arena); for (int i = 0; i < qualifiers.size() - 1; i++) { const auto& qualifier = qualifiers[i]; CEL_RETURN_IF_ERROR(qualify_state.ApplySelectQualifier( qualifier, MemoryManagerRef::Pooling(arena))); if (qualify_state.result().has_value()) { *result = std::move(qualify_state.result()).value(); *count = result->Is() ? -1 : i + 1; return absl::OkStatus(); } } const auto& last_qualifier = qualifiers.back(); if (presence_test) { CEL_RETURN_IF_ERROR(qualify_state.ApplyLastQualifierHas( last_qualifier, MemoryManagerRef::Pooling(arena))); } else { CEL_RETURN_IF_ERROR(qualify_state.ApplyLastQualifierGet( last_qualifier, MemoryManagerRef::Pooling(arena))); } *result = std::move(qualify_state.result()).value(); *count = -1; return absl::OkStatus(); } absl::Status ParsedMessageValue::GetField( const google::protobuf::FieldDescriptor* absl_nonnull field, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(field != nullptr); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (arena_ == nullptr) { *result = Value::WrapFieldUnsafe(unboxing_options, value_, field, descriptor_pool, message_factory, arena); } else { *result = Value::WrapField(unboxing_options, value_, field, descriptor_pool, message_factory, arena); } return absl::OkStatus(); } bool ParsedMessageValue::HasField( const google::protobuf::FieldDescriptor* absl_nonnull field) const { ABSL_DCHECK(field != nullptr); const auto* reflection = GetReflection(); if (field->is_map() || field->is_repeated()) { return reflection->FieldSize(*value_, field) > 0; } return reflection->HasField(*value_, field); } } // namespace cel ================================================ FILE: common/values/parsed_message_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MESSAGE_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MESSAGE_VALUE_H_ #include #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/memory.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_struct_value.h" #include "common/values/values.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class MessageValue; class StructValue; class Value; class ParsedMessageValue final : private common_internal::StructValueMixin { public: static constexpr ValueKind kKind = ValueKind::kStruct; using element_type = const google::protobuf::Message; ParsedMessageValue( const google::protobuf::Message* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_(value), arena_(arena) { ABSL_DCHECK(value != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(!value_ || !IsWellKnownMessageType(value_->GetDescriptor())) << value_->GetTypeName() << " is a well known type"; ABSL_DCHECK(!value_ || value_->GetReflection() != nullptr) << value_->GetTypeName() << " is missing reflection"; ABSL_DCHECK_OK(CheckArena(value_, arena_)); } // Places the `ParsedMessageValue` into a special state where it is logically // equivalent to the default instance of `google.protobuf.Empty`, however // dereferencing via `operator*` or `operator->` is not allowed. ParsedMessageValue(); ParsedMessageValue(const ParsedMessageValue&) = default; ParsedMessageValue(ParsedMessageValue&&) = default; ParsedMessageValue& operator=(const ParsedMessageValue&) = default; ParsedMessageValue& operator=(ParsedMessageValue&&) = default; static constexpr ValueKind kind() { return kKind; } absl::string_view GetTypeName() const { return GetDescriptor()->full_name(); } MessageType GetRuntimeType() const { return MessageType(GetDescriptor()); } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { return (*this)->GetDescriptor(); } const google::protobuf::Reflection* absl_nonnull GetReflection() const { return (*this)->GetReflection(); } const google::protobuf::Message& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return *value_; } const google::protobuf::Message* absl_nonnull operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } bool IsZeroValue() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonObject(). absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::Equal; ParsedMessageValue Clone(google::protobuf::Arena* absl_nonnull arena) const; absl::Status GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByName; absl::Status GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByNumber; absl::StatusOr HasFieldByName(absl::string_view name) const; absl::StatusOr HasFieldByNumber(int64_t number) const; using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; absl::Status ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::Status Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const; using StructValueMixin::Qualify; friend void swap(ParsedMessageValue& lhs, ParsedMessageValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); swap(lhs.arena_, rhs.arena_); } private: friend std::pointer_traits; friend class StructValue; friend class common_internal::ValueMixin; friend class common_internal::StructValueMixin; friend ParsedMessageValue UnsafeParsedMessageValue( const google::protobuf::Message* absl_nonnull value); explicit ParsedMessageValue( const google::protobuf::Message* absl_nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND) : value_(value), arena_(value->GetArena()) { ABSL_DCHECK(value != nullptr); ABSL_DCHECK(!value_ || !IsWellKnownMessageType(value_->GetDescriptor())) << value_->GetTypeName() << " is a well known type"; ABSL_DCHECK(!value_ || value_->GetReflection() != nullptr) << value_->GetTypeName() << " is missing reflection"; } static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, google::protobuf::Arena* absl_nonnull arena) { if (message != nullptr && message->GetArena() != nullptr && message->GetArena() != arena) { return absl::InvalidArgumentError( "message arena must be the same as arena"); } return absl::OkStatus(); } absl::Status GetField( const google::protobuf::FieldDescriptor* absl_nonnull field, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; bool HasField(const google::protobuf::FieldDescriptor* absl_nonnull field) const; const google::protobuf::Message* absl_nonnull value_; // Arena that is attributed as owning the value. May be null to indicate that // the value is managed externally. google::protobuf::Arena* absl_nullable arena_; }; inline std::ostream& operator<<(std::ostream& out, const ParsedMessageValue& value) { return out << value.DebugString(); } // Creates a `ParsedMessageValue` without specifying a managing arena. // The message must outlive the `ParsedMessageValue` or any value that might // be derived from it. Prefer to use `cel::Value::WrapMessageUnsafe()`. inline ParsedMessageValue UnsafeParsedMessageValue( const google::protobuf::Message* absl_nonnull value) { return ParsedMessageValue(value); } } // namespace cel namespace std { template <> struct pointer_traits { using pointer = cel::ParsedMessageValue; using element_type = typename cel::ParsedMessageValue::element_type; using difference_type = ptrdiff_t; static element_type* to_address(const pointer& p) noexcept { return cel::to_address(p.value_); } }; } // namespace std #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_MESSAGE_VALUE_H_ ================================================ FILE: common/values/parsed_message_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "google/protobuf/struct.pb.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "internal/testing.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::cel::test::BoolValueIs; using ::testing::_; using ::testing::IsEmpty; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; using ParsedMessageValueTest = common_internal::ValueTest<>; TEST_F(ParsedMessageValueTest, Kind) { ParsedMessageValue value = MakeParsedMessage(); EXPECT_EQ(value.kind(), ParsedMessageValue::kKind); EXPECT_EQ(value.kind(), ValueKind::kStruct); } TEST_F(ParsedMessageValueTest, GetTypeName) { ParsedMessageValue value = MakeParsedMessage(); EXPECT_EQ(value.GetTypeName(), "cel.expr.conformance.proto3.TestAllTypes"); } TEST_F(ParsedMessageValueTest, GetRuntimeType) { ParsedMessageValue value = MakeParsedMessage(); EXPECT_EQ(value.GetRuntimeType(), MessageType(value.GetDescriptor())); } TEST_F(ParsedMessageValueTest, DebugString) { ParsedMessageValue value = MakeParsedMessage(); EXPECT_THAT(value.DebugString(), _); } TEST_F(ParsedMessageValueTest, IsZeroValue) { MessageValue value = MakeParsedMessage(); EXPECT_TRUE(value.IsZeroValue()); } TEST_F(ParsedMessageValueTest, SerializeTo) { MessageValue value = MakeParsedMessage(); google::protobuf::io::CordOutputStream output; EXPECT_THAT(value.SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), IsEmpty()); } TEST_F(ParsedMessageValueTest, ConvertToJson) { MessageValue value = MakeParsedMessage(); auto json = DynamicParseTextProto(R"pb()pb"); EXPECT_THAT(value.ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(json)), IsOk()); EXPECT_THAT(*json, EqualsTextProto( R"pb(struct_value: {})pb")); } TEST_F(ParsedMessageValueTest, Equal) { MessageValue value = MakeParsedMessage(); EXPECT_THAT( value.Equal(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Equal(MakeParsedMessage(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } TEST_F(ParsedMessageValueTest, GetFieldByName) { MessageValue value = MakeParsedMessage(); EXPECT_THAT(value.GetFieldByName("single_bool", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } TEST_F(ParsedMessageValueTest, GetFieldByNumber) { MessageValue value = MakeParsedMessage(); EXPECT_THAT( value.GetFieldByNumber(13, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } } // namespace } // namespace cel ================================================ FILE: common/values/parsed_repeated_field_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/parsed_repeated_field_value.h" #include #include #include #include #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/value.h" #include "internal/json.h" #include "internal/message_equality.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { using ::cel::well_known_types::ValueReflection; std::string ParsedRepeatedFieldValue::DebugString() const { if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return "INVALID"; } return "VALID"; } absl::Status ParsedRepeatedFieldValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return absl::OkStatus(); } // We have to convert to google.protobuf.Struct first. google::protobuf::Value message; CEL_RETURN_IF_ERROR(internal::MessageFieldToJson( *message_, field_, descriptor_pool, message_factory, &message)); if (!message.list_value().SerializePartialToZeroCopyStream(output)) { return absl::UnknownError("failed to serialize google.protobuf.Struct"); } return absl::OkStatus(); } absl::Status ParsedRepeatedFieldValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.MutableListValue(json)->Clear(); return absl::OkStatus(); } return internal::MessageFieldToJson(*message_, field_, descriptor_pool, message_factory, json); } absl::Status ParsedRepeatedFieldValue::ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); ABSL_DCHECK(*this); json->Clear(); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return absl::OkStatus(); } return internal::MessageFieldToJson(*message_, field_, descriptor_pool, message_factory, json); } absl::Status ParsedRepeatedFieldValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { if (auto other_value = other.AsParsedRepeatedField(); other_value) { ABSL_DCHECK(field_ != nullptr); ABSL_DCHECK(other_value->field_ != nullptr); CEL_ASSIGN_OR_RETURN( auto equal, internal::MessageFieldEquals( *message_, field_, *other_value->message_, other_value->field_, descriptor_pool, message_factory)); *result = BoolValue(equal); return absl::OkStatus(); } if (auto other_value = other.AsParsedJsonList(); other_value) { if (other_value->value_ == nullptr) { *result = BoolValue(IsEmpty()); return absl::OkStatus(); } ABSL_DCHECK(field_ != nullptr); CEL_ASSIGN_OR_RETURN( auto equal, internal::MessageFieldEquals(*message_, field_, *other_value->value_, descriptor_pool, message_factory)); *result = BoolValue(equal); return absl::OkStatus(); } if (auto other_value = other.AsList(); other_value) { return common_internal::ListValueEqual(ListValue(*this), *other_value, descriptor_pool, message_factory, arena, result); } *result = BoolValue(false); return absl::OkStatus(); } bool ParsedRepeatedFieldValue::IsZeroValue() const { return IsEmpty(); } ParsedRepeatedFieldValue ParsedRepeatedFieldValue::Clone( google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return ParsedRepeatedFieldValue(); } if (arena_ == arena) { return *this; } auto field = message_->GetReflection()->GetRepeatedFieldRef( *message_, field_); auto* cloned_message = message_->New(arena); auto cloned_field = cloned_message->GetReflection() ->GetMutableRepeatedFieldRef(cloned_message, field_); cloned_field.CopyFrom(field); return ParsedRepeatedFieldValue(cloned_message, field_, arena); } bool ParsedRepeatedFieldValue::IsEmpty() const { return Size() == 0; } size_t ParsedRepeatedFieldValue::Size() const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return 0; } return static_cast(GetReflection()->FieldSize(*message_, field_)); } // See ListValueInterface::Get for documentation. absl::Status ParsedRepeatedFieldValue::Get( size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(*this); ABSL_DCHECK(message_ != nullptr); if (ABSL_PREDICT_FALSE(field_ == nullptr || index >= std::numeric_limits::max() || static_cast(index) >= GetReflection()->FieldSize(*message_, field_))) { *result = IndexOutOfBoundsError(index); return absl::OkStatus(); } if (arena_ == nullptr) { *result = Value::WrapRepeatedFieldUnsafe(static_cast(index), message_, field_, descriptor_pool, message_factory, arena); } else { *result = Value::WrapRepeatedField(static_cast(index), message_, field_, descriptor_pool, message_factory, arena); } return absl::OkStatus(); } absl::Status ParsedRepeatedFieldValue::ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return absl::OkStatus(); } const auto* reflection = message_->GetReflection(); const int size = reflection->FieldSize(*message_, field_); if (size > 0) { CEL_ASSIGN_OR_RETURN(auto accessor, common_internal::RepeatedFieldAccessorFor(field_)); Value scratch; for (int i = 0; i < size; ++i) { (*accessor)(i, message_, field_, reflection, descriptor_pool, message_factory, arena, &scratch); CEL_ASSIGN_OR_RETURN(auto ok, callback(static_cast(i), scratch)); if (!ok) { break; } } } return absl::OkStatus(); } namespace { class ParsedRepeatedFieldValueIterator final : public ValueIterator { public: ParsedRepeatedFieldValueIterator( const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, absl_nonnull common_internal::RepeatedFieldAccessor accessor) : message_(message), field_(field), reflection_(message_->GetReflection()), accessor_(accessor), size_(reflection_->FieldSize(*message_, field_)) {} bool HasNext() override { return index_ < size_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (ABSL_PREDICT_FALSE(index_ >= size_)) { return absl::FailedPreconditionError( "ValueIterator::Next called after ValueIterator::HasNext returned " "false"); } (*accessor_)(index_, message_, field_, reflection_, descriptor_pool, message_factory, arena, result); ++index_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (index_ >= size_) { return false; } (*accessor_)(index_, message_, field_, reflection_, descriptor_pool, message_factory, arena, key_or_value); ++index_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (index_ >= size_) { return false; } if (value != nullptr) { (*accessor_)(index_, message_, field_, reflection_, descriptor_pool, message_factory, arena, value); } *key = IntValue(index_); ++index_; return true; } private: const google::protobuf::Message* absl_nonnull const message_; const google::protobuf::FieldDescriptor* absl_nonnull const field_; const google::protobuf::Reflection* absl_nonnull const reflection_; const absl_nonnull common_internal::RepeatedFieldAccessor accessor_; const int size_; int index_ = 0; }; } // namespace absl::StatusOr> ParsedRepeatedFieldValue::NewIterator() const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { return NewEmptyValueIterator(); } CEL_ASSIGN_OR_RETURN(auto accessor, common_internal::RepeatedFieldAccessorFor(field_)); return std::make_unique(message_, field_, accessor); } absl::Status ParsedRepeatedFieldValue::Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(*this); if (ABSL_PREDICT_FALSE(field_ == nullptr)) { *result = FalseValue(); return absl::OkStatus(); } const auto* reflection = message_->GetReflection(); const int size = reflection->FieldSize(*message_, field_); if (size > 0) { CEL_ASSIGN_OR_RETURN(auto accessor, common_internal::RepeatedFieldAccessorFor(field_)); Value scratch; for (int i = 0; i < size; ++i) { (*accessor)(i, message_, field_, reflection, descriptor_pool, message_factory, arena, &scratch); CEL_RETURN_IF_ERROR(scratch.Equal(other, descriptor_pool, message_factory, arena, result)); if (result->IsTrue()) { return absl::OkStatus(); } } } *result = FalseValue(); return absl::OkStatus(); } const google::protobuf::Reflection* absl_nonnull ParsedRepeatedFieldValue::GetReflection() const { return message_->GetReflection(); } } // namespace cel ================================================ FILE: common/values/parsed_repeated_field_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_REPEATED_FIELD_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_REPEATED_FIELD_VALUE_H_ #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_list_value.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class ValueIterator; class ParsedJsonListValue; // ParsedRepeatedFieldValue is a ListValue over a repeated field of a parsed // protocol buffer message. class ParsedRepeatedFieldValue final : private common_internal::ListValueMixin { public: static constexpr ValueKind kKind = ValueKind::kList; static constexpr absl::string_view kName = "list"; ParsedRepeatedFieldValue(const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::Arena* absl_nonnull arena) : message_(message), field_(field), arena_(arena) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(field_->is_repeated() && !field_->is_map()) << field_->full_name() << " must be a repeated field"; ABSL_DCHECK_OK(CheckArena(message_, arena_)); } // Places the `ParsedRepeatedFieldValue` into an invalid state. Anything // except assigning to `ParsedRepeatedFieldValue` is undefined behavior. ParsedRepeatedFieldValue() = default; ParsedRepeatedFieldValue(const ParsedRepeatedFieldValue&) = default; ParsedRepeatedFieldValue(ParsedRepeatedFieldValue&&) = default; ParsedRepeatedFieldValue& operator=(const ParsedRepeatedFieldValue&) = default; ParsedRepeatedFieldValue& operator=(ParsedRepeatedFieldValue&&) = default; static constexpr ValueKind kind() { return kKind; } static constexpr absl::string_view GetTypeName() { return kName; } static ListType GetRuntimeType() { return ListType(); } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // See Value::ConvertToJsonArray(). absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Equal; bool IsZeroValue() const; bool IsEmpty() const; ParsedRepeatedFieldValue Clone(google::protobuf::Arena* absl_nonnull arena) const; size_t Size() const; // See ListValueInterface::Get for documentation. absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Get; using ForEachCallback = typename CustomListValueInterface::ForEachCallback; using ForEachWithIndexCallback = typename CustomListValueInterface::ForEachWithIndexCallback; absl::Status ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; using ListValueMixin::ForEach; absl::StatusOr NewIterator() const; absl::Status Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ListValueMixin::Contains; const google::protobuf::Message& message() const { ABSL_DCHECK(*this); return *message_; } const google::protobuf::FieldDescriptor* absl_nonnull field() const { ABSL_DCHECK(*this); return field_; } // Returns `true` if `ParsedRepeatedFieldValue` is in a valid state. explicit operator bool() const { return field_ != nullptr; } friend void swap(ParsedRepeatedFieldValue& lhs, ParsedRepeatedFieldValue& rhs) noexcept { using std::swap; swap(lhs.message_, rhs.message_); swap(lhs.field_, rhs.field_); swap(lhs.arena_, rhs.arena_); } private: friend class ParsedJsonListValue; friend class common_internal::ValueMixin; friend class common_internal::ListValueMixin; friend ParsedRepeatedFieldValue UnsafeParsedRepeatedFieldValue( const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field); ParsedRepeatedFieldValue(const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field) : message_(message), field_(field), arena_(message->GetArena()) { ABSL_DCHECK(message != nullptr); ABSL_DCHECK(field != nullptr); ABSL_DCHECK(field_->is_repeated() && !field_->is_map()) << field_->full_name() << " must be a repeated field"; } static absl::Status CheckArena(const google::protobuf::Message* absl_nullable message, google::protobuf::Arena* absl_nonnull arena) { if (message != nullptr && message->GetArena() != nullptr && message->GetArena() != arena) { return absl::InvalidArgumentError( "message arena must be the same as arena"); } return absl::OkStatus(); } const google::protobuf::Reflection* absl_nonnull GetReflection() const; const google::protobuf::Message* absl_nullable message_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable field_ = nullptr; google::protobuf::Arena* absl_nullable arena_ = nullptr; }; inline std::ostream& operator<<(std::ostream& out, const ParsedRepeatedFieldValue& value) { return out << value.DebugString(); } // Creates a `ParsedRepeatedFieldValue` without specifying a managing arena. // The message must outlive the `ParsedRepeatedFieldValue` or any value that // might be derived from it. Prefer to use // `cel::Value::WrapRepeatedFieldUnsafe()`. inline ParsedRepeatedFieldValue UnsafeParsedRepeatedFieldValue( const google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field) { return ParsedRepeatedFieldValue(message, field); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_PARSED_REPEATED_FIELD_VALUE_H_ ================================================ FILE: common/values/parsed_repeated_field_value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "internal/testing.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::BytesValueIs; using ::cel::test::DoubleValueIs; using ::cel::test::DurationValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::cel::test::IsNullValue; using ::cel::test::UintValueIs; using ::testing::_; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Optional; using ::testing::Pair; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; using ParsedRepeatedFieldValueTest = common_internal::ValueTest<>; TEST_F(ParsedRepeatedFieldValueTest, Field) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_TRUE(value); } TEST_F(ParsedRepeatedFieldValueTest, Kind) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_EQ(value.kind(), ParsedRepeatedFieldValue::kKind); EXPECT_EQ(value.kind(), ValueKind::kList); } TEST_F(ParsedRepeatedFieldValueTest, GetTypeName) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_EQ(value.GetTypeName(), ParsedRepeatedFieldValue::kName); EXPECT_EQ(value.GetTypeName(), "list"); } TEST_F(ParsedRepeatedFieldValueTest, GetRuntimeType) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_EQ(value.GetRuntimeType(), ListType()); } TEST_F(ParsedRepeatedFieldValueTest, DebugString) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_THAT(value.DebugString(), _); } TEST_F(ParsedRepeatedFieldValueTest, IsZeroValue) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_TRUE(value.IsZeroValue()); } TEST_F(ParsedRepeatedFieldValueTest, SerializeTo) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); google::protobuf::io::CordOutputStream output; EXPECT_THAT(value.SerializeTo(descriptor_pool(), message_factory(), &output), IsOk()); EXPECT_THAT(std::move(output).Consume(), IsEmpty()); } TEST_F(ParsedRepeatedFieldValueTest, ConvertToJson) { auto json = DynamicParseTextProto(R"pb()pb"); ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_THAT(value.ConvertToJson(descriptor_pool(), message_factory(), cel::to_address(json)), IsOk()); EXPECT_THAT( *json, EqualsTextProto(R"pb(list_value: {})pb")); } TEST_F(ParsedRepeatedFieldValueTest, Equal_RepeatedField) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_THAT( value.Equal(BoolValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT( value.Equal( ParsedRepeatedFieldValue( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( value.Equal(ListValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } TEST_F(ParsedRepeatedFieldValueTest, Equal_JsonList) { ParsedRepeatedFieldValue repeated_value( DynamicParseTextProto(R"pb(repeated_int64: 1 repeated_int64: 0)pb"), DynamicGetField("repeated_int64"), arena()); ParsedJsonListValue json_value( DynamicParseTextProto( R"pb( values { number_value: 1 } values { number_value: 0 } )pb"), arena()); EXPECT_THAT(repeated_value.Equal(json_value, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(json_value.Equal(repeated_value, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } TEST_F(ParsedRepeatedFieldValueTest, Empty) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_TRUE(value.IsEmpty()); } TEST_F(ParsedRepeatedFieldValueTest, Size) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb()pb"), DynamicGetField("repeated_int64"), arena()); EXPECT_EQ(value.Size(), 0); } TEST_F(ParsedRepeatedFieldValueTest, Get) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_bool: false repeated_bool: true)pb"), DynamicGetField("repeated_bool"), arena()); EXPECT_THAT(value.Get(0, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Get(1, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( value.Get(2, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument)))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_Bool) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_bool: false repeated_bool: true)pb"), DynamicGetField("repeated_bool"), arena()); { std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(BoolValueIs(false), BoolValueIs(true))); } { std::vector values; EXPECT_THAT(value.ForEach( [&](size_t, const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(BoolValueIs(false), BoolValueIs(true))); } } TEST_F(ParsedRepeatedFieldValueTest, ForEach_Double) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_double: 1 repeated_double: 0)pb"), DynamicGetField("repeated_double"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(DoubleValueIs(1), DoubleValueIs(0))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_Float) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_float: 1 repeated_float: 0)pb"), DynamicGetField("repeated_float"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(DoubleValueIs(1), DoubleValueIs(0))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_UInt64) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_uint64: 1 repeated_uint64: 0)pb"), DynamicGetField("repeated_uint64"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(UintValueIs(1), UintValueIs(0))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_Int32) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_int32: 1 repeated_int32: 0)pb"), DynamicGetField("repeated_int32"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(IntValueIs(1), IntValueIs(0))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_UInt32) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_uint32: 1 repeated_uint32: 0)pb"), DynamicGetField("repeated_uint32"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(UintValueIs(1), UintValueIs(0))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_Duration) { ParsedRepeatedFieldValue value( DynamicParseTextProto( R"pb(repeated_duration: { seconds: 1 nanos: 1 } repeated_duration: {})pb"), DynamicGetField("repeated_duration"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(DurationValueIs(absl::Seconds(1) + absl::Nanoseconds(1)), DurationValueIs(absl::ZeroDuration()))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_Bytes) { ParsedRepeatedFieldValue value( DynamicParseTextProto( R"pb(repeated_bytes: "bar" repeated_bytes: "foo")pb"), DynamicGetField("repeated_bytes"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(BytesValueIs("bar"), BytesValueIs("foo"))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_Enum) { ParsedRepeatedFieldValue value( DynamicParseTextProto( R"pb(repeated_nested_enum: BAR repeated_nested_enum: FOO)pb"), DynamicGetField("repeated_nested_enum"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(IntValueIs(1), IntValueIs(0))); } TEST_F(ParsedRepeatedFieldValueTest, ForEach_Null) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_null_value: NULL_VALUE repeated_null_value: NULL_VALUE)pb"), DynamicGetField("repeated_null_value"), arena()); std::vector values; EXPECT_THAT(value.ForEach( [&](const Value& element) -> absl::StatusOr { values.push_back(element); return true; }, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(values, ElementsAre(IsNullValue(), IsNullValue())); } TEST_F(ParsedRepeatedFieldValueTest, NewIterator) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_bool: false repeated_bool: true)pb"), DynamicGetField("repeated_bool"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); ASSERT_TRUE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); ASSERT_FALSE(iterator->HasNext()); EXPECT_THAT(iterator->Next(descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(ParsedRepeatedFieldValueTest, NewIterator1) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_bool: false repeated_bool: true)pb"), DynamicGetField("repeated_bool"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(false)))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(BoolValueIs(true)))); EXPECT_THAT(iterator->Next1(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ParsedRepeatedFieldValueTest, NewIterator2) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_bool: false repeated_bool: true)pb"), DynamicGetField("repeated_bool"), arena()); ASSERT_OK_AND_ASSIGN(auto iterator, value.NewIterator()); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(IntValueIs(0), BoolValueIs(false))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(Pair(IntValueIs(1), BoolValueIs(true))))); EXPECT_THAT(iterator->Next2(descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ParsedRepeatedFieldValueTest, Contains) { ParsedRepeatedFieldValue value( DynamicParseTextProto(R"pb(repeated_bool: true)pb"), DynamicGetField("repeated_bool"), arena()); EXPECT_THAT(value.Contains(BytesValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Contains(NullValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Contains(BoolValue(false), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Contains(BoolValue(true), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(value.Contains(DoubleValue(0.0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Contains(DoubleValue(1.0), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Contains(StringValue("bar"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT(value.Contains(StringValue("foo"), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); EXPECT_THAT( value.Contains(MapValue(), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } } // namespace } // namespace cel ================================================ FILE: common/values/string_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/cord.h" #include "absl/strings/cord_buffer.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/internal/byte_string.h" #include "common/internal/reference_count.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "internal/utf8.h" #include "internal/well_known_types.h" #include "runtime/internal/errors.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::ValueReflection; template std::string StringDebugString(const Bytes& value) { return value.NativeValue(absl::Overload( [](absl::string_view string) -> std::string { return internal::FormatStringLiteral(string); }, [](const absl::Cord& cord) -> std::string { if (auto flat = cord.TryFlat(); flat.has_value()) { return internal::FormatStringLiteral(*flat); } return internal::FormatStringLiteral(static_cast(cord)); })); } } // namespace StringValue StringValue::Concat(const StringValue& lhs, const StringValue& rhs, google::protobuf::Arena* absl_nonnull arena) { return StringValue( common_internal::ByteString::Concat(lhs.value_, rhs.value_, arena)); } std::string StringValue::DebugString() const { return StringDebugString(*this); } absl::Status StringValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::StringValue message; message.set_value(NativeString()); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", message.GetTypeName())); } return absl::OkStatus(); } absl::Status StringValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); NativeValue( [&](const auto& value) { value_reflection.SetStringValue(json, value); }); return absl::OkStatus(); } absl::Status StringValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsString(); other_value.has_value()) { *result = NativeValue([other_value](const auto& value) -> BoolValue { return other_value->NativeValue( [&value](const auto& other_value) -> BoolValue { return BoolValue{value == other_value}; }); }); return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } size_t StringValue::Size() const { return NativeValue([](const auto& alternative) -> size_t { return internal::Utf8CodePointCount(alternative); }); } bool StringValue::IsEmpty() const { return NativeValue( [](const auto& alternative) -> bool { return alternative.empty(); }); } bool StringValue::Equals(absl::string_view string) const { return value_.Equals(string); } bool StringValue::Equals(const absl::Cord& string) const { return value_.Equals(string); } bool StringValue::Equals(const StringValue& string) const { return value_.Equals(string.value_); } StringValue StringValue::Clone(google::protobuf::Arena* absl_nonnull arena) const { return StringValue(value_.Clone(arena)); } int StringValue::Compare(absl::string_view string) const { return value_.Compare(string); } int StringValue::Compare(const absl::Cord& string) const { return value_.Compare(string); } int StringValue::Compare(const StringValue& string) const { return value_.Compare(string.value_); } bool StringValue::StartsWith(absl::string_view string) const { return value_.StartsWith(string); } bool StringValue::StartsWith(const absl::Cord& string) const { return value_.StartsWith(string); } bool StringValue::StartsWith(const StringValue& string) const { return value_.StartsWith(string.value_); } bool StringValue::EndsWith(absl::string_view string) const { return value_.EndsWith(string); } bool StringValue::EndsWith(const absl::Cord& string) const { return value_.EndsWith(string); } bool StringValue::EndsWith(const StringValue& string) const { return value_.EndsWith(string.value_); } bool StringValue::Contains(absl::string_view string) const { return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> bool { return absl::StrContains(lhs, string); }, [&](const absl::Cord& lhs) -> bool { return lhs.Contains(string); })); } bool StringValue::Contains(const absl::Cord& string) const { return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> bool { if (auto flat = string.TryFlat(); flat) { return absl::StrContains(lhs, *flat); } // There is no nice way to do this. We cannot use std::search due to // absl::Cord::CharIterator being an input iterator instead of a forward // iterator. So just make an external cord with a noop releaser. We know // the external cord will not outlive this function. return absl::MakeCordFromExternal(lhs, []() {}).Contains(string); }, [&](const absl::Cord& lhs) -> bool { return lhs.Contains(string); })); } bool StringValue::Contains(const StringValue& string) const { return string.value_.Visit(absl::Overload( [&](absl::string_view rhs) -> bool { return Contains(rhs); }, [&](const absl::Cord& rhs) -> bool { return Contains(rhs); })); } absl::optional StringValue::IndexOf(absl::string_view string) const { return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> absl::optional { int64_t code_points = 0; while (lhs.size() >= string.size()) { if (absl::StartsWith(lhs, string)) { return code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); lhs.remove_prefix(code_units); ++code_points; } return absl::nullopt; }, [&](absl::Cord lhs) -> absl::optional { int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.StartsWith(string)) { return code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), /*code_point=*/nullptr); lhs.RemovePrefix(code_units); ++code_points; } return absl::nullopt; })); } absl::optional StringValue::IndexOf(const absl::Cord& string) const { return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> absl::optional { int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.substr(0, string.size()) == string) { return code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); lhs.remove_prefix(code_units); ++code_points; } return absl::nullopt; }, [&](absl::Cord lhs) -> absl::optional { int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.StartsWith(string)) { return code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), /*code_point=*/nullptr); lhs.RemovePrefix(code_units); ++code_points; } return absl::nullopt; })); } absl::optional StringValue::IndexOf(const StringValue& string) const { return string.value_.Visit(absl::Overload( [this](absl::string_view rhs) -> absl::optional { return IndexOf(rhs); }, [this](const absl::Cord& rhs) -> absl::optional { return IndexOf(rhs); })); } absl::optional StringValue::IndexOf(absl::string_view string, int64_t pos) const { if (pos < 0) { pos = 0; } return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> absl::optional { int64_t code_points = 0; while (lhs.size() >= string.size()) { if (code_points >= pos && absl::StartsWith(lhs, string)) { return code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); lhs.remove_prefix(code_units); ++code_points; } return absl::nullopt; }, [&](absl::Cord lhs) -> absl::optional { int64_t code_points = 0; while (lhs.size() >= string.size()) { if (code_points >= pos && lhs.StartsWith(string)) { return code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), /*code_point=*/nullptr); lhs.RemovePrefix(code_units); ++code_points; } return absl::nullopt; })); } absl::optional StringValue::IndexOf(const absl::Cord& string, int64_t pos) const { if (pos < 0) { pos = 0; } return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> absl::optional { int64_t code_points = 0; while (lhs.size() >= string.size()) { if (code_points >= pos && lhs.substr(0, string.size()) == string) { return code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); lhs.remove_prefix(code_units); ++code_points; } return absl::nullopt; }, [&](absl::Cord lhs) -> absl::optional { int64_t code_points = 0; while (lhs.size() >= string.size()) { if (code_points >= pos && lhs.StartsWith(string)) { return code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), /*code_point=*/nullptr); lhs.RemovePrefix(code_units); ++code_points; } return absl::nullopt; })); } absl::optional StringValue::IndexOf(const StringValue& string, int64_t pos) const { return string.value_.Visit(absl::Overload( [this, pos](absl::string_view rhs) -> absl::optional { return IndexOf(rhs, pos); }, [this, pos](const absl::Cord& rhs) -> absl::optional { return IndexOf(rhs, pos); })); } absl::optional StringValue::LastIndexOf( absl::string_view string) const { return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> absl::optional { int64_t last_index = -1; int64_t code_points = 0; while (lhs.size() >= string.size()) { if (absl::StartsWith(lhs, string)) { last_index = code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); lhs.remove_prefix(code_units); ++code_points; } if (last_index < 0) return absl::nullopt; return last_index; }, [&](absl::Cord lhs) -> absl::optional { int64_t last_index = -1; int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.StartsWith(string)) { last_index = code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), /*code_point=*/nullptr); lhs.RemovePrefix(code_units); ++code_points; } if (last_index < 0) return absl::nullopt; return last_index; })); } absl::optional StringValue::LastIndexOf( const absl::Cord& string) const { return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> absl::optional { int64_t last_index = -1; int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.substr(0, string.size()) == string) { last_index = code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); lhs.remove_prefix(code_units); ++code_points; } if (last_index < 0) return absl::nullopt; return last_index; }, [&](absl::Cord lhs) -> absl::optional { int64_t last_index = -1; int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.StartsWith(string)) { last_index = code_points; } if (lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), /*code_point=*/nullptr); lhs.RemovePrefix(code_units); ++code_points; } if (last_index < 0) return absl::nullopt; return last_index; })); } absl::optional StringValue::LastIndexOf( const StringValue& string) const { return string.value_.Visit(absl::Overload( [this](absl::string_view rhs) -> absl::optional { return LastIndexOf(rhs); }, [this](const absl::Cord& rhs) -> absl::optional { return LastIndexOf(rhs); })); } absl::optional StringValue::LastIndexOf(absl::string_view string, int64_t pos) const { if (pos < 0) { return absl::nullopt; } return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> absl::optional { int64_t last_index = -1; int64_t code_points = 0; while (lhs.size() >= string.size()) { if (absl::StartsWith(lhs, string)) { last_index = code_points; } if (code_points >= pos || lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); lhs.remove_prefix(code_units); ++code_points; } if (last_index < 0) return absl::nullopt; return last_index; }, [&](absl::Cord lhs) -> absl::optional { int64_t last_index = -1; int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.StartsWith(string)) { last_index = code_points; } if (code_points >= pos || lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), /*code_point=*/nullptr); lhs.RemovePrefix(code_units); ++code_points; } if (last_index < 0) return absl::nullopt; return last_index; })); } absl::optional StringValue::LastIndexOf(const absl::Cord& string, int64_t pos) const { if (pos < 0) { return absl::nullopt; } return value_.Visit(absl::Overload( [&](absl::string_view lhs) -> absl::optional { int64_t last_index = -1; int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.substr(0, string.size()) == string) { last_index = code_points; } if (code_points >= pos || lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); lhs.remove_prefix(code_units); ++code_points; } if (last_index < 0) return absl::nullopt; return last_index; }, [&](absl::Cord lhs) -> absl::optional { int64_t last_index = -1; int64_t code_points = 0; while (lhs.size() >= string.size()) { if (lhs.StartsWith(string)) { last_index = code_points; } if (code_points >= pos || lhs.size() == string.size()) { break; } size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), /*code_point=*/nullptr); lhs.RemovePrefix(code_units); ++code_points; } if (last_index < 0) return absl::nullopt; return last_index; })); } absl::optional StringValue::LastIndexOf(const StringValue& string, int64_t pos) const { return string.value_.Visit(absl::Overload( [this, pos](absl::string_view rhs) -> absl::optional { return LastIndexOf(rhs, pos); }, [this, pos](const absl::Cord& rhs) -> absl::optional { return LastIndexOf(rhs, pos); })); } namespace { absl::StatusOr SubstringImpl(absl::string_view string, uint64_t start) { size_t size_code_points = 0; size_t size_code_units = 0; while (!string.empty()) { char32_t code_point; size_t code_units; std::tie(code_point, code_units) = cel::internal::Utf8Decode(string); if (size_code_points == start) { return size_code_units; } string.remove_prefix(code_units); ++size_code_points; size_code_units += code_units; } if (size_code_points == start) { return size_code_units; } return absl::InvalidArgumentError( ".substring(): is greater than .size()"); } absl::StatusOr SubstringImpl(const absl::Cord& cord, uint64_t start) { absl::Cord::CharIterator char_begin = cord.char_begin(); absl::Cord::CharIterator char_end = cord.char_end(); size_t size_code_points = 0; size_t size_code_units = 0; while (char_begin != char_end) { char32_t code_point; size_t code_units; std::tie(code_point, code_units) = cel::internal::Utf8Decode(char_begin); if (size_code_points == start) { return cord.Subcord(size_code_units, std::numeric_limits::max()); } absl::Cord::Advance(&char_begin, code_units); ++size_code_points; size_code_units += code_units; } if (size_code_points == start) { return cord; } return absl::InvalidArgumentError( ".substring(): is greater than .size()"); } } // namespace Value StringValue::Substring(int64_t start) const { if (start < 0) { return ErrorValue(absl::InvalidArgumentError( ".substring(): is less than 0")); } if (static_cast(start) > value_.size()) { return ErrorValue(absl::InvalidArgumentError( ".substring(, ): or is greater than " ".size()")); } if (start == 0) { return *this; } switch (value_.GetKind()) { case common_internal::ByteStringKind::kSmall: { absl::StatusOr status_or_index = (SubstringImpl)(value_.GetSmall(), start); if (!status_or_index.ok()) { return ErrorValue(std::move(status_or_index).status()); } StringValue result; result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; result.value_.rep_.small.size = value_.rep_.small.size - *status_or_index; std::memcpy(result.value_.rep_.small.data, value_.rep_.small.data + *status_or_index, result.value_.rep_.small.size); result.value_.rep_.small.arena = value_.rep_.small.arena; return result; } case common_internal::ByteStringKind::kMedium: { absl::StatusOr status_or_index = (SubstringImpl)(value_.GetMedium(), start); if (!status_or_index.ok()) { return ErrorValue(std::move(status_or_index).status()); } StringValue result; result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; result.value_.rep_.medium.size = value_.rep_.medium.size - *status_or_index; result.value_.rep_.medium.data = value_.rep_.medium.data + *status_or_index; result.value_.rep_.medium.owner = value_.rep_.medium.owner; common_internal::StrongRef(result.value_.GetMediumReferenceCount()); return result; } case common_internal::ByteStringKind::kLarge: { absl::StatusOr status_or_cord = (SubstringImpl)(value_.GetLarge(), start); if (!status_or_cord.ok()) { return ErrorValue(std::move(status_or_cord).status()); } return StringValue::Wrap(*std::move(status_or_cord)); } } } namespace { absl::StatusOr> SubstringImpl( absl::string_view string, uint64_t start, uint64_t end) { size_t size_code_points = 0; size_t size_code_units = 0; size_t start_code_units; while (!string.empty()) { if (size_code_points == start) { start_code_units = size_code_units; } if (size_code_points == end) { return std::pair{start_code_units, size_code_units}; } char32_t code_point; size_t code_units; std::tie(code_point, code_units) = cel::internal::Utf8Decode(string); string.remove_prefix(code_units); ++size_code_points; size_code_units += code_units; } if (size_code_points == start && start == end) { return std::pair{size_code_units, size_code_units}; } return absl::InvalidArgumentError( ".substring(, ): or is greater than " ".size()"); } absl::StatusOr SubstringImpl(const absl::Cord& cord, uint64_t start, uint64_t end) { absl::Cord::CharIterator char_begin = cord.char_begin(); absl::Cord::CharIterator char_end = cord.char_end(); size_t size_code_points = 0; size_t size_code_units = 0; size_t start_code_units; while (char_begin != char_end) { if (size_code_points == start) { start_code_units = size_code_units; } if (size_code_points == end) { return cord.Subcord(start_code_units, size_code_points - start_code_units); } char32_t code_point; size_t code_units; std::tie(code_point, code_units) = cel::internal::Utf8Decode(char_begin); absl::Cord::Advance(&char_begin, code_units); ++size_code_points; size_code_units += code_units; } if (size_code_points == start && start == end) { return absl::Cord(); } return absl::InvalidArgumentError( ".substring(, ): or is greater than " ".size()"); } } // namespace Value StringValue::Substring(int64_t start, int64_t end) const { if (start < 0) { return ErrorValue(absl::InvalidArgumentError( ".substring(, ): is less than 0")); } if (end < start) { return ErrorValue(absl::InvalidArgumentError( ".substring(, ): is less than ")); } if (static_cast(start) > value_.size() || static_cast(end) > value_.size()) { return ErrorValue(absl::InvalidArgumentError( ".substring(, ): or is greater than " ".size()")); } switch (value_.GetKind()) { case common_internal::ByteStringKind::kSmall: { absl::StatusOr> status_or_indices = (SubstringImpl)(value_.GetSmall(), start, end); if (!status_or_indices.ok()) { return ErrorValue(std::move(status_or_indices).status()); } StringValue result; result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; result.value_.rep_.small.size = (status_or_indices->second - status_or_indices->first); std::memcpy(result.value_.rep_.small.data, value_.rep_.small.data + status_or_indices->first, result.value_.rep_.small.size); result.value_.rep_.small.arena = value_.rep_.small.arena; return result; } case common_internal::ByteStringKind::kMedium: { absl::StatusOr> status_or_indices = (SubstringImpl)(value_.GetMedium(), start, end); if (!status_or_indices.ok()) { return ErrorValue(std::move(status_or_indices).status()); } StringValue result; result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; result.value_.rep_.medium.size = (status_or_indices->second - status_or_indices->first); result.value_.rep_.medium.data = value_.rep_.medium.data + status_or_indices->first; result.value_.rep_.medium.owner = value_.rep_.medium.owner; common_internal::StrongRef(result.value_.GetMediumReferenceCount()); return result; } case common_internal::ByteStringKind::kLarge: { absl::StatusOr status_or_cord = (SubstringImpl)(value_.GetLarge(), start, end); if (!status_or_cord.ok()) { return ErrorValue(std::move(status_or_cord).status()); } return StringValue::Wrap(*std::move(status_or_cord)); } } } namespace { bool LowerAsciiImpl(absl::string_view in, std::string* absl_nonnull out) { if (in.empty()) { return false; } bool needs_conversion = false; for (char c : in) { if (absl::ascii_isupper(c)) { needs_conversion = true; break; } } if (!needs_conversion) { return false; } *out = absl::AsciiStrToLower(in); return true; } absl::Cord LowerAsciiImpl(const absl::Cord& in) { if (in.empty()) { return in; } size_t pos = 0; bool needs_conversion = false; for (char c : in.Chars()) { if (absl::ascii_isupper(c)) { needs_conversion = true; break; } pos++; } if (!needs_conversion) { return in; } absl::Cord out = in.Subcord(0, pos); absl::Cord rest = in.Subcord(pos, in.size() - pos); std::string suffix; suffix.resize(rest.size()); size_t current = 0; for (char c : rest.Chars()) { suffix[current++] = absl::ascii_tolower(c); } out.Append(std::move(suffix)); return out; } } // namespace StringValue StringValue::LowerAscii(google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); switch (value_.GetKind()) { case common_internal::ByteStringKind::kSmall: { std::string out; if (!(LowerAsciiImpl)(value_.GetSmall(), &out)) { return *this; } return StringValue::From(std::move(out), arena); } case common_internal::ByteStringKind::kMedium: { std::string out; if (!(LowerAsciiImpl)(value_.GetMedium(), &out)) { return *this; } return StringValue::From(std::move(out), arena); } case common_internal::ByteStringKind::kLarge: return StringValue::Wrap((LowerAsciiImpl)(value_.GetLarge())); } } namespace { bool UpperAsciiImpl(absl::string_view in, std::string* absl_nonnull out) { if (in.empty()) { return false; } bool needs_conversion = false; for (char c : in) { if (absl::ascii_islower(c)) { needs_conversion = true; break; } } if (!needs_conversion) { return false; } *out = absl::AsciiStrToUpper(in); return true; } absl::Cord UpperAsciiImpl(const absl::Cord& in) { if (in.empty()) { return in; } size_t pos = 0; bool needs_conversion = false; for (char c : in.Chars()) { if (absl::ascii_islower(c)) { needs_conversion = true; break; } pos++; } if (!needs_conversion) { return in; } absl::Cord out = in.Subcord(0, pos); absl::Cord rest = in.Subcord(pos, in.size() - pos); std::string suffix; suffix.resize(rest.size()); size_t current = 0; for (char c : rest.Chars()) { suffix[current++] = absl::ascii_toupper(c); } out.Append(std::move(suffix)); return out; } } // namespace StringValue StringValue::UpperAscii(google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(arena != nullptr); switch (value_.GetKind()) { case common_internal::ByteStringKind::kSmall: { std::string out; if (!(UpperAsciiImpl)(value_.GetSmall(), &out)) { return *this; } return StringValue::From(std::move(out), arena); } case common_internal::ByteStringKind::kMedium: { std::string out; if (!(UpperAsciiImpl)(value_.GetMedium(), &out)) { return *this; } return StringValue::From(std::move(out), arena); } case common_internal::ByteStringKind::kLarge: return StringValue::Wrap((UpperAsciiImpl)(value_.GetLarge())); } } namespace { // Per CEL spec, checking for Unicode whitespace. bool IsUnicodeWhitespace(char32_t c) { if (c <= 0x0020) { return c == 0x0020 || (c >= 0x0009 && c <= 0x000D); } if (c > 0x3000) return false; if (c == 0x0085 || c == 0x00a0 || c == 0x1680) return true; if (c >= 0x2000 && c <= 0x200a) return true; return c == 0x2028 || c == 0x2029 || c == 0x202f || c == 0x205f || c == 0x3000; } std::pair TrimImpl(absl::string_view string) { absl::string_view temp_string = string; size_t left_trim_bytes = 0; while (!temp_string.empty()) { char32_t c; size_t char_len = cel::internal::Utf8Decode(temp_string, &c); if (!IsUnicodeWhitespace(c)) { break; } temp_string.remove_prefix(char_len); left_trim_bytes += char_len; } if (left_trim_bytes == string.size()) { return {left_trim_bytes, 0}; } size_t last_non_ws_end_bytes = 0; size_t current_pos_bytes = 0; temp_string = string; while (!temp_string.empty()) { char32_t c; size_t char_len = cel::internal::Utf8Decode(temp_string, &c); if (!IsUnicodeWhitespace(c)) { last_non_ws_end_bytes = current_pos_bytes + char_len; } current_pos_bytes += char_len; temp_string.remove_prefix(char_len); } return {left_trim_bytes, string.size() - last_non_ws_end_bytes}; } absl::Cord TrimImpl(const absl::Cord& cord) { size_t left_trim_bytes = 0; { absl::Cord::CharIterator begin = cord.char_begin(); const absl::Cord::CharIterator end = cord.char_end(); while (begin != end) { char32_t c; size_t char_len; std::tie(c, char_len) = cel::internal::Utf8Decode(begin); if (!IsUnicodeWhitespace(c)) { break; } absl::Cord::Advance(&begin, char_len); left_trim_bytes += char_len; } } if (left_trim_bytes == cord.size()) { return absl::Cord(); } absl::Cord ltrimmed = cord.Subcord(left_trim_bytes, cord.size() - left_trim_bytes); size_t last_non_ws_end_bytes = 0; size_t current_pos_bytes = 0; { absl::Cord::CharIterator begin = ltrimmed.char_begin(); const absl::Cord::CharIterator end = ltrimmed.char_end(); while (begin != end) { char32_t c; size_t char_len; std::tie(c, char_len) = cel::internal::Utf8Decode(begin); if (!IsUnicodeWhitespace(c)) { last_non_ws_end_bytes = current_pos_bytes + char_len; } absl::Cord::Advance(&begin, char_len); current_pos_bytes += char_len; } } return ltrimmed.Subcord(0, last_non_ws_end_bytes); } } // namespace StringValue StringValue::Trim() const { switch (value_.GetKind()) { case common_internal::ByteStringKind::kSmall: { std::pair trims = (TrimImpl)(value_.GetSmall()); StringValue result; result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; result.value_.rep_.small.size = value_.rep_.small.size - trims.first - trims.second; std::memcpy(result.value_.rep_.small.data, value_.rep_.small.data + trims.first, result.value_.rep_.small.size); result.value_.rep_.small.arena = value_.GetSmallArena(); return result; } case common_internal::ByteStringKind::kMedium: { std::pair trims = (TrimImpl)(value_.GetMedium()); StringValue result; result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; result.value_.rep_.medium.size = value_.rep_.medium.size - trims.first - trims.second; result.value_.rep_.medium.data = value_.rep_.medium.data + trims.first; result.value_.rep_.medium.owner = value_.rep_.medium.owner; common_internal::StrongRef(result.value_.GetMediumReferenceCount()); return result; } case common_internal::ByteStringKind::kLarge: { return StringValue::Wrap((TrimImpl)(value_.GetLarge())); } } } namespace { void AppendQuoteCodePoint(char32_t code_point, std::string& dst) { switch (code_point) { case '\a': dst.append("\\a"); break; case '\b': dst.append("\\b"); break; case '\f': dst.append("\\f"); break; case '\n': dst.append("\\n"); break; case '\r': dst.append("\\r"); break; case '\t': dst.append("\\t"); break; case '\v': dst.append("\\v"); break; case '\\': dst.append("\\\\"); break; case '\"': dst.append("\\\""); break; default: cel::internal::Utf8Encode(code_point, &dst); break; } } } // namespace StringValue StringValue::Quote(google::protobuf::Arena* absl_nonnull arena) const { return value_.Visit(absl::Overload( [&](absl::string_view rep) -> StringValue { std::string result; result.push_back('\"'); while (!rep.empty()) { char32_t code_point; size_t code_units; std::tie(code_point, code_units) = cel::internal::Utf8Decode(rep); AppendQuoteCodePoint(code_point, result); rep.remove_prefix(code_units); } result.push_back('\"'); return StringValue::From(std::move(result), arena); }, [&](const absl::Cord& rep) -> StringValue { absl::Cord::CharIterator begin = rep.char_begin(); absl::Cord::CharIterator end = rep.char_end(); std::string result; result.push_back('\"'); while (begin != end) { char32_t code_point; size_t code_units; std::tie(code_point, code_units) = cel::internal::Utf8Decode(begin); AppendQuoteCodePoint(code_point, result); absl::Cord::Advance(&begin, code_units); } result.push_back('\"'); return StringValue::From(std::move(result), arena); })); } StringValue StringValue::Reverse(google::protobuf::Arena* absl_nonnull arena) const { return value_.Visit(absl::Overload( [arena](absl::string_view string) -> StringValue { if (string.empty()) { return StringValue(); } std::string reversed; reversed.reserve(string.size()); const char* ptr = string.data() + string.size(); const char* begin = string.data(); while (ptr > begin) { const char* char_end = ptr; --ptr; // Back up to beginning of encoded UTF-8 code point. while (ptr > begin && (*ptr & 0xC0) == 0x80) { --ptr; } reversed.append(ptr, char_end - ptr); } return StringValue::From(std::move(reversed), arena); }, [arena](const absl::Cord& cord) -> StringValue { if (cord.empty()) { return StringValue(); } std::vector code_points; absl::Cord::CharIterator char_begin = cord.char_begin(); absl::Cord::CharIterator char_end = cord.char_end(); while (char_begin != char_end) { char32_t code_point; size_t code_units = cel::internal::Utf8Decode(char_begin, &code_point); code_points.push_back(code_point); absl::Cord::Advance(&char_begin, code_units); } std::string reversed; reversed.reserve(cord.size()); for (auto it = code_points.rbegin(); it != code_points.rend(); ++it) { cel::internal::Utf8Encode(*it, &reversed); } return StringValue::From(std::move(reversed), arena); })); } absl::StatusOr StringValue::Join( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { Value result; CEL_RETURN_IF_ERROR( Join(list, descriptor_pool, message_factory, arena, &result)); return result; } absl::Status StringValue::Join( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); std::string joined; CEL_ASSIGN_OR_RETURN(auto iterator, list.NewIterator()); CEL_ASSIGN_OR_RETURN( absl::optional element, iterator->Next1(descriptor_pool, message_factory, arena)); if (element) { if (auto string_element = element->AsString(); string_element) { string_element->AppendToString(&joined); } else { ABSL_DCHECK(!element->Is()); *result = ErrorValue(runtime_internal::CreateNoMatchingOverloadError("join")); return absl::OkStatus(); } while (true) { CEL_ASSIGN_OR_RETURN( element, iterator->Next1(descriptor_pool, message_factory, arena)); if (!element) { break; } AppendToString(&joined); if (auto string_element = element->AsString(); string_element) { string_element->AppendToString(&joined); } else { ABSL_DCHECK(!element->Is()); *result = ErrorValue(runtime_internal::CreateNoMatchingOverloadError("join")); return absl::OkStatus(); } } } if (joined.size() > common_internal::kSmallByteStringCapacity) { joined.shrink_to_fit(); } *result = StringValue::From(std::move(joined), arena); return absl::OkStatus(); } absl::StatusOr StringValue::Split( const StringValue& delimiter, int64_t limit, google::protobuf::Arena* absl_nonnull arena) const { Value result; CEL_RETURN_IF_ERROR(Split(delimiter, limit, arena, &result)); return result; } absl::Status StringValue::Split(const StringValue& delimiter, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return Split(delimiter, -1, arena, result); } absl::StatusOr StringValue::Split( const StringValue& delimiter, google::protobuf::Arena* absl_nonnull arena) const { Value result; CEL_RETURN_IF_ERROR(Split(delimiter, -1, arena, &result)); return result; } absl::Status StringValue::Split(const StringValue& delimiter, int64_t limit, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (limit == 0) { // Per spec, when limit is 0 return an empty list. *result = ListValue(); return absl::OkStatus(); } if (limit < 0) { // Per spec, when limit is negative treat it as unlimited splits. limit = std::numeric_limits::max(); } std::vector> splits; size_t pos = 0; const size_t len = value_.size(); if (delimiter.IsEmpty()) { value_.Visit(absl::Overload( [&](absl::string_view s) { while (pos < len && limit > 1) { size_t char_len = cel::internal::Utf8Decode(s.substr(pos), nullptr); splits.push_back({pos, pos + char_len}); pos += char_len; --limit; } }, [&](const absl::Cord& s) { while (pos < len && limit > 1) { size_t char_len = cel::internal::Utf8Decode( s.Subcord(pos, len - pos).char_begin(), nullptr); splits.push_back({pos, pos + char_len}); pos += char_len; --limit; } })); } else { while (pos < len && limit > 1) { absl::optional next = value_.Find(delimiter.value_, pos); if (!next) { break; } splits.push_back(std::pair{pos, *next}); pos = *next + delimiter.value_.size(); --limit; ABSL_DCHECK_LE(pos, len); } } if (splits.empty() || !delimiter.IsEmpty() || pos < len) { splits.push_back(std::pair{pos, len}); } auto builder = NewListValueBuilder(arena); builder->Reserve(splits.size()); for (const std::pair& split : splits) { builder->UnsafeAdd( StringValue(value_.Substring(split.first, split.second))); } *result = std::move(*builder).Build(); return absl::OkStatus(); } absl::StatusOr StringValue::Replace( const StringValue& needle, const StringValue& replacement, int64_t limit, google::protobuf::Arena* absl_nonnull arena) const { Value result; CEL_RETURN_IF_ERROR(Replace(needle, replacement, limit, arena, &result)); return result; } absl::Status StringValue::Replace(const StringValue& needle, const StringValue& replacement, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return Replace(needle, replacement, -1, arena, result); } absl::StatusOr StringValue::Replace( const StringValue& needle, const StringValue& replacement, google::protobuf::Arena* absl_nonnull arena) const { Value result; CEL_RETURN_IF_ERROR(Replace(needle, replacement, -1, arena, &result)); return result; } absl::Status StringValue::Replace(const StringValue& needle, const StringValue& replacement, int64_t limit, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (limit == 0) { // Per spec, when limit is 0 return the original string. *result = *this; return absl::OkStatus(); } if (limit < 0) { // Per spec, when limit is negative treat it as unlimited replacements. limit = std::numeric_limits::max(); } size_t pos = 0; const size_t len = value_.size(); const size_t needle_len = needle.value_.size(); std::string res_str; if (needle.IsEmpty()) { value_.Visit(absl::Overload( [&](absl::string_view s) { while (pos < len && limit > 0) { replacement.AppendToString(&res_str); size_t char_len = cel::internal::Utf8Decode(s.substr(pos), nullptr); value_.Substring(pos, pos + char_len).AppendToString(&res_str); pos += char_len; --limit; } }, [&](const absl::Cord& s) { while (pos < len && limit > 0) { replacement.AppendToString(&res_str); size_t char_len = cel::internal::Utf8Decode( s.Subcord(pos, len - pos).char_begin(), nullptr); value_.Substring(pos, pos + char_len).AppendToString(&res_str); pos += char_len; --limit; } })); if (limit > 0) { replacement.AppendToString(&res_str); } } else { while (pos < len && limit > 0) { absl::optional next = value_.Find(needle.value_, pos); if (!next) { break; } value_.Substring(pos, *next).AppendToString(&res_str); replacement.AppendToString(&res_str); pos = *next + needle_len; --limit; } } if (pos < len) { value_.Substring(pos, len).AppendToString(&res_str); } if (res_str.size() > common_internal::kSmallByteStringCapacity) { res_str.shrink_to_fit(); } *result = StringValue::From(std::move(res_str), arena); return absl::OkStatus(); } Value StringValue::CharAt(int64_t pos) const { if (pos < 0) { return ErrorValue(absl::InvalidArgumentError( ".charAt(): is less than 0")); } return value_.Visit(absl::Overload( [this, pos](absl::string_view rep) mutable -> Value { while (!rep.empty()) { char32_t code_point; size_t code_units; std::tie(code_point, code_units) = cel::internal::Utf8Decode(rep); if (pos == 0) { StringValue result; result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; result.value_.rep_.small.size = cel::internal::Utf8Encode( code_point, result.value_.rep_.small.data); result.value_.rep_.small.arena = value_.GetArena(); return result; } rep.remove_prefix(code_units); --pos; } // If we exit the loop, we iterated through all the code points in // `rep`. `pos == 0` means we were looking for a character at index // `size()`, which is defined to return an empty string. if (pos == 0) { return StringValue(); } return ErrorValue(absl::InvalidArgumentError( ".charAt(): is greater than .size()")); }, [pos](const absl::Cord& rep) mutable -> Value { absl::Cord::CharIterator begin = rep.char_begin(); absl::Cord::CharIterator end = rep.char_end(); while (begin != end) { char32_t code_point; size_t code_units; std::tie(code_point, code_units) = cel::internal::Utf8Decode(begin); if (pos == 0) { StringValue result; result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; result.value_.rep_.small.size = cel::internal::Utf8Encode( code_point, result.value_.rep_.small.data); result.value_.rep_.small.arena = nullptr; return result; } absl::Cord::Advance(&begin, code_units); --pos; } // If we exit the loop, we iterated through all the code points in // `rep`. `pos == 0` means we were looking for a character at index // `size()`, which is defined to return an empty string. if (pos == 0) { return StringValue(); } return ErrorValue(absl::InvalidArgumentError( ".charAt(): is greater than .size()")); })); } } // namespace cel ================================================ FILE: common/values/string_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRING_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRING_VALUE_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/allocator.h" #include "common/arena.h" #include "common/internal/byte_string.h" #include "common/memory.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class ListValue; class StringValue; namespace common_internal { absl::string_view LegacyStringValue(const StringValue& value, bool stable, google::protobuf::Arena* absl_nonnull arena); } // namespace common_internal // `StringValue` represents values of the primitive `string` type. class StringValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kString; static StringValue From(const char* absl_nullable value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static StringValue From(absl::string_view value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static StringValue From(const absl::Cord& value); static StringValue From(std::string&& value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static StringValue Wrap(absl::string_view value, google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND); static StringValue Wrap(absl::string_view value) = delete; static StringValue Wrap(const absl::Cord& value); static StringValue Wrap(std::string&& value) = delete; static StringValue Wrap(std::string&& value, google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) = delete; // Returns a StringValue that aliases the provided string. Caller must ensure // the provided string outlives the use of the returned StringValue. static StringValue WrapUnsafe(absl::string_view value); static StringValue Concat(const StringValue& lhs, const StringValue& rhs, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND); ABSL_DEPRECATED("Use From") explicit StringValue(const char* absl_nullable value) : value_(value) {} ABSL_DEPRECATED("Use From") explicit StringValue(absl::string_view value) : value_(value) {} ABSL_DEPRECATED("Use From") explicit StringValue(const absl::Cord& value) : value_(value) {} ABSL_DEPRECATED("Use From") explicit StringValue(std::string&& value) : value_(std::move(value)) {} ABSL_DEPRECATED("Use From") StringValue(Allocator<> allocator, const char* absl_nullable value) : value_(allocator, value) {} ABSL_DEPRECATED("Use From") StringValue(Allocator<> allocator, absl::string_view value) : value_(allocator, value) {} ABSL_DEPRECATED("Use From") StringValue(Allocator<> allocator, const absl::Cord& value) : value_(allocator, value) {} ABSL_DEPRECATED("Use From") StringValue(Allocator<> allocator, std::string&& value) : value_(allocator, std::move(value)) {} ABSL_DEPRECATED("Use Wrap") StringValue(Borrower borrower, absl::string_view value) : value_(borrower, value) {} ABSL_DEPRECATED("Use Wrap") StringValue(Borrower borrower, const absl::Cord& value) : value_(borrower, value) {} StringValue() = default; StringValue(const StringValue&) = default; StringValue(StringValue&&) = default; StringValue& operator=(const StringValue&) = default; StringValue& operator=(StringValue&&) = default; constexpr ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return StringType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; StringValue Clone(google::protobuf::Arena* absl_nonnull arena) const; bool IsZeroValue() const { return NativeValue([](const auto& value) -> bool { return value.empty(); }); } ABSL_DEPRECATED("Use ToString()") std::string NativeString() const { return value_.ToString(); } ABSL_DEPRECATED("Use ToStringView()") absl::string_view NativeString( std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.ToStringView(&scratch); } ABSL_DEPRECATED("Use ToCord()") absl::Cord NativeCord() const { return value_.ToCord(); } template ABSL_DEPRECATED("Use TryFlat()") std::common_type_t< std::invoke_result_t, std::invoke_result_t> NativeValue(Visitor&& visitor) const { return value_.Visit(std::forward(visitor)); } void swap(StringValue& other) noexcept { using std::swap; swap(value_, other.value_); } size_t Size() const; bool IsEmpty() const; bool Equals(absl::string_view string) const; bool Equals(const absl::Cord& string) const; bool Equals(const StringValue& string) const; int Compare(absl::string_view string) const; int Compare(const absl::Cord& string) const; int Compare(const StringValue& string) const; bool StartsWith(absl::string_view string) const; bool StartsWith(const absl::Cord& string) const; bool StartsWith(const StringValue& string) const; bool EndsWith(absl::string_view string) const; bool EndsWith(const absl::Cord& string) const; bool EndsWith(const StringValue& string) const; bool Contains(absl::string_view string) const; bool Contains(const absl::Cord& string) const; bool Contains(const StringValue& string) const; // Returns the 0-based index of the first occurrence of `string` in this // string, or `absl::nullopt` if `string` is not found. absl::optional IndexOf(absl::string_view string) const; absl::optional IndexOf(const absl::Cord& string) const; absl::optional IndexOf(const StringValue& string) const; // Returns the 0-based index of the first occurrence of `string` in this // string at or after `pos`, or `absl::nullopt` if `string` is not found. absl::optional IndexOf(absl::string_view string, int64_t pos) const; absl::optional IndexOf(const absl::Cord& string, int64_t pos) const; absl::optional IndexOf(const StringValue& string, int64_t pos) const; // Returns the 0-based index of the last occurrence of `string` in this // string, or `absl::nullopt` if `string` is not found. absl::optional LastIndexOf(absl::string_view string) const; absl::optional LastIndexOf(const absl::Cord& string) const; absl::optional LastIndexOf(const StringValue& string) const; // Returns the 0-based index of the last occurrence of `string` in this // string at or before `pos`, or `absl::nullopt` if `string` is not found. absl::optional LastIndexOf(absl::string_view string, int64_t pos) const; absl::optional LastIndexOf(const absl::Cord& string, int64_t pos) const; absl::optional LastIndexOf(const StringValue& string, int64_t pos) const; Value Substring(int64_t start) const; Value Substring(int64_t start, int64_t end) const; // Returns a new `StringValue` with all lowercase ASCII characters // converted to lowercase. StringValue LowerAscii(google::protobuf::Arena* absl_nonnull arena) const; // Returns a new `StringValue` with all lowercase ASCII characters // converted to uppercase. StringValue UpperAscii(google::protobuf::Arena* absl_nonnull arena) const; StringValue Trim() const; // Returns a new `StringValue` with the string surrounded by double quotes. StringValue Quote(google::protobuf::Arena* absl_nonnull arena) const; // Returns a new `StringValue` with the characters in reverse order. StringValue Reverse(google::protobuf::Arena* absl_nonnull arena) const; // Joins the elements of `list` with this string using `separator` as the // separator. absl::Status Join(const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; absl::StatusOr Join( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; // Splits this string on `delimiter`, returning a list of strings. If `limit` // is provided and non-negative, the string is split into at most `limit` // substrings. absl::Status Split(const StringValue& delimiter, int64_t limit, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; absl::StatusOr Split(const StringValue& delimiter, int64_t limit, google::protobuf::Arena* absl_nonnull arena) const; absl::Status Split(const StringValue& delimiter, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; absl::StatusOr Split(const StringValue& delimiter, google::protobuf::Arena* absl_nonnull arena) const; // Replaces occurrences of `needle` with `replacement`. If `limit` is provided // and non-negative, only the first `limit` occurrences are replaced. absl::Status Replace(const StringValue& needle, const StringValue& replacement, int64_t limit, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; absl::StatusOr Replace(const StringValue& needle, const StringValue& replacement, int64_t limit, google::protobuf::Arena* absl_nonnull arena) const; absl::Status Replace(const StringValue& needle, const StringValue& replacement, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; absl::StatusOr Replace(const StringValue& needle, const StringValue& replacement, google::protobuf::Arena* absl_nonnull arena) const; // Returns the character at `pos` as a new `StringValue`. `pos` is a // 0-based index based on Unicode code points. Returns `ErrorValue` if `pos` // is out of range. Value CharAt(int64_t pos) const; absl::optional TryFlat() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.TryFlat(); } std::string ToString() const { return value_.ToString(); } void CopyToString(std::string* absl_nonnull out) const { value_.CopyToString(out); } void AppendToString(std::string* absl_nonnull out) const { value_.AppendToString(out); } absl::Cord ToCord() const { return value_.ToCord(); } void CopyToCord(absl::Cord* absl_nonnull out) const { value_.CopyToCord(out); } void AppendToCord(absl::Cord* absl_nonnull out) const { value_.AppendToCord(out); } absl::string_view ToStringView( std::string* absl_nonnull scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.ToStringView(scratch); } template friend H AbslHashValue(H state, const StringValue& string) { return H::combine(std::move(state), string.value_); } friend bool operator==(const StringValue& lhs, const StringValue& rhs) { return lhs.value_ == rhs.value_; } friend bool operator<(const StringValue& lhs, const StringValue& rhs) { return lhs.value_ < rhs.value_; } private: friend class common_internal::ValueMixin; friend absl::string_view common_internal::LegacyStringValue( const StringValue& value, bool stable, google::protobuf::Arena* absl_nonnull arena); friend struct ArenaTraits; explicit StringValue(common_internal::ByteString value) noexcept : value_(std::move(value)) {} common_internal::ByteString value_; }; inline void swap(StringValue& lhs, StringValue& rhs) noexcept { lhs.swap(rhs); } inline bool operator==(const StringValue& lhs, absl::string_view rhs) { return lhs.Equals(rhs); } inline bool operator==(absl::string_view lhs, const StringValue& rhs) { return rhs == lhs; } inline bool operator==(const StringValue& lhs, const absl::Cord& rhs) { return lhs.Equals(rhs); } inline bool operator==(const absl::Cord& lhs, const StringValue& rhs) { return rhs == lhs; } inline bool operator!=(const StringValue& lhs, absl::string_view rhs) { return !operator==(lhs, rhs); } inline bool operator!=(absl::string_view lhs, const StringValue& rhs) { return !operator==(lhs, rhs); } inline bool operator!=(const StringValue& lhs, const absl::Cord& rhs) { return !operator==(lhs, rhs); } inline bool operator!=(const absl::Cord& lhs, const StringValue& rhs) { return !operator==(lhs, rhs); } inline bool operator!=(const StringValue& lhs, const StringValue& rhs) { return !operator==(lhs, rhs); } inline bool operator<(const StringValue& lhs, absl::string_view rhs) { return lhs.Compare(rhs) < 0; } inline bool operator<(absl::string_view lhs, const StringValue& rhs) { return rhs.Compare(lhs) > 0; } inline bool operator<(const StringValue& lhs, const absl::Cord& rhs) { return lhs.Compare(rhs) < 0; } inline bool operator<(const absl::Cord& lhs, const StringValue& rhs) { return rhs.Compare(lhs) > 0; } inline std::ostream& operator<<(std::ostream& out, const StringValue& value) { return out << value.DebugString(); } inline StringValue StringValue::From(const char* absl_nullable value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { return From(absl::NullSafeStringView(value), arena); } inline StringValue StringValue::From(absl::string_view value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(arena != nullptr); return StringValue(arena, value); } inline StringValue StringValue::From(const absl::Cord& value) { return StringValue(value); } inline StringValue StringValue::From(std::string&& value, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(arena != nullptr); return StringValue(arena, std::move(value)); } inline StringValue StringValue::Wrap(absl::string_view value, google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(arena != nullptr); return StringValue(Borrower::Arena(arena), value); } inline StringValue StringValue::WrapUnsafe(absl::string_view value) { return StringValue(common_internal::ByteString::FromExternal(value)); } inline StringValue StringValue::Wrap(const absl::Cord& value) { return StringValue(value); } namespace common_internal { inline absl::string_view LegacyStringValue(const StringValue& value, bool stable, google::protobuf::Arena* absl_nonnull arena) { return LegacyByteString(value.value_, stable, arena); } } // namespace common_internal template <> struct ArenaTraits { using constructible = std::true_type; static bool trivially_destructible(const StringValue& value) { return ArenaTraits<>::trivially_destructible(value.value_); } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRING_VALUE_H_ ================================================ FILE: common/values/string_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "common/values/int_value.h" #include "internal/testing.h" #include "runtime/internal/errors.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::testing::Eq; using ::testing::Optional; using StringValueTest = common_internal::ValueTest<>; TEST_F(StringValueTest, Kind) { EXPECT_EQ(StringValue("foo").kind(), StringValue::kKind); EXPECT_EQ(Value(StringValue(absl::Cord("foo"))).kind(), StringValue::kKind); } TEST_F(StringValueTest, DebugString) { { std::ostringstream out; out << StringValue("foo"); EXPECT_EQ(out.str(), "\"foo\""); } { std::ostringstream out; out << StringValue(absl::MakeFragmentedCord({"f", "o", "o"})); EXPECT_EQ(out.str(), "\"foo\""); } { std::ostringstream out; out << Value(StringValue(absl::Cord("foo"))); EXPECT_EQ(out.str(), "\"foo\""); } } TEST_F(StringValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(StringValue("foo").ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(string_value: "foo")pb")); } TEST_F(StringValueTest, NativeValue) { std::string scratch; EXPECT_EQ(StringValue("foo").NativeString(), "foo"); EXPECT_EQ(StringValue("foo").NativeString(scratch), "foo"); EXPECT_EQ(StringValue("foo").NativeCord(), "foo"); } TEST_F(StringValueTest, TryFlat) { EXPECT_THAT(StringValue("foo").TryFlat(), Optional(Eq("foo"))); EXPECT_THAT( StringValue(absl::MakeFragmentedCord({"Hello, World!", "World, Hello!"})) .TryFlat(), Eq(absl::nullopt)); } TEST_F(StringValueTest, ToString) { EXPECT_EQ(StringValue("foo").ToString(), "foo"); EXPECT_EQ(StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).ToString(), "foo"); } TEST_F(StringValueTest, CopyToString) { std::string out; StringValue("foo").CopyToString(&out); EXPECT_EQ(out, "foo"); StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).CopyToString(&out); EXPECT_EQ(out, "foo"); } TEST_F(StringValueTest, AppendToString) { std::string out; StringValue("foo").AppendToString(&out); EXPECT_EQ(out, "foo"); StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).AppendToString(&out); EXPECT_EQ(out, "foofoo"); } TEST_F(StringValueTest, ToCord) { EXPECT_EQ(StringValue("foo").ToCord(), "foo"); EXPECT_EQ(StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).ToCord(), "foo"); } TEST_F(StringValueTest, CopyToCord) { absl::Cord out; StringValue("foo").CopyToCord(&out); EXPECT_EQ(out, "foo"); StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).CopyToCord(&out); EXPECT_EQ(out, "foo"); } TEST_F(StringValueTest, AppendToCord) { absl::Cord out; StringValue("foo").AppendToCord(&out); EXPECT_EQ(out, "foo"); StringValue(absl::MakeFragmentedCord({"f", "o", "o"})).AppendToCord(&out); EXPECT_EQ(out, "foofoo"); } TEST_F(StringValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(StringValue("foo")), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(StringValue(absl::Cord("foo")))), NativeTypeId::For()); } TEST_F(StringValueTest, HashValue) { EXPECT_EQ(absl::HashOf(StringValue("foo")), absl::HashOf(absl::string_view("foo"))); EXPECT_EQ(absl::HashOf(StringValue(absl::string_view("foo"))), absl::HashOf(absl::string_view("foo"))); EXPECT_EQ(absl::HashOf(StringValue(absl::Cord("foo"))), absl::HashOf(absl::string_view("foo"))); } TEST_F(StringValueTest, Equality) { EXPECT_NE(StringValue("foo"), "bar"); EXPECT_NE("bar", StringValue("foo")); EXPECT_NE(StringValue("foo"), StringValue("bar")); EXPECT_NE(StringValue("foo"), absl::Cord("bar")); EXPECT_NE(absl::Cord("bar"), StringValue("foo")); } TEST_F(StringValueTest, LessThan) { EXPECT_LT(StringValue("bar"), "foo"); EXPECT_LT("bar", StringValue("foo")); EXPECT_LT(StringValue("bar"), StringValue("foo")); EXPECT_LT(StringValue("bar"), absl::Cord("foo")); EXPECT_LT(absl::Cord("bar"), StringValue("foo")); } TEST_F(StringValueTest, StartsWith) { EXPECT_TRUE( StringValue("This string is large enough to not be stored inline!") .StartsWith(StringValue("This string is large enough"))); EXPECT_TRUE( StringValue("This string is large enough to not be stored inline!") .StartsWith(StringValue(absl::Cord("This string is large enough")))); EXPECT_TRUE( StringValue( absl::Cord("This string is large enough to not be stored inline!")) .StartsWith(StringValue("This string is large enough"))); EXPECT_TRUE( StringValue( absl::Cord("This string is large enough to not be stored inline!")) .StartsWith(StringValue(absl::Cord("This string is large enough")))); } TEST_F(StringValueTest, EndsWith) { EXPECT_TRUE( StringValue("This string is large enough to not be stored inline!") .EndsWith(StringValue("to not be stored inline!"))); EXPECT_TRUE( StringValue("This string is large enough to not be stored inline!") .EndsWith(StringValue(absl::Cord("to not be stored inline!")))); EXPECT_TRUE( StringValue( absl::Cord("This string is large enough to not be stored inline!")) .EndsWith(StringValue("to not be stored inline!"))); EXPECT_TRUE( StringValue( absl::Cord("This string is large enough to not be stored inline!")) .EndsWith(StringValue(absl::Cord("to not be stored inline!")))); } TEST_F(StringValueTest, Contains) { EXPECT_TRUE( StringValue("This string is large enough to not be stored inline!") .Contains(StringValue("string is large enough"))); EXPECT_TRUE( StringValue("This string is large enough to not be stored inline!") .Contains(StringValue(absl::Cord("string is large enough")))); EXPECT_TRUE( StringValue( absl::Cord("This string is large enough to not be stored inline!")) .Contains(StringValue("string is large enough"))); EXPECT_TRUE( StringValue( absl::Cord("This string is large enough to not be stored inline!")) .Contains(StringValue(absl::Cord("string is large enough")))); } TEST_F(StringValueTest, IndexOf) { StringValue big_string = StringValue("This string is large enough to not be stored inline!"); StringValue big_string_cord = StringValue( absl::Cord("This string is large enough to not be stored inline!")); StringValue small_string = StringValue("is"); StringValue small_string_cord = StringValue(absl::Cord("is")); EXPECT_THAT(big_string.IndexOf(small_string), Optional(Eq(2))); EXPECT_THAT(big_string.IndexOf(small_string_cord), Optional(Eq(2))); EXPECT_THAT(big_string_cord.IndexOf(small_string), Optional(Eq(2))); EXPECT_THAT(big_string_cord.IndexOf(small_string_cord), Optional(Eq(2))); EXPECT_THAT(big_string.IndexOf("is"), Optional(Eq(2))); EXPECT_THAT(big_string_cord.IndexOf("is"), Optional(Eq(2))); EXPECT_THAT(big_string_cord.IndexOf("not found"), Eq(absl::nullopt)); EXPECT_THAT(big_string.IndexOf(small_string, 4), Optional(Eq(12))); EXPECT_THAT(big_string.IndexOf(small_string_cord, 4), Optional(Eq(12))); EXPECT_THAT(big_string_cord.IndexOf(small_string, 4), Optional(Eq(12))); EXPECT_THAT(big_string_cord.IndexOf(small_string_cord, 4), Optional(Eq(12))); EXPECT_THAT(big_string.IndexOf("is", 4), Optional(Eq(12))); EXPECT_THAT(big_string_cord.IndexOf("is", 4), Optional(Eq(12))); EXPECT_THAT(big_string.IndexOf(small_string, 13), Eq(absl::nullopt)); EXPECT_THAT(big_string.IndexOf(small_string_cord, 13), Eq(absl::nullopt)); EXPECT_THAT(big_string_cord.IndexOf(small_string, 13), Eq(absl::nullopt)); EXPECT_THAT(big_string_cord.IndexOf(small_string_cord, 13), Eq(absl::nullopt)); EXPECT_THAT(big_string.IndexOf(absl::Cord("is"), 4), Optional(Eq(12))); EXPECT_THAT(big_string_cord.IndexOf(absl::Cord("is"), 4), Optional(Eq(12))); EXPECT_THAT(big_string.IndexOf(absl::Cord("is"), 13), Eq(absl::nullopt)); EXPECT_THAT(big_string_cord.IndexOf(absl::Cord("is"), 13), Eq(absl::nullopt)); } TEST_F(StringValueTest, LowerAscii) { EXPECT_EQ(StringValue("UPPER lower").LowerAscii(arena()), "upper lower"); EXPECT_EQ(StringValue(absl::Cord("UPPER lower")).LowerAscii(arena()), "upper lower"); EXPECT_EQ(StringValue("upper lower").LowerAscii(arena()), "upper lower"); EXPECT_EQ(StringValue(absl::Cord("upper lower")).LowerAscii(arena()), "upper lower"); EXPECT_EQ(StringValue("").LowerAscii(arena()), ""); EXPECT_EQ(StringValue(absl::Cord("")).LowerAscii(arena()), ""); const std::string kLongMixed = "A long STRING with MiXeD case to test conversion to lower case!"; const std::string kLongLower = "a long string with mixed case to test conversion to lower case!"; EXPECT_EQ(StringValue(absl::Cord(kLongMixed)).LowerAscii(arena()), kLongLower); std::string very_long_mixed(10000, 'A'); std::string very_long_lower(10000, 'a'); EXPECT_EQ( StringValue(absl::MakeFragmentedCord({very_long_mixed.substr(0, 5000), very_long_mixed.substr(5000)})) .LowerAscii(arena()), very_long_lower); EXPECT_EQ(StringValue(absl::MakeFragmentedCord({"hello", "WORLD"})) .LowerAscii(arena()), "helloworld"); } TEST_F(StringValueTest, UpperAscii) { EXPECT_EQ(StringValue("UPPER lower").UpperAscii(arena()), "UPPER LOWER"); EXPECT_EQ(StringValue(absl::Cord("UPPER lower")).UpperAscii(arena()), "UPPER LOWER"); EXPECT_EQ(StringValue("UPPER LOWER").UpperAscii(arena()), "UPPER LOWER"); EXPECT_EQ(StringValue(absl::Cord("UPPER LOWER")).UpperAscii(arena()), "UPPER LOWER"); EXPECT_EQ(StringValue("").UpperAscii(arena()), ""); EXPECT_EQ(StringValue(absl::Cord("")).UpperAscii(arena()), ""); const std::string kLongMixed = "A long STRING with MiXeD case to test conversion to UPPER case!"; const std::string kLongUpper = "A LONG STRING WITH MIXED CASE TO TEST CONVERSION TO UPPER CASE!"; EXPECT_EQ(StringValue(absl::Cord(kLongMixed)).UpperAscii(arena()), kLongUpper); std::string very_long_mixed(10000, 'a'); std::string very_long_upper(10000, 'A'); EXPECT_EQ( StringValue(absl::MakeFragmentedCord({very_long_mixed.substr(0, 5000), very_long_mixed.substr(5000)})) .UpperAscii(arena()), very_long_upper); EXPECT_EQ(StringValue(absl::MakeFragmentedCord({"HELLO", "world"})) .UpperAscii(arena()), "HELLOWORLD"); } TEST_F(StringValueTest, LastIndexOf) { StringValue big_string = StringValue("This string is large enough to not be stored inline!"); StringValue big_string_cord = StringValue( absl::Cord("This string is large enough to not be stored inline!")); StringValue small_string = StringValue("is"); StringValue small_string_cord = StringValue(absl::Cord("is")); EXPECT_THAT(big_string.LastIndexOf(small_string), Optional(Eq(12))); EXPECT_THAT(big_string.LastIndexOf(small_string_cord), Optional(Eq(12))); EXPECT_THAT(big_string_cord.LastIndexOf(small_string), Optional(Eq(12))); EXPECT_THAT(big_string_cord.LastIndexOf(small_string_cord), Optional(Eq(12))); EXPECT_THAT(big_string.LastIndexOf("is"), Optional(Eq(12))); EXPECT_THAT(big_string_cord.LastIndexOf("is"), Optional(Eq(12))); EXPECT_THAT(big_string_cord.LastIndexOf("not found"), Eq(absl::nullopt)); EXPECT_THAT(big_string.LastIndexOf(small_string, 4), Optional(Eq(2))); EXPECT_THAT(big_string.LastIndexOf(small_string_cord, 4), Optional(Eq(2))); EXPECT_THAT(big_string_cord.LastIndexOf(small_string, 4), Optional(Eq(2))); EXPECT_THAT(big_string_cord.LastIndexOf(small_string_cord, 4), Optional(Eq(2))); EXPECT_THAT(big_string.LastIndexOf("is", 4), Optional(Eq(2))); EXPECT_THAT(big_string_cord.LastIndexOf("is", 4), Optional(Eq(2))); EXPECT_THAT(big_string.LastIndexOf(small_string, 100), Optional(Eq(12))); EXPECT_THAT(big_string.LastIndexOf(small_string_cord, 100), Optional(Eq(12))); EXPECT_THAT(big_string_cord.LastIndexOf(small_string, 100), Optional(Eq(12))); EXPECT_THAT(big_string_cord.LastIndexOf(small_string_cord, 100), Optional(Eq(12))); EXPECT_THAT(big_string.LastIndexOf(absl::Cord("is"), 4), Optional(Eq(2))); EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord("is"), 4), Optional(Eq(2))); EXPECT_THAT(big_string.LastIndexOf(absl::Cord("is"), 100), Optional(Eq(12))); EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord("is"), 100), Optional(Eq(12))); EXPECT_THAT(big_string.LastIndexOf(absl::Cord(""), 100), Optional(Eq(52))); EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord(""), 100), Optional(Eq(52))); } TEST_F(StringValueTest, Trim) { using ::cel::test::StringValueIs; StringValue unpadded = StringValue("no padding"); StringValue front_padded = StringValue(" \t\r\nno padding"); StringValue back_padded = StringValue("no padding \t\r\n"); StringValue both_padded = StringValue(" \t\r\nno padding \t\r\n"); StringValue whitespace = StringValue(" \t\r\n"); StringValue empty = StringValue(""); EXPECT_THAT(unpadded.Trim(), StringValueIs("no padding")); EXPECT_THAT(front_padded.Trim(), StringValueIs("no padding")); EXPECT_THAT(back_padded.Trim(), StringValueIs("no padding")); EXPECT_THAT(both_padded.Trim(), StringValueIs("no padding")); EXPECT_THAT(whitespace.Trim(), StringValueIs("")); EXPECT_THAT(empty.Trim(), StringValueIs("")); StringValue unpadded_cord = StringValue(absl::Cord("no padding")); StringValue front_padded_cord = StringValue(absl::Cord(" \t\r\nno padding")); StringValue back_padded_cord = StringValue(absl::Cord("no padding \t\r\n")); StringValue both_padded_cord = StringValue(absl::Cord(" \t\r\nno padding \t\r\n")); StringValue whitespace_cord = StringValue(absl::Cord(" \t\r\n")); StringValue empty_cord = StringValue(absl::Cord("")); EXPECT_THAT(unpadded_cord.Trim(), StringValueIs("no padding")); EXPECT_THAT(front_padded_cord.Trim(), StringValueIs("no padding")); EXPECT_THAT(back_padded_cord.Trim(), StringValueIs("no padding")); EXPECT_THAT(both_padded_cord.Trim(), StringValueIs("no padding")); EXPECT_THAT(whitespace_cord.Trim(), StringValueIs("")); EXPECT_THAT(empty_cord.Trim(), StringValueIs("")); } TEST_F(StringValueTest, CharAt) { using ::cel::test::ErrorValueIs; using ::cel::test::StringValueIs; StringValue big_string = StringValue("This string is large enough to not be stored inline!"); StringValue big_string_cord = StringValue( absl::Cord("This string is large enough to not be stored inline!")); StringValue small_string = StringValue("abc"); StringValue small_string_cord = StringValue(absl::Cord("abc")); StringValue unicode_string = StringValue("aμc"); StringValue unicode_string_cord = StringValue(absl::Cord("aμc")); EXPECT_THAT(big_string.CharAt(0), StringValueIs("T")); EXPECT_THAT(big_string_cord.CharAt(0), StringValueIs("T")); EXPECT_THAT(small_string.CharAt(1), StringValueIs("b")); EXPECT_THAT(small_string_cord.CharAt(1), StringValueIs("b")); EXPECT_THAT(unicode_string.CharAt(1), StringValueIs("μ")); EXPECT_THAT(unicode_string_cord.CharAt(1), StringValueIs("μ")); EXPECT_THAT( big_string.CharAt(100), ErrorValueIs(absl::InvalidArgumentError( ".charAt(): is greater than .size()"))); EXPECT_THAT( big_string_cord.CharAt(100), ErrorValueIs(absl::InvalidArgumentError( ".charAt(): is greater than .size()"))); EXPECT_THAT(big_string.CharAt(-1), ErrorValueIs(absl::InvalidArgumentError( ".charAt(): is less than 0"))); EXPECT_THAT(big_string_cord.CharAt(-1), ErrorValueIs(absl::InvalidArgumentError( ".charAt(): is less than 0"))); } TEST_F(StringValueTest, Join) { using ::cel::runtime_internal::CreateNoMatchingOverloadError; using ::cel::test::ErrorValueIs; using ::cel::test::StringValueIs; StringValue separator(","); Value result; // Empty list. auto list_builder0 = NewListValueBuilder(arena()); auto list0 = std::move(*list_builder0).Build(); EXPECT_THAT(separator.Join(list0, descriptor_pool(), message_factory(), arena(), &result), IsOk()); EXPECT_THAT(result, StringValueIs("")); // Single element list. auto list_builder1 = NewListValueBuilder(arena()); ASSERT_THAT(list_builder1->Add(StringValue("foo")), IsOk()); auto list1 = std::move(*list_builder1).Build(); EXPECT_THAT(separator.Join(list1, descriptor_pool(), message_factory(), arena(), &result), IsOk()); EXPECT_THAT(result, StringValueIs("foo")); // Multi element list. auto list_builder2 = NewListValueBuilder(arena()); ASSERT_THAT(list_builder2->Add(StringValue("foo")), IsOk()); ASSERT_THAT(list_builder2->Add(StringValue("bar")), IsOk()); ASSERT_THAT(list_builder2->Add(StringValue("baz")), IsOk()); auto list2 = std::move(*list_builder2).Build(); EXPECT_THAT(separator.Join(list2, descriptor_pool(), message_factory(), arena(), &result), IsOk()); EXPECT_THAT(result, StringValueIs("foo,bar,baz")); // List with non-string. auto list_builder3 = NewListValueBuilder(arena()); ASSERT_THAT(list_builder3->Add(IntValue(1)), IsOk()); auto list3 = std::move(*list_builder3).Build(); EXPECT_THAT(separator.Join(list3, descriptor_pool(), message_factory(), arena(), &result), IsOk()); EXPECT_THAT(result, ErrorValueIs(CreateNoMatchingOverloadError("join"))); // List with string and non-string. auto list_builder4 = NewListValueBuilder(arena()); ASSERT_THAT(list_builder4->Add(StringValue("foo")), IsOk()); ASSERT_THAT(list_builder4->Add(IntValue(1)), IsOk()); auto list4 = std::move(*list_builder4).Build(); EXPECT_THAT(separator.Join(list4, descriptor_pool(), message_factory(), arena(), &result), IsOk()); EXPECT_THAT(result, ErrorValueIs(CreateNoMatchingOverloadError("join"))); } TEST_F(StringValueTest, Reverse) { using ::cel::test::StringValueIs; EXPECT_THAT(StringValue().Reverse(arena()), StringValueIs("")); EXPECT_THAT(StringValue("").Reverse(arena()), StringValueIs("")); EXPECT_THAT(StringValue("hello").Reverse(arena()), StringValueIs("olleh")); EXPECT_THAT(StringValue("aμc").Reverse(arena()), StringValueIs("cμa")); EXPECT_THAT( StringValue("This string is large enough to not be stored inline!") .Reverse(arena()), StringValueIs("!enilni derots eb ton ot hguone egral si gnirts sihT")); EXPECT_THAT(StringValue(absl::Cord("hello")).Reverse(arena()), StringValueIs("olleh")); EXPECT_THAT(StringValue(absl::Cord("aμc")).Reverse(arena()), StringValueIs("cμa")); EXPECT_THAT( StringValue( absl::Cord("This string is large enough to not be stored inline!")) .Reverse(arena()), StringValueIs("!enilni derots eb ton ot hguone egral si gnirts sihT")); } } // namespace } // namespace cel ================================================ FILE: common/values/struct_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/type.h" #include "common/value.h" #include "common/values/value_variant.h" #include "internal/status_macros.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { StructType StructValue::GetRuntimeType() const { return variant_.Visit([](const auto& alternative) -> StructType { return alternative.GetRuntimeType(); }); } absl::string_view StructValue::GetTypeName() const { return variant_.Visit([](const auto& alternative) -> absl::string_view { return alternative.GetTypeName(); }); } NativeTypeId StructValue::GetTypeId() const { return variant_.Visit([](const auto& alternative) -> NativeTypeId { return NativeTypeId::Of(alternative); }); } std::string StructValue::DebugString() const { return variant_.Visit([](const auto& alternative) -> std::string { return alternative.DebugString(); }); } absl::Status StructValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.SerializeTo(descriptor_pool, message_factory, output); }); } absl::Status StructValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ConvertToJson(descriptor_pool, message_factory, json); }); } absl::Status StructValue::ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ConvertToJsonObject(descriptor_pool, message_factory, json); }); } absl::Status StructValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.Equal(other, descriptor_pool, message_factory, arena, result); }); } bool StructValue::IsZeroValue() const { return variant_.Visit([](const auto& alternative) -> bool { return alternative.IsZeroValue(); }); } absl::StatusOr StructValue::HasFieldByName(absl::string_view name) const { return variant_.Visit( [name](const auto& alternative) -> absl::StatusOr { return alternative.HasFieldByName(name); }); } absl::StatusOr StructValue::HasFieldByNumber(int64_t number) const { return variant_.Visit( [number](const auto& alternative) -> absl::StatusOr { return alternative.HasFieldByNumber(number); }); } absl::Status StructValue::GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.GetFieldByName(name, unboxing_options, descriptor_pool, message_factory, arena, result); }); } absl::Status StructValue::GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.GetFieldByNumber(number, unboxing_options, descriptor_pool, message_factory, arena, result); }); } absl::Status StructValue::ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.ForEachField(callback, descriptor_pool, message_factory, arena); }); } absl::Status StructValue::Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const { ABSL_DCHECK(!qualifiers.empty()); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); ABSL_DCHECK(count != nullptr); return variant_.Visit([&](const auto& alternative) -> absl::Status { return alternative.Qualify(qualifiers, presence_test, descriptor_pool, message_factory, arena, result, count); }); } namespace common_internal { absl::Status StructValueEqual( const StructValue& lhs, const StructValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (lhs.GetTypeName() != rhs.GetTypeName()) { *result = FalseValue(); return absl::OkStatus(); } absl::flat_hash_map lhs_fields; CEL_RETURN_IF_ERROR(lhs.ForEachField( [&lhs_fields](absl::string_view name, const Value& lhs_value) -> absl::StatusOr { lhs_fields.insert_or_assign(std::string(name), Value(lhs_value)); return true; }, descriptor_pool, message_factory, arena)); bool equal = true; size_t rhs_fields_count = 0; CEL_RETURN_IF_ERROR(rhs.ForEachField( [&](absl::string_view name, const Value& rhs_value) -> absl::StatusOr { auto lhs_field = lhs_fields.find(name); if (lhs_field == lhs_fields.end()) { equal = false; return false; } CEL_RETURN_IF_ERROR(lhs_field->second.Equal( rhs_value, descriptor_pool, message_factory, arena, result)); if (result->IsFalse()) { equal = false; return false; } ++rhs_fields_count; return true; }, descriptor_pool, message_factory, arena)); if (!equal || rhs_fields_count != lhs_fields.size()) { *result = FalseValue(); return absl::OkStatus(); } *result = TrueValue(); return absl::OkStatus(); } absl::Status StructValueEqual( const CustomStructValueInterface& lhs, const StructValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (lhs.GetTypeName() != rhs.GetTypeName()) { *result = FalseValue(); return absl::OkStatus(); } absl::flat_hash_map lhs_fields; CEL_RETURN_IF_ERROR(lhs.ForEachField( [&lhs_fields](absl::string_view name, const Value& lhs_value) -> absl::StatusOr { lhs_fields.insert_or_assign(std::string(name), Value(lhs_value)); return true; }, descriptor_pool, message_factory, arena)); bool equal = true; size_t rhs_fields_count = 0; CEL_RETURN_IF_ERROR(rhs.ForEachField( [&](absl::string_view name, const Value& rhs_value) -> absl::StatusOr { auto lhs_field = lhs_fields.find(name); if (lhs_field == lhs_fields.end()) { equal = false; return false; } CEL_RETURN_IF_ERROR(lhs_field->second.Equal( rhs_value, descriptor_pool, message_factory, arena, result)); if (result->IsFalse()) { equal = false; return false; } ++rhs_fields_count; return true; }, descriptor_pool, message_factory, arena)); if (!equal || rhs_fields_count != lhs_fields.size()) { *result = FalseValue(); return absl::OkStatus(); } *result = TrueValue(); return absl::OkStatus(); } } // namespace common_internal absl::optional StructValue::AsMessage() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional StructValue::AsMessage() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } optional_ref StructValue::AsParsedMessage() const& { if (const auto* alternative = variant_.As(); alternative != nullptr) { return *alternative; } return absl::nullopt; } absl::optional StructValue::AsParsedMessage() && { if (auto* alternative = variant_.As(); alternative != nullptr) { return std::move(*alternative); } return absl::nullopt; } MessageValue StructValue::GetMessage() const& { ABSL_DCHECK(IsMessage()) << *this; return variant_.Get(); } MessageValue StructValue::GetMessage() && { ABSL_DCHECK(IsMessage()) << *this; return std::move(variant_).Get(); } const ParsedMessageValue& StructValue::GetParsedMessage() const& { ABSL_DCHECK(IsParsedMessage()) << *this; return variant_.Get(); } ParsedMessageValue StructValue::GetParsedMessage() && { ABSL_DCHECK(IsParsedMessage()) << *this; return std::move(variant_).Get(); } common_internal::ValueVariant StructValue::ToValueVariant() const& { return variant_.Visit( [](const auto& alternative) -> common_internal::ValueVariant { return common_internal::ValueVariant(alternative); }); } common_internal::ValueVariant StructValue::ToValueVariant() && { return std::move(variant_).Visit( [](auto&& alternative) -> common_internal::ValueVariant { // NOLINTNEXTLINE(bugprone-move-forwarding-reference) return common_internal::ValueVariant(std::move(alternative)); }); } } // namespace cel ================================================ FILE: common/values/struct_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" // `StructValue` is the value representation of `StructType`. `StructValue` // itself is a composed type of more specific runtime representations. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/meta/type_traits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/utility/utility.h" #include "base/attribute.h" #include "common/native_type.h" #include "common/optional_ref.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/custom_struct_value.h" #include "common/values/legacy_struct_value.h" #include "common/values/message_value.h" #include "common/values/parsed_message_value.h" #include "common/values/struct_value_variant.h" #include "common/values/values.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class StructValue; class Value; class StructValue final : private common_internal::StructValueMixin { public: static constexpr ValueKind kKind = ValueKind::kStruct; template < typename T, typename = std::enable_if_t< common_internal::IsStructValueAlternativeV>>> // NOLINTNEXTLINE(google-explicit-constructor) StructValue(T&& value) : variant_(absl::in_place_type>, std::forward(value)) {} // NOLINTNEXTLINE(google-explicit-constructor) StructValue(const MessageValue& other) : variant_(other.ToStructValueVariant()) {} // NOLINTNEXTLINE(google-explicit-constructor) StructValue(MessageValue&& other) : variant_(std::move(other).ToStructValueVariant()) {} StructValue() = default; StructValue(const StructValue&) = default; StructValue(StructValue&& other) = default; StructValue& operator=(const StructValue&) = default; StructValue& operator=(StructValue&&) = default; constexpr ValueKind kind() const { return kKind; } StructType GetRuntimeType() const; absl::string_view GetTypeName() const; NativeTypeId GetTypeId() const; std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; // Like ConvertToJson(), except `json` **MUST** be an instance of // `google.protobuf.Struct`. absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::Equal; bool IsZeroValue() const; absl::Status GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByName; absl::Status GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using StructValueMixin::GetFieldByNumber; absl::StatusOr HasFieldByName(absl::string_view name) const; absl::StatusOr HasFieldByNumber(int64_t number) const; using ForEachFieldCallback = CustomStructValueInterface::ForEachFieldCallback; absl::Status ForEachField( ForEachFieldCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::Status Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result, int* absl_nonnull count) const; using StructValueMixin::Qualify; // Returns `true` if this value is an instance of a message value. If `true` // is returned, it is implied that `IsOpaque()` would also return true. bool IsMessage() const { return IsParsedMessage(); } // Returns `true` if this value is an instance of a parsed message value. If // `true` is returned, it is implied that `IsMessage()` would also return // true. bool IsParsedMessage() const { return variant_.Is(); } // Convenience method for use with template metaprogramming. See // `IsMessage()`. template std::enable_if_t, bool> Is() const { return IsMessage(); } // Convenience method for use with template metaprogramming. See // `IsParsedMessage()`. template std::enable_if_t, bool> Is() const { return IsParsedMessage(); } // Performs a checked cast from a value to a message value, // returning a non-empty optional with either a value or reference to the // message value. Otherwise an empty optional is returned. absl::optional AsMessage() & { return std::as_const(*this).AsMessage(); } absl::optional AsMessage() const&; absl::optional AsMessage() &&; absl::optional AsMessage() const&& { return AsMessage(); } // Performs a checked cast from a value to a parsed message value, // returning a non-empty optional with either a value or reference to the // parsed message value. Otherwise an empty optional is returned. optional_ref AsParsedMessage() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).AsParsedMessage(); } optional_ref AsParsedMessage() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; absl::optional AsParsedMessage() &&; absl::optional AsParsedMessage() const&& { return common_internal::AsOptional(AsParsedMessage()); } // Convenience method for use with template metaprogramming. See // `AsMessage()`. template std::enable_if_t, absl::optional> As() & { return AsMessage(); } template std::enable_if_t, absl::optional> As() const& { return AsMessage(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsMessage(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsMessage(); } // Convenience method for use with template metaprogramming. See // `AsParsedMessage()`. template std::enable_if_t, optional_ref> As() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedMessage(); } template std::enable_if_t, optional_ref> As() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return AsParsedMessage(); } template std::enable_if_t, absl::optional> As() && { return std::move(*this).AsParsedMessage(); } template std::enable_if_t, absl::optional> As() const&& { return std::move(*this).AsParsedMessage(); } // Performs an unchecked cast from a value to a message value. In // debug builds a best effort is made to crash. If `IsMessage()` would return // false, calling this method is undefined behavior. MessageValue GetMessage() & { return std::as_const(*this).GetMessage(); } MessageValue GetMessage() const&; MessageValue GetMessage() &&; MessageValue GetMessage() const&& { return GetMessage(); } // Performs an unchecked cast from a value to a parsed message value. In // debug builds a best effort is made to crash. If `IsParsedMessage()` would // return false, calling this method is undefined behavior. const ParsedMessageValue& GetParsedMessage() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::as_const(*this).GetParsedMessage(); } const ParsedMessageValue& GetParsedMessage() const& ABSL_ATTRIBUTE_LIFETIME_BOUND; ParsedMessageValue GetParsedMessage() &&; ParsedMessageValue GetParsedMessage() const&& { return GetParsedMessage(); } // Convenience method for use with template metaprogramming. See // `GetMessage()`. template std::enable_if_t, MessageValue> Get() & { return GetMessage(); } template std::enable_if_t, MessageValue> Get() const& { return GetMessage(); } template std::enable_if_t, MessageValue> Get() && { return std::move(*this).GetMessage(); } template std::enable_if_t, MessageValue> Get() const&& { return std::move(*this).GetMessage(); } // Convenience method for use with template metaprogramming. See // `GetParsedMessage()`. template std::enable_if_t, const ParsedMessageValue&> Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedMessage(); } template std::enable_if_t, const ParsedMessageValue&> Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return GetParsedMessage(); } template std::enable_if_t, ParsedMessageValue> Get() && { return std::move(*this).GetParsedMessage(); } template std::enable_if_t, ParsedMessageValue> Get() const&& { return std::move(*this).GetParsedMessage(); } friend void swap(StructValue& lhs, StructValue& rhs) noexcept { using std::swap; swap(lhs.variant_, rhs.variant_); } private: friend class Value; friend class common_internal::ValueMixin; friend class common_internal::StructValueMixin; common_internal::ValueVariant ToValueVariant() const&; common_internal::ValueVariant ToValueVariant() &&; // Unlike many of the other derived values, `StructValue` is itself a composed // type. This is to avoid making `StructValue` too big and by extension // `Value` too big. Instead we store the derived `StructValue` values in // `Value` and not `StructValue` itself. common_internal::StructValueVariant variant_; }; inline std::ostream& operator<<(std::ostream& out, const StructValue& value) { return out << value.DebugString(); } template <> struct NativeTypeTraits final { static NativeTypeId Id(const StructValue& value) { return value.GetTypeId(); } }; class StructValueBuilder { public: virtual ~StructValueBuilder() = default; virtual absl::StatusOr> SetFieldByName( absl::string_view name, Value value) = 0; virtual absl::StatusOr> SetFieldByNumber( int64_t number, Value value) = 0; virtual absl::StatusOr Build() && = 0; }; using StructValueBuilderPtr = std::unique_ptr; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_H_ ================================================ FILE: common/values/struct_value_builder.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/struct_value_builder.h" #include #include #include #include #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/allocator.h" #include "common/any.h" #include "common/memory.h" #include "common/value.h" #include "common/value_kind.h" #include "common/values/value_builder.h" #include "extensions/protobuf/internal/map_reflection.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" #include "google/protobuf/message.h" // TODO(uncreated-issue/82): Improve test coverage for struct value builder // TODO(uncreated-issue/76): improve test coverage for JSON/Any namespace cel::common_internal { namespace { absl::StatusOr GetDescriptor( const google::protobuf::Message& message) { const auto* desc = message.GetDescriptor(); if (ABSL_PREDICT_FALSE(desc == nullptr)) { return absl::InvalidArgumentError( absl::StrCat(message.GetTypeName(), " is missing descriptor")); } return desc; } absl::StatusOr> ProtoMessageCopyUsingSerialization( google::protobuf::MessageLite* to, const google::protobuf::MessageLite* from) { ABSL_DCHECK_EQ(to->GetTypeName(), from->GetTypeName()); absl::Cord serialized; if (!from->SerializePartialToString(&serialized)) { return absl::UnknownError( absl::StrCat("failed to serialize `", from->GetTypeName(), "`")); } if (!to->ParsePartialFromString(serialized)) { return absl::UnknownError( absl::StrCat("failed to parse `", to->GetTypeName(), "`")); } return absl::nullopt; } absl::StatusOr> ProtoMessageCopy( google::protobuf::Message* absl_nonnull to_message, const google::protobuf::Descriptor* absl_nonnull to_descriptor, const google::protobuf::Message* absl_nonnull from_message) { CEL_ASSIGN_OR_RETURN(const auto* from_descriptor, GetDescriptor(*from_message)); if (to_descriptor == from_descriptor) { // Same. to_message->CopyFrom(*from_message); return absl::nullopt; } if (to_descriptor->full_name() == from_descriptor->full_name()) { // Same type, different descriptors. return ProtoMessageCopyUsingSerialization(to_message, from_message); } return TypeConversionError(from_descriptor->full_name(), to_descriptor->full_name()); } absl::StatusOr> ProtoMessageFromValueImpl( const Value& value, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory, well_known_types::Reflection* absl_nonnull well_known_types, google::protobuf::Message* absl_nonnull message) { CEL_ASSIGN_OR_RETURN(const auto* to_desc, GetDescriptor(*message)); switch (to_desc->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: { if (auto double_value = value.AsDouble(); double_value) { CEL_RETURN_IF_ERROR(well_known_types->FloatValue().Initialize( message->GetDescriptor())); well_known_types->FloatValue().SetValue( message, static_cast(double_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { if (auto double_value = value.AsDouble(); double_value) { CEL_RETURN_IF_ERROR(well_known_types->DoubleValue().Initialize( message->GetDescriptor())); well_known_types->DoubleValue().SetValue(message, double_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: { if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || int_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); } CEL_RETURN_IF_ERROR(well_known_types->Int32Value().Initialize( message->GetDescriptor())); well_known_types->Int32Value().SetValue( message, static_cast(int_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: { if (auto int_value = value.AsInt(); int_value) { CEL_RETURN_IF_ERROR(well_known_types->Int64Value().Initialize( message->GetDescriptor())); well_known_types->Int64Value().SetValue(message, int_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: { if (auto uint_value = value.AsUint(); uint_value) { if (uint_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("uint64 to uint32 overflow")); } CEL_RETURN_IF_ERROR(well_known_types->UInt32Value().Initialize( message->GetDescriptor())); well_known_types->UInt32Value().SetValue( message, static_cast(uint_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: { if (auto uint_value = value.AsUint(); uint_value) { CEL_RETURN_IF_ERROR(well_known_types->UInt64Value().Initialize( message->GetDescriptor())); well_known_types->UInt64Value().SetValue(message, uint_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: { if (auto string_value = value.AsString(); string_value) { CEL_RETURN_IF_ERROR(well_known_types->StringValue().Initialize( message->GetDescriptor())); well_known_types->StringValue().SetValue(message, string_value->NativeCord()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: { if (auto bytes_value = value.AsBytes(); bytes_value) { CEL_RETURN_IF_ERROR(well_known_types->BytesValue().Initialize( message->GetDescriptor())); well_known_types->BytesValue().SetValue(message, bytes_value->NativeCord()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: { if (auto bool_value = value.AsBool(); bool_value) { CEL_RETURN_IF_ERROR( well_known_types->BoolValue().Initialize(message->GetDescriptor())); well_known_types->BoolValue().SetValue(message, bool_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: { google::protobuf::io::CordOutputStream serialized; CEL_RETURN_IF_ERROR(value.SerializeTo(pool, factory, &serialized)); std::string type_url; switch (value.kind()) { case ValueKind::kNull: type_url = MakeTypeUrl("google.protobuf.Value"); break; case ValueKind::kBool: type_url = MakeTypeUrl("google.protobuf.BoolValue"); break; case ValueKind::kInt: type_url = MakeTypeUrl("google.protobuf.Int64Value"); break; case ValueKind::kUint: type_url = MakeTypeUrl("google.protobuf.UInt64Value"); break; case ValueKind::kDouble: type_url = MakeTypeUrl("google.protobuf.DoubleValue"); break; case ValueKind::kBytes: type_url = MakeTypeUrl("google.protobuf.BytesValue"); break; case ValueKind::kString: type_url = MakeTypeUrl("google.protobuf.StringValue"); break; case ValueKind::kList: type_url = MakeTypeUrl("google.protobuf.ListValue"); break; case ValueKind::kMap: type_url = MakeTypeUrl("google.protobuf.Struct"); break; case ValueKind::kDuration: type_url = MakeTypeUrl("google.protobuf.Duration"); break; case ValueKind::kTimestamp: type_url = MakeTypeUrl("google.protobuf.Timestamp"); break; default: type_url = MakeTypeUrl(value.GetTypeName()); break; } CEL_RETURN_IF_ERROR( well_known_types->Any().Initialize(message->GetDescriptor())); well_known_types->Any().SetTypeUrl(message, type_url); well_known_types->Any().SetValue(message, std::move(serialized).Consume()); return absl::nullopt; } case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: { if (auto duration_value = value.AsDuration(); duration_value) { CEL_RETURN_IF_ERROR( well_known_types->Duration().Initialize(message->GetDescriptor())); CEL_RETURN_IF_ERROR(well_known_types->Duration().SetFromAbslDuration( message, duration_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: { if (auto timestamp_value = value.AsTimestamp(); timestamp_value) { CEL_RETURN_IF_ERROR( well_known_types->Timestamp().Initialize(message->GetDescriptor())); CEL_RETURN_IF_ERROR(well_known_types->Timestamp().SetFromAbslTime( message, timestamp_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), to_desc->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: { CEL_RETURN_IF_ERROR(value.ConvertToJson(pool, factory, message)); return absl::nullopt; } case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: { CEL_RETURN_IF_ERROR(value.ConvertToJsonArray(pool, factory, message)); return absl::nullopt; } case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: { CEL_RETURN_IF_ERROR(value.ConvertToJsonObject(pool, factory, message)); return absl::nullopt; } default: break; } // Not a well known type. // Deal with legacy values. if (auto legacy_value = common_internal::AsLegacyStructValue(value); legacy_value) { const auto* from_message = legacy_value->message_ptr(); return ProtoMessageCopy(message, to_desc, from_message); } // Deal with modern values. if (auto parsed_message_value = value.AsParsedMessage(); parsed_message_value) { return ProtoMessageCopy(message, to_desc, cel::to_address(*parsed_message_value)); } return TypeConversionError(value.GetTypeName(), message->GetTypeName()); } // Converts a value to a specific protocol buffer map key. using ProtoMapKeyFromValueConverter = absl::StatusOr> (*)(const Value&, google::protobuf::MapKey&, std::string&); absl::StatusOr> ProtoBoolMapKeyFromValueConverter( const Value& value, google::protobuf::MapKey& key, std::string&) { if (auto bool_value = value.AsBool(); bool_value) { key.SetBoolValue(bool_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "bool"); } absl::StatusOr> ProtoInt32MapKeyFromValueConverter( const Value& value, google::protobuf::MapKey& key, std::string&) { if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || int_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); } key.SetInt32Value(static_cast(int_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "int"); } absl::StatusOr> ProtoInt64MapKeyFromValueConverter( const Value& value, google::protobuf::MapKey& key, std::string&) { if (auto int_value = value.AsInt(); int_value) { key.SetInt64Value(int_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "int"); } absl::StatusOr> ProtoUInt32MapKeyFromValueConverter( const Value& value, google::protobuf::MapKey& key, std::string&) { if (auto uint_value = value.AsUint(); uint_value) { if (uint_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("uint64 to uint32 overflow")); } key.SetUInt32Value(static_cast(uint_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "uint"); } absl::StatusOr> ProtoUInt64MapKeyFromValueConverter( const Value& value, google::protobuf::MapKey& key, std::string&) { if (auto uint_value = value.AsUint(); uint_value) { key.SetUInt64Value(uint_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "uint"); } absl::StatusOr> ProtoStringMapKeyFromValueConverter( const Value& value, google::protobuf::MapKey& key, std::string& key_string) { if (auto string_value = value.AsString(); string_value) { key_string = string_value->NativeString(); key.SetStringValue(key_string); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "string"); } // Gets the converter for converting from values to protocol buffer map key. absl::StatusOr GetProtoMapKeyFromValueConverter( google::protobuf::FieldDescriptor::CppType cpp_type) { switch (cpp_type) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: return ProtoBoolMapKeyFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: return ProtoInt32MapKeyFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_INT64: return ProtoInt64MapKeyFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: return ProtoUInt32MapKeyFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: return ProtoUInt64MapKeyFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: return ProtoStringMapKeyFromValueConverter; default: return absl::InvalidArgumentError( absl::StrCat("unexpected protocol buffer map key type: ", google::protobuf::FieldDescriptor::CppTypeName(cpp_type))); } } // Converts a value to a specific protocol buffer map value. using ProtoMapValueFromValueConverter = absl::StatusOr> (*)( const Value&, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef&); absl::StatusOr> ProtoBoolMapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto bool_value = value.AsBool(); bool_value) { value_ref.SetBoolValue(bool_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "bool"); } absl::StatusOr> ProtoInt32MapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || int_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); } value_ref.SetInt32Value(static_cast(int_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "int"); } absl::StatusOr> ProtoInt64MapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto int_value = value.AsInt(); int_value) { value_ref.SetInt64Value(int_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "int"); } absl::StatusOr> ProtoUInt32MapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto uint_value = value.AsUint(); uint_value) { if (uint_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("uint64 to uint32 overflow")); } value_ref.SetUInt32Value(static_cast(uint_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "uint"); } absl::StatusOr> ProtoUInt64MapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto uint_value = value.AsUint(); uint_value) { value_ref.SetUInt64Value(uint_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "uint"); } absl::StatusOr> ProtoFloatMapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto double_value = value.AsDouble(); double_value) { value_ref.SetFloatValue(double_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "double"); } absl::StatusOr> ProtoDoubleMapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto double_value = value.AsDouble(); double_value) { value_ref.SetDoubleValue(double_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "double"); } absl::StatusOr> ProtoBytesMapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto bytes_value = value.AsBytes(); bytes_value) { value_ref.SetStringValue(bytes_value->NativeString()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "bytes"); } absl::StatusOr> ProtoStringMapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto string_value = value.AsString(); string_value) { value_ref.SetStringValue(string_value->NativeString()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "string"); } absl::StatusOr> ProtoNullMapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (value.IsNull() || value.IsInt()) { value_ref.SetEnumValue(0); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "google.protobuf.NullValue"); } absl::StatusOr> ProtoEnumMapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, google::protobuf::MapValueRef& value_ref) { if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || int_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); } value_ref.SetEnumValue(static_cast(int_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "enum"); } absl::StatusOr> ProtoMessageMapValueFromValueConverter( const Value& value, const google::protobuf::FieldDescriptor* absl_nonnull, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory, well_known_types::Reflection* absl_nonnull well_known_types, google::protobuf::MapValueRef& value_ref) { return ProtoMessageFromValueImpl(value, pool, factory, well_known_types, value_ref.MutableMessageValue()); } // Gets the converter for converting from values to protocol buffer map value. absl::StatusOr GetProtoMapValueFromValueConverter( const google::protobuf::FieldDescriptor* absl_nonnull field) { ABSL_DCHECK(field->is_map()); const auto* value_field = field->message_type()->map_value(); switch (value_field->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: return ProtoBoolMapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: return ProtoInt32MapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_INT64: return ProtoInt64MapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: return ProtoUInt32MapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: return ProtoUInt64MapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: return ProtoFloatMapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: return ProtoDoubleMapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: if (value_field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { return ProtoBytesMapValueFromValueConverter; } return ProtoStringMapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: if (value_field->enum_type()->full_name() == "google.protobuf.NullValue") { return ProtoNullMapValueFromValueConverter; } return ProtoEnumMapValueFromValueConverter; case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: return ProtoMessageMapValueFromValueConverter; default: return absl::InvalidArgumentError(absl::StrCat( "unexpected protocol buffer map value type: ", google::protobuf::FieldDescriptor::CppTypeName(value_field->cpp_type()))); } } using ProtoRepeatedFieldFromValueMutator = absl::StatusOr> (*)( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull, google::protobuf::Message* absl_nonnull, const google::protobuf::FieldDescriptor* absl_nonnull, const Value&); absl::StatusOr> ProtoBoolRepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto bool_value = value.AsBool(); bool_value) { reflection->AddBool(message, field, bool_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "bool"); } absl::StatusOr> ProtoInt32RepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || int_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); } reflection->AddInt32(message, field, static_cast(int_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "int"); } absl::StatusOr> ProtoInt64RepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto int_value = value.AsInt(); int_value) { reflection->AddInt64(message, field, int_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "int"); } absl::StatusOr> ProtoUInt32RepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto uint_value = value.AsUint(); uint_value) { if (uint_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("uint64 to uint32 overflow")); } reflection->AddUInt32(message, field, static_cast(uint_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "uint"); } absl::StatusOr> ProtoUInt64RepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto uint_value = value.AsUint(); uint_value) { reflection->AddUInt64(message, field, uint_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "uint"); } absl::StatusOr> ProtoFloatRepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto double_value = value.AsDouble(); double_value) { reflection->AddFloat(message, field, static_cast(double_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "double"); } absl::StatusOr> ProtoDoubleRepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto double_value = value.AsDouble(); double_value) { reflection->AddDouble(message, field, double_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "double"); } absl::StatusOr> ProtoBytesRepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto bytes_value = value.AsBytes(); bytes_value) { reflection->AddString(message, field, bytes_value->NativeString()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "bytes"); } absl::StatusOr> ProtoStringRepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (auto string_value = value.AsString(); string_value) { reflection->AddString(message, field, string_value->NativeString()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "string"); } absl::StatusOr> ProtoNullRepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { if (value.IsNull() || value.IsInt()) { reflection->AddEnumValue(message, field, 0); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "null_type"); } absl::StatusOr> ProtoEnumRepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, well_known_types::Reflection* absl_nonnull, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { const auto* enum_descriptor = field->enum_type(); if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || int_value->NativeValue() > std::numeric_limits::max()) { return TypeConversionError(value.GetTypeName(), enum_descriptor->full_name()); } reflection->AddEnumValue(message, field, static_cast(int_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), enum_descriptor->full_name()); } absl::StatusOr> ProtoMessageRepeatedFieldFromValueMutator( const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory, well_known_types::Reflection* absl_nonnull well_known_types, const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { auto* element = reflection->AddMessage(message, field, factory); auto result = ProtoMessageFromValueImpl(value, pool, factory, well_known_types, element); if (!result.ok() || result->has_value()) { reflection->RemoveLast(message, field); } return result; } absl::StatusOr GetProtoRepeatedFieldFromValueMutator( const google::protobuf::FieldDescriptor* absl_nonnull field) { ABSL_DCHECK(!field->is_map()); ABSL_DCHECK(field->is_repeated()); switch (field->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: return ProtoBoolRepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: return ProtoInt32RepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_INT64: return ProtoInt64RepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: return ProtoUInt32RepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: return ProtoUInt64RepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: return ProtoFloatRepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: return ProtoDoubleRepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: if (field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { return ProtoBytesRepeatedFieldFromValueMutator; } return ProtoStringRepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: if (field->enum_type()->full_name() == "google.protobuf.NullValue") { return ProtoNullRepeatedFieldFromValueMutator; } return ProtoEnumRepeatedFieldFromValueMutator; case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: return ProtoMessageRepeatedFieldFromValueMutator; default: return absl::InvalidArgumentError(absl::StrCat( "unexpected protocol buffer repeated field type: ", google::protobuf::FieldDescriptor::CppTypeName(field->cpp_type()))); } } class MessageValueBuilderImpl { public: MessageValueBuilderImpl( google::protobuf::Arena* absl_nullable arena, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull message) : arena_(arena), descriptor_pool_(descriptor_pool), message_factory_(message_factory), message_(message), descriptor_(message_->GetDescriptor()), reflection_(message_->GetReflection()) {} ~MessageValueBuilderImpl() { if (arena_ == nullptr && message_ != nullptr) { delete message_; } } absl::StatusOr> SetFieldByName( absl::string_view name, Value value) { const auto* field = descriptor_->FindFieldByName(name); if (field == nullptr) { field = descriptor_pool_->FindExtensionByPrintableName(descriptor_, name); if (field == nullptr) { return NoSuchFieldError(name); } } return SetField(field, std::move(value)); } absl::StatusOr> SetFieldByNumber(int64_t number, Value value) { if (number < std::numeric_limits::min() || number > std::numeric_limits::max()) { return NoSuchFieldError(absl::StrCat(number)); } const auto* field = descriptor_->FindFieldByNumber(static_cast(number)); if (field == nullptr) { return NoSuchFieldError(absl::StrCat(number)); } return SetField(field, std::move(value)); } absl::StatusOr Build() && { return Value::WrapMessage(std::exchange(message_, nullptr), descriptor_pool_, message_factory_, arena_); } absl::StatusOr BuildStruct() && { return ParsedMessageValue(std::exchange(message_, nullptr), arena_); } private: absl::StatusOr> SetMapField( const google::protobuf::FieldDescriptor* absl_nonnull field, Value value) { auto map_value = value.AsMap(); if (!map_value) { return TypeConversionError(value.GetTypeName(), "map"); } CEL_ASSIGN_OR_RETURN(auto key_converter, GetProtoMapKeyFromValueConverter( field->message_type()->map_key()->cpp_type())); CEL_ASSIGN_OR_RETURN(auto value_converter, GetProtoMapValueFromValueConverter(field)); reflection_->ClearField(message_, field); const auto* map_value_field = field->message_type()->map_value(); absl::optional error_value; // Don't replace this pattern with a status macro; nested macro invocations // have the same __LINE__ on MSVC, causing CEL_ASSIGN_OR_RETURN invocations // to conflict with each-other. auto status = map_value->ForEach( [this, field, key_converter, map_value_field, value_converter, &error_value](const Value& entry_key, const Value& entry_value) -> absl::StatusOr { std::string proto_key_string; google::protobuf::MapKey proto_key; CEL_ASSIGN_OR_RETURN( error_value, (*key_converter)(entry_key, proto_key, proto_key_string)); if (error_value) { return false; } google::protobuf::MapValueRef proto_value; extensions::protobuf_internal::InsertOrLookupMapValue( *reflection_, message_, *field, proto_key, &proto_value); CEL_ASSIGN_OR_RETURN( error_value, (*value_converter)(entry_value, map_value_field, descriptor_pool_, message_factory_, &well_known_types_, proto_value)); if (error_value) { return false; } return true; }, descriptor_pool_, message_factory_, arena_); if (!status.ok()) { return status; } return error_value; } absl::StatusOr> SetRepeatedField( const google::protobuf::FieldDescriptor* absl_nonnull field, Value value) { auto list_value = value.AsList(); if (!list_value) { return TypeConversionError(value.GetTypeName(), "list").NativeValue(); } CEL_ASSIGN_OR_RETURN(auto accessor, GetProtoRepeatedFieldFromValueMutator(field)); reflection_->ClearField(message_, field); absl::optional error_value; CEL_RETURN_IF_ERROR(list_value->ForEach( [this, field, accessor, &error_value](const Value& element) -> absl::StatusOr { CEL_ASSIGN_OR_RETURN(error_value, (*accessor)(descriptor_pool_, message_factory_, &well_known_types_, reflection_, message_, field, element)); return !error_value; }, descriptor_pool_, message_factory_, arena_)); return error_value; } absl::StatusOr> SetSingularField( const google::protobuf::FieldDescriptor* absl_nonnull field, Value value) { switch (field->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { if (auto bool_value = value.AsBool(); bool_value) { reflection_->SetBool(message_, field, bool_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "bool"); } case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || int_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue(absl::OutOfRangeError("int64 to int32 overflow")); } reflection_->SetInt32(message_, field, static_cast(int_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "int"); } case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { if (auto int_value = value.AsInt(); int_value) { reflection_->SetInt64(message_, field, int_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "int"); } case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { if (auto uint_value = value.AsUint(); uint_value) { if (uint_value->NativeValue() > std::numeric_limits::max()) { return ErrorValue( absl::OutOfRangeError("uint64 to uint32 overflow")); } reflection_->SetUInt32( message_, field, static_cast(uint_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "uint"); } case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { if (auto uint_value = value.AsUint(); uint_value) { reflection_->SetUInt64(message_, field, uint_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "uint"); } case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { if (auto double_value = value.AsDouble(); double_value) { reflection_->SetFloat(message_, field, double_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "double"); } case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { if (auto double_value = value.AsDouble(); double_value) { reflection_->SetDouble(message_, field, double_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "double"); } case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { if (field->type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { if (auto bytes_value = value.AsBytes(); bytes_value) { bytes_value->NativeValue(absl::Overload( [this, field](absl::string_view string) { reflection_->SetString(message_, field, std::string(string)); }, [this, field](const absl::Cord& cord) { reflection_->SetString(message_, field, cord); })); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "bytes"); } if (auto string_value = value.AsString(); string_value) { string_value->NativeValue(absl::Overload( [this, field](absl::string_view string) { reflection_->SetString(message_, field, std::string(string)); }, [this, field](const absl::Cord& cord) { reflection_->SetString(message_, field, cord); })); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "string"); } case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { if (field->enum_type()->full_name() == "google.protobuf.NullValue") { if (value.IsNull() || value.IsInt()) { reflection_->SetEnumValue(message_, field, 0); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), "null_type"); } if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() >= std::numeric_limits::min() && int_value->NativeValue() <= std::numeric_limits::max()) { reflection_->SetEnumValue( message_, field, static_cast(int_value->NativeValue())); return absl::nullopt; } } return TypeConversionError(value.GetTypeName(), field->enum_type()->full_name()); } case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { switch (field->message_type()->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto bool_value = value.AsBool(); bool_value) { CEL_RETURN_IF_ERROR(well_known_types_.BoolValue().Initialize( field->message_type())); well_known_types_.BoolValue().SetValue( reflection_->MutableMessage(message_, field, message_factory_), bool_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto int_value = value.AsInt(); int_value) { if (int_value->NativeValue() < std::numeric_limits::min() || int_value->NativeValue() > std::numeric_limits::max()) { return absl::OutOfRangeError("int64 to int32 overflow"); } CEL_RETURN_IF_ERROR(well_known_types_.Int32Value().Initialize( field->message_type())); well_known_types_.Int32Value().SetValue( reflection_->MutableMessage(message_, field, message_factory_), static_cast(int_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto int_value = value.AsInt(); int_value) { CEL_RETURN_IF_ERROR(well_known_types_.Int64Value().Initialize( field->message_type())); well_known_types_.Int64Value().SetValue( reflection_->MutableMessage(message_, field, message_factory_), int_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto uint_value = value.AsUint(); uint_value) { if (uint_value->NativeValue() > std::numeric_limits::max()) { return absl::OutOfRangeError("uint64 to uint32 overflow"); } CEL_RETURN_IF_ERROR(well_known_types_.UInt32Value().Initialize( field->message_type())); well_known_types_.UInt32Value().SetValue( reflection_->MutableMessage(message_, field, message_factory_), static_cast(uint_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto uint_value = value.AsUint(); uint_value) { CEL_RETURN_IF_ERROR(well_known_types_.UInt64Value().Initialize( field->message_type())); well_known_types_.UInt64Value().SetValue( reflection_->MutableMessage(message_, field, message_factory_), uint_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto double_value = value.AsDouble(); double_value) { CEL_RETURN_IF_ERROR(well_known_types_.FloatValue().Initialize( field->message_type())); well_known_types_.FloatValue().SetValue( reflection_->MutableMessage(message_, field, message_factory_), static_cast(double_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto double_value = value.AsDouble(); double_value) { CEL_RETURN_IF_ERROR(well_known_types_.DoubleValue().Initialize( field->message_type())); well_known_types_.DoubleValue().SetValue( reflection_->MutableMessage(message_, field, message_factory_), double_value->NativeValue()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto bytes_value = value.AsBytes(); bytes_value) { CEL_RETURN_IF_ERROR(well_known_types_.BytesValue().Initialize( field->message_type())); well_known_types_.BytesValue().SetValue( reflection_->MutableMessage(message_, field, message_factory_), bytes_value->NativeCord()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto string_value = value.AsString(); string_value) { CEL_RETURN_IF_ERROR(well_known_types_.StringValue().Initialize( field->message_type())); well_known_types_.StringValue().SetValue( reflection_->MutableMessage(message_, field, message_factory_), string_value->NativeCord()); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto duration_value = value.AsDuration(); duration_value) { CEL_RETURN_IF_ERROR(well_known_types_.Duration().Initialize( field->message_type())); CEL_RETURN_IF_ERROR( well_known_types_.Duration().SetFromAbslDuration( reflection_->MutableMessage(message_, field, message_factory_), duration_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: { if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } if (auto timestamp_value = value.AsTimestamp(); timestamp_value) { CEL_RETURN_IF_ERROR(well_known_types_.Timestamp().Initialize( field->message_type())); CEL_RETURN_IF_ERROR(well_known_types_.Timestamp().SetFromAbslTime( reflection_->MutableMessage(message_, field, message_factory_), timestamp_value->NativeValue())); return absl::nullopt; } return TypeConversionError(value.GetTypeName(), field->message_type()->full_name()); } case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: { CEL_RETURN_IF_ERROR( value.ConvertToJson(descriptor_pool_, message_factory_, reflection_->MutableMessage( message_, field, message_factory_))); return absl::nullopt; } case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: { CEL_RETURN_IF_ERROR(value.ConvertToJsonArray( descriptor_pool_, message_factory_, reflection_->MutableMessage(message_, field, message_factory_))); return absl::nullopt; } case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: { CEL_RETURN_IF_ERROR(value.ConvertToJsonObject( descriptor_pool_, message_factory_, reflection_->MutableMessage(message_, field, message_factory_))); return absl::nullopt; } case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: { // Probably not correct, need to use the parent/common one. google::protobuf::io::CordOutputStream serialized; CEL_RETURN_IF_ERROR(value.SerializeTo( descriptor_pool_, message_factory_, &serialized)); std::string type_url; switch (value.kind()) { case ValueKind::kNull: type_url = MakeTypeUrl("google.protobuf.Value"); break; case ValueKind::kBool: type_url = MakeTypeUrl("google.protobuf.BoolValue"); break; case ValueKind::kInt: type_url = MakeTypeUrl("google.protobuf.Int64Value"); break; case ValueKind::kUint: type_url = MakeTypeUrl("google.protobuf.UInt64Value"); break; case ValueKind::kDouble: type_url = MakeTypeUrl("google.protobuf.DoubleValue"); break; case ValueKind::kBytes: type_url = MakeTypeUrl("google.protobuf.BytesValue"); break; case ValueKind::kString: type_url = MakeTypeUrl("google.protobuf.StringValue"); break; case ValueKind::kList: type_url = MakeTypeUrl("google.protobuf.ListValue"); break; case ValueKind::kMap: type_url = MakeTypeUrl("google.protobuf.Struct"); break; case ValueKind::kDuration: type_url = MakeTypeUrl("google.protobuf.Duration"); break; case ValueKind::kTimestamp: type_url = MakeTypeUrl("google.protobuf.Timestamp"); break; default: type_url = MakeTypeUrl(value.GetTypeName()); break; } CEL_RETURN_IF_ERROR( well_known_types_.Any().Initialize(field->message_type())); well_known_types_.Any().SetTypeUrl( reflection_->MutableMessage(message_, field, message_factory_), type_url); well_known_types_.Any().SetValue( reflection_->MutableMessage(message_, field, message_factory_), std::move(serialized).Consume()); return absl::nullopt; } default: if (value.IsNull()) { // Allowing assigning `null` to message fields. return absl::nullopt; } break; } return ProtoMessageFromValueImpl( value, descriptor_pool_, message_factory_, &well_known_types_, reflection_->MutableMessage(message_, field, message_factory_)); } default: return absl::InternalError( absl::StrCat("unexpected protocol buffer message field type: ", field->cpp_type_name())); } } absl::StatusOr> SetField( const google::protobuf::FieldDescriptor* absl_nonnull field, Value value) { if (field->is_map()) { return SetMapField(field, std::move(value)); } if (field->is_repeated()) { return SetRepeatedField(field, std::move(value)); } return SetSingularField(field, std::move(value)); } google::protobuf::Arena* absl_nullable const arena_; const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; google::protobuf::MessageFactory* absl_nonnull const message_factory_; google::protobuf::Message* absl_nullable message_; const google::protobuf::Descriptor* absl_nonnull const descriptor_; const google::protobuf::Reflection* absl_nonnull const reflection_; well_known_types::Reflection well_known_types_; }; class ValueBuilderImpl final : public ValueBuilder { public: ValueBuilderImpl(google::protobuf::Arena* absl_nullable arena, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull message) : builder_(arena, descriptor_pool, message_factory, message) {} absl::StatusOr> SetFieldByName( absl::string_view name, Value value) override { return builder_.SetFieldByName(name, std::move(value)); } absl::StatusOr> SetFieldByNumber( int64_t number, Value value) override { return builder_.SetFieldByNumber(number, std::move(value)); } absl::StatusOr Build() && override { return std::move(builder_).Build(); } private: MessageValueBuilderImpl builder_; }; class StructValueBuilderImpl final : public StructValueBuilder { public: StructValueBuilderImpl( google::protobuf::Arena* absl_nullable arena, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull message) : builder_(arena, descriptor_pool, message_factory, message) {} absl::StatusOr> SetFieldByName( absl::string_view name, Value value) override { return builder_.SetFieldByName(name, std::move(value)); } absl::StatusOr> SetFieldByNumber( int64_t number, Value value) override { return builder_.SetFieldByNumber(number, std::move(value)); } absl::StatusOr Build() && override { return std::move(builder_).BuildStruct(); } private: MessageValueBuilderImpl builder_; }; } // namespace absl_nullable cel::ValueBuilderPtr NewValueBuilder( Allocator<> allocator, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, absl::string_view name) { const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool->FindMessageTypeByName(name); if (descriptor == nullptr) { return nullptr; } const google::protobuf::Message* absl_nullable prototype = message_factory->GetPrototype(descriptor); ABSL_DCHECK(prototype != nullptr) << "failed to get message prototype from factory, did you pass a dynamic " "descriptor to the generated message factory? we consider this to be " "a logic error and not a runtime error: " << descriptor->full_name(); if (ABSL_PREDICT_FALSE(prototype == nullptr)) { return nullptr; } return std::make_unique(allocator.arena(), descriptor_pool, message_factory, prototype->New(allocator.arena())); } absl_nullable cel::StructValueBuilderPtr NewStructValueBuilder( Allocator<> allocator, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, absl::string_view name) { const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool->FindMessageTypeByName(name); if (descriptor == nullptr) { return nullptr; } const google::protobuf::Message* absl_nullable prototype = message_factory->GetPrototype(descriptor); ABSL_DCHECK(prototype != nullptr) << "failed to get message prototype from factory, did you pass a dynamic " "descriptor to the generated message factory? we consider this to be " "a logic error and not a runtime error: " << descriptor->full_name(); if (ABSL_PREDICT_FALSE(prototype == nullptr)) { return nullptr; } return std::make_unique( allocator.arena(), descriptor_pool, message_factory, prototype->New(allocator.arena())); } } // namespace cel::common_internal ================================================ FILE: common/values/struct_value_builder.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_ #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "common/allocator.h" #include "common/value.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::common_internal { absl_nullable cel::StructValueBuilderPtr NewStructValueBuilder( Allocator<> allocator, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, absl::string_view name); } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_ ================================================ FILE: common/values/struct_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/base/attributes.h" #include "common/value.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::cel::internal::DynamicParseTextProto; using ::cel::internal::GetTestingDescriptorPool; using ::cel::internal::GetTestingMessageFactory; using ::testing::An; using ::testing::Optional; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; TEST(StructValue, Is) { EXPECT_TRUE(StructValue(ParsedMessageValue()).Is()); EXPECT_TRUE(StructValue(ParsedMessageValue()).Is()); } template constexpr T& AsLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return t; } template constexpr const T& AsConstLValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return t; } template constexpr T&& AsRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return static_cast(t); } template constexpr const T&& AsConstRValueRef(T& t ABSL_ATTRIBUTE_LIFETIME_BOUND) { return static_cast(t); } TEST(StructValue, As) { google::protobuf::Arena arena; { StructValue value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); StructValue other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstRValueRef(other_value).As(), Optional(An())); } { StructValue value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); StructValue other_value = value; EXPECT_THAT(AsLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsConstLValueRef(value).As(), Optional(An())); EXPECT_THAT(AsRValueRef(value).As(), Optional(An())); EXPECT_THAT( AsConstRValueRef(other_value).As(), Optional(An())); } } template decltype(auto) DoGet(From&& from) { return std::forward(from).template Get(); } TEST(StructValue, Get) { google::protobuf::Arena arena; { StructValue value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); StructValue other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstRValueRef(other_value)), An()); } { StructValue value(ParsedMessageValue{ DynamicParseTextProto(&arena, R"pb()pb", GetTestingDescriptorPool(), GetTestingMessageFactory()), &arena}); StructValue other_value = value; EXPECT_THAT(DoGet(AsLValueRef(value)), An()); EXPECT_THAT(DoGet(AsConstLValueRef(value)), An()); EXPECT_THAT(DoGet(AsRValueRef(value)), An()); EXPECT_THAT( DoGet(AsConstRValueRef(other_value)), An()); } } } // namespace } // namespace cel ================================================ FILE: common/values/struct_value_variant.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_VARIANT_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_VARIANT_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/meta/type_traits.h" #include "absl/utility/utility.h" #include "common/values/custom_struct_value.h" #include "common/values/legacy_struct_value.h" #include "common/values/parsed_message_value.h" namespace cel::common_internal { enum class StructValueIndex : uint16_t { kParsedMessage = 0, kCustom, kLegacy, }; template struct StructValueAlternative; template <> struct StructValueAlternative { static constexpr StructValueIndex kIndex = StructValueIndex::kCustom; }; template <> struct StructValueAlternative { static constexpr StructValueIndex kIndex = StructValueIndex::kParsedMessage; }; template <> struct StructValueAlternative { static constexpr StructValueIndex kIndex = StructValueIndex::kLegacy; }; template struct IsStructValueAlternative : std::false_type {}; template struct IsStructValueAlternative< T, std::void_t{})>> : std::true_type {}; template inline constexpr bool IsStructValueAlternativeV = IsStructValueAlternative::value; inline constexpr size_t kStructValueVariantAlign = 8; inline constexpr size_t kStructValueVariantSize = 24; // StructValueVariant is a subset of alternatives from the main ValueVariant // that is only structs. It is not stored directly in ValueVariant. class alignas(kStructValueVariantAlign) StructValueVariant final { public: StructValueVariant() : StructValueVariant(absl::in_place_type) {} StructValueVariant(const StructValueVariant&) = default; StructValueVariant(StructValueVariant&&) = default; StructValueVariant& operator=(const StructValueVariant&) = default; StructValueVariant& operator=(StructValueVariant&&) = default; template explicit StructValueVariant(absl::in_place_type_t, Args&&... args) : index_(StructValueAlternative::kIndex) { static_assert(alignof(T) <= kStructValueVariantAlign); static_assert(sizeof(T) <= kStructValueVariantSize); static_assert(std::is_trivially_copyable_v); ::new (static_cast(&raw_[0])) T(std::forward(args)...); } template >>> explicit StructValueVariant(T&& value) : StructValueVariant(absl::in_place_type>, std::forward(value)) {} template void Assign(T&& value) { using U = absl::remove_cvref_t; static_assert(alignof(U) <= kStructValueVariantAlign); static_assert(sizeof(U) <= kStructValueVariantSize); static_assert(std::is_trivially_copyable_v); index_ = StructValueAlternative::kIndex; ::new (static_cast(&raw_[0])) U(std::forward(value)); } template bool Is() const { return index_ == StructValueAlternative::kIndex; } template T& Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return *At(); } template const T& Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return *At(); } template T&& Get() && ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return std::move(*At()); } template const T&& Get() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return std::move(*At()); } template T* absl_nullable As() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (Is()) { return At(); } return nullptr; } template const T* absl_nullable As() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (Is()) { return At(); } return nullptr; } template decltype(auto) Visit(Visitor&& visitor) const { switch (index_) { case StructValueIndex::kCustom: return std::forward(visitor)(Get()); case StructValueIndex::kParsedMessage: return std::forward(visitor)(Get()); case StructValueIndex::kLegacy: return std::forward(visitor)(Get()); } } friend void swap(StructValueVariant& lhs, StructValueVariant& rhs) noexcept { using std::swap; swap(lhs.index_, rhs.index_); swap(lhs.raw_, rhs.raw_); } private: template ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull At() ABSL_ATTRIBUTE_LIFETIME_BOUND { static_assert(alignof(T) <= kStructValueVariantAlign); static_assert(sizeof(T) <= kStructValueVariantSize); static_assert(std::is_trivially_copyable_v); return std::launder(reinterpret_cast(&raw_[0])); } template ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull At() const ABSL_ATTRIBUTE_LIFETIME_BOUND { static_assert(alignof(T) <= kStructValueVariantAlign); static_assert(sizeof(T) <= kStructValueVariantSize); static_assert(std::is_trivially_copyable_v); return std::launder(reinterpret_cast(&raw_[0])); } StructValueIndex index_ = StructValueIndex::kCustom; alignas(8) std::byte raw_[kStructValueVariantSize]; }; } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_VARIANT_H_ ================================================ FILE: common/values/timestamp_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "google/protobuf/timestamp.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/time.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::TimestampReflection; using ::cel::well_known_types::ValueReflection; std::string TimestampDebugString(absl::Time value) { return internal::DebugStringTimestamp(value); } } // namespace std::string TimestampValue::DebugString() const { return TimestampDebugString(NativeValue()); } absl::Status TimestampValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::Timestamp message; CEL_RETURN_IF_ERROR( TimestampReflection::SetFromAbslTime(&message, NativeValue())); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", message.GetTypeName())); } return absl::OkStatus(); } absl::Status TimestampValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.SetStringValueFromTimestamp(json, NativeValue()); return absl::OkStatus(); } absl::Status TimestampValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsTimestamp(); other_value.has_value()) { *result = BoolValue{NativeValue() == other_value->NativeValue()}; return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/timestamp_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_TIMESTAMP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_TIMESTAMP_VALUE_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/utility/utility.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "internal/time.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class TimestampValue; TimestampValue UnsafeTimestampValue(absl::Time value); absl::StatusOr SafeTimestampValue(absl::Time value); // `TimestampValue` represents values of the primitive `timestamp` type. class TimestampValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kTimestamp; explicit TimestampValue(absl::Time value) noexcept : TimestampValue(absl::in_place, value) { ABSL_DCHECK_OK(internal::ValidateTimestamp(value)); } TimestampValue() = default; TimestampValue(const TimestampValue&) = default; TimestampValue(TimestampValue&&) = default; TimestampValue& operator=(const TimestampValue&) = default; TimestampValue& operator=(TimestampValue&&) = default; ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return TimestampType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return ToTime() == absl::UnixEpoch(); } ABSL_DEPRECATED("Use ToTime()") absl::Time NativeValue() const { return static_cast(*this); } ABSL_DEPRECATED("Use ToTime()") // NOLINTNEXTLINE(google-explicit-constructor) operator absl::Time() const noexcept { return value_; } absl::Time ToTime() const { return value_; } friend void swap(TimestampValue& lhs, TimestampValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); } friend bool operator==(TimestampValue lhs, TimestampValue rhs) { return lhs.value_ == rhs.value_; } friend bool operator<(const TimestampValue& lhs, const TimestampValue& rhs) { return lhs.value_ < rhs.value_; } private: friend class common_internal::ValueMixin; friend TimestampValue UnsafeTimestampValue(absl::Time value); TimestampValue(absl::in_place_t, absl::Time value) : value_(value) {} absl::Time value_ = absl::UnixEpoch(); }; inline TimestampValue UnsafeTimestampValue(absl::Time value) { return TimestampValue(absl::in_place, value); } inline absl::StatusOr SafeTimestampValue(absl::Time value) { absl::Status status = internal::ValidateTimestamp(value); if (!status.ok()) { return status; } return UnsafeTimestampValue(value); } inline bool operator!=(TimestampValue lhs, TimestampValue rhs) { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, TimestampValue value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_TIMESTAMP_VALUE_H_ ================================================ FILE: common/values/timestamp_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/status/status_matchers.h" #include "absl/time/time.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using TimestampValueTest = common_internal::ValueTest<>; TEST_F(TimestampValueTest, Kind) { EXPECT_EQ(TimestampValue().kind(), TimestampValue::kKind); EXPECT_EQ(Value(TimestampValue(absl::UnixEpoch() + absl::Seconds(1))).kind(), TimestampValue::kKind); } TEST_F(TimestampValueTest, DebugString) { { std::ostringstream out; out << TimestampValue(absl::UnixEpoch() + absl::Seconds(1)); EXPECT_EQ(out.str(), "1970-01-01T00:00:01Z"); } { std::ostringstream out; out << Value(TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); EXPECT_EQ(out.str(), "1970-01-01T00:00:01Z"); } } TEST_F(TimestampValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(TimestampValue().ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto( R"pb(string_value: "1970-01-01T00:00:00Z")pb")); } TEST_F(TimestampValueTest, NativeTypeId) { EXPECT_EQ( NativeTypeId::Of(TimestampValue(absl::UnixEpoch() + absl::Seconds(1))), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of( Value(TimestampValue(absl::UnixEpoch() + absl::Seconds(1)))), NativeTypeId::For()); } TEST_F(TimestampValueTest, Equality) { EXPECT_NE(TimestampValue(absl::UnixEpoch()), absl::UnixEpoch() + absl::Seconds(1)); EXPECT_NE(absl::UnixEpoch() + absl::Seconds(1), TimestampValue(absl::UnixEpoch())); EXPECT_NE(TimestampValue(absl::UnixEpoch()), TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); } TEST_F(TimestampValueTest, Comparison) { EXPECT_LT(TimestampValue(absl::UnixEpoch()), TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); EXPECT_FALSE(TimestampValue(absl::UnixEpoch() + absl::Seconds(1)) < TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); EXPECT_FALSE(TimestampValue(absl::UnixEpoch() + absl::Seconds(2)) < TimestampValue(absl::UnixEpoch() + absl::Seconds(1))); } } // namespace } // namespace cel ================================================ FILE: common/values/type_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "common/type.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { absl::Status TypeValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); return absl::FailedPreconditionError( absl::StrCat(GetTypeName(), " is unserializable")); } absl::Status TypeValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); return absl::FailedPreconditionError( absl::StrCat(GetTypeName(), " is not convertable to JSON")); } absl::Status TypeValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsType(); other_value.has_value()) { *result = BoolValue{NativeValue() == other_value->NativeValue()}; return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/type_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_TYPE_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_TYPE_VALUE_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class TypeValue; // `TypeValue` represents values of the primitive `type` type. class TypeValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kType; explicit TypeValue(Type value) : value_(value) {} TypeValue() = default; TypeValue(const TypeValue&) = default; TypeValue(TypeValue&&) = default; TypeValue& operator=(const TypeValue&) = default; TypeValue& operator=(TypeValue&&) = default; static constexpr ValueKind kind() { return kKind; } static absl::string_view GetTypeName() { return TypeType::kName; } std::string DebugString() const { return type().DebugString(); } // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return false; } ABSL_DEPRECATED(("Use type()")) const Type& NativeValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return type(); } const Type& type() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } absl::string_view name() const { return type().name(); } friend void swap(TypeValue& lhs, TypeValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); } private: friend class common_internal::ValueMixin; Type value_; }; inline std::ostream& operator<<(std::ostream& out, const TypeValue& value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_TYPE_VALUE_H_ ================================================ FILE: common/values/type_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/status/status.h" #include "common/native_type.h" #include "common/type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::StatusIs; using TypeValueTest = common_internal::ValueTest<>; TEST_F(TypeValueTest, Kind) { EXPECT_EQ(TypeValue(AnyType()).kind(), TypeValue::kKind); EXPECT_EQ(Value(TypeValue(AnyType())).kind(), TypeValue::kKind); } TEST_F(TypeValueTest, DebugString) { { std::ostringstream out; out << TypeValue(AnyType()); EXPECT_EQ(out.str(), "google.protobuf.Any"); } { std::ostringstream out; out << Value(TypeValue(AnyType())); EXPECT_EQ(out.str(), "google.protobuf.Any"); } } TEST_F(TypeValueTest, SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT(TypeValue(AnyType()).SerializeTo(descriptor_pool(), message_factory(), &output), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(TypeValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(TypeValue(AnyType()).ConvertToJson(descriptor_pool(), message_factory(), message), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(TypeValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(TypeValue(AnyType())), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(TypeValue(AnyType()))), NativeTypeId::For()); } } // namespace } // namespace cel ================================================ FILE: common/values/uint_value.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "common/value.h" #include "internal/number.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::well_known_types::ValueReflection; std::string UintDebugString(int64_t value) { return absl::StrCat(value, "u"); } } // namespace std::string UintValue::DebugString() const { return UintDebugString(NativeValue()); } absl::Status UintValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); google::protobuf::UInt64Value message; message.set_value(NativeValue()); if (!message.SerializePartialToZeroCopyStream(output)) { return absl::UnknownError( absl::StrCat("failed to serialize message: ", message.GetTypeName())); } return absl::OkStatus(); } absl::Status UintValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection value_reflection; CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor())); value_reflection.SetNumberValue(json, NativeValue()); return absl::OkStatus(); } absl::Status UintValue::Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); if (auto other_value = other.AsUint(); other_value.has_value()) { *result = BoolValue{NativeValue() == other_value->NativeValue()}; return absl::OkStatus(); } if (auto other_value = other.AsDouble(); other_value.has_value()) { *result = BoolValue{internal::Number::FromUint64(NativeValue()) == internal::Number::FromDouble(other_value->NativeValue())}; return absl::OkStatus(); } if (auto other_value = other.AsInt(); other_value.has_value()) { *result = BoolValue{internal::Number::FromUint64(NativeValue()) == internal::Number::FromInt64(other_value->NativeValue())}; return absl::OkStatus(); } *result = FalseValue(); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/uint_value.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_UINT_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_UINT_VALUE_H_ #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class UintValue; // `UintValue` represents values of the primitive `uint` type. class UintValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kUint; explicit UintValue(uint64_t value) noexcept : value_(value) {} UintValue() = default; UintValue(const UintValue&) = default; UintValue(UintValue&&) = default; UintValue& operator=(const UintValue&) = default; UintValue& operator=(UintValue&&) = default; constexpr ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return UintType::kName; } std::string DebugString() const; // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return NativeValue() == 0; } constexpr uint64_t NativeValue() const { return static_cast(*this); } // NOLINTNEXTLINE(google-explicit-constructor) constexpr operator uint64_t() const noexcept { return value_; } friend void swap(UintValue& lhs, UintValue& rhs) noexcept { using std::swap; swap(lhs.value_, rhs.value_); } private: friend class common_internal::ValueMixin; uint64_t value_ = 0; }; template H AbslHashValue(H state, UintValue value) { return H::combine(std::move(state), value.NativeValue()); } constexpr bool operator==(UintValue lhs, UintValue rhs) { return lhs.NativeValue() == rhs.NativeValue(); } constexpr bool operator!=(UintValue lhs, UintValue rhs) { return !operator==(lhs, rhs); } inline std::ostream& operator<<(std::ostream& out, UintValue value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_UINT_VALUE_H_ ================================================ FILE: common/values/uint_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/hash/hash.h" #include "absl/status/status_matchers.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using UintValueTest = common_internal::ValueTest<>; TEST_F(UintValueTest, Kind) { EXPECT_EQ(UintValue(1).kind(), UintValue::kKind); EXPECT_EQ(Value(UintValue(1)).kind(), UintValue::kKind); } TEST_F(UintValueTest, DebugString) { { std::ostringstream out; out << UintValue(1); EXPECT_EQ(out.str(), "1u"); } { std::ostringstream out; out << Value(UintValue(1)); EXPECT_EQ(out.str(), "1u"); } } TEST_F(UintValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT( UintValue(1).ConvertToJson(descriptor_pool(), message_factory(), message), IsOk()); EXPECT_THAT(*message, EqualsValueTextProto(R"pb(number_value: 1)pb")); } TEST_F(UintValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(UintValue(1)), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(UintValue(1))), NativeTypeId::For()); } TEST_F(UintValueTest, HashValue) { EXPECT_EQ(absl::HashOf(UintValue(1)), absl::HashOf(uint64_t{1})); } TEST_F(UintValueTest, Equality) { EXPECT_NE(UintValue(0u), 1u); EXPECT_NE(1u, UintValue(0u)); EXPECT_NE(UintValue(0u), UintValue(1u)); } TEST_F(UintValueTest, LessThan) { EXPECT_LT(UintValue(0), 1); EXPECT_LT(0, UintValue(1)); EXPECT_LT(UintValue(0), UintValue(1)); } } // namespace } // namespace cel ================================================ FILE: common/values/unknown_value.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { absl::Status UnknownValue::SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(output != nullptr); return absl::FailedPreconditionError( absl::StrCat(GetTypeName(), " is unserializable")); } absl::Status UnknownValue::ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); return absl::FailedPreconditionError( absl::StrCat(GetTypeName(), " is not convertable to JSON")); } absl::Status UnknownValue::Equal( const Value&, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(result != nullptr); *result = FalseValue(); return absl::OkStatus(); } } // namespace cel ================================================ FILE: common/values/unknown_value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_UNKNOWN_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_UNKNOWN_VALUE_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/unknown.h" #include "common/value_kind.h" #include "common/values/values.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/message.h" namespace cel { class Value; class UnknownValue; // `UnknownValue` represents values of the primitive `duration` type. class UnknownValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kUnknown; explicit UnknownValue(Unknown unknown) : unknown_(std::move(unknown)) {} UnknownValue() = default; UnknownValue(const UnknownValue&) = default; UnknownValue(UnknownValue&&) = default; UnknownValue& operator=(const UnknownValue&) = default; UnknownValue& operator=(UnknownValue&&) = default; constexpr ValueKind kind() const { return kKind; } absl::string_view GetTypeName() const { return UnknownType::kName; } std::string DebugString() const { return ""; } // See Value::SerializeTo(). absl::Status SerializeTo( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const; // See Value::ConvertToJson(). absl::Status ConvertToJson( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const; absl::Status Equal(const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using ValueMixin::Equal; bool IsZeroValue() const { return false; } void swap(UnknownValue& other) noexcept { using std::swap; swap(unknown_, other.unknown_); } const Unknown& NativeValue() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { return unknown_; } Unknown NativeValue() && { Unknown unknown = std::move(unknown_); return unknown; } const AttributeSet& attribute_set() const { return unknown_.unknown_attributes(); } const FunctionResultSet& function_result_set() const { return unknown_.unknown_function_results(); } private: friend class common_internal::ValueMixin; Unknown unknown_; }; inline void swap(UnknownValue& lhs, UnknownValue& rhs) noexcept { lhs.swap(rhs); } inline std::ostream& operator<<(std::ostream& out, const UnknownValue& value) { return out << value.DebugString(); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_UNKNOWN_VALUE_H_ ================================================ FILE: common/values/unknown_value_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/status/status.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" namespace cel { namespace { using ::absl_testing::StatusIs; using UnknownValueTest = common_internal::ValueTest<>; TEST_F(UnknownValueTest, Kind) { EXPECT_EQ(UnknownValue().kind(), UnknownValue::kKind); EXPECT_EQ(Value(UnknownValue()).kind(), UnknownValue::kKind); } TEST_F(UnknownValueTest, DebugString) { { std::ostringstream out; out << UnknownValue(); EXPECT_EQ(out.str(), ""); } { std::ostringstream out; out << Value(UnknownValue()); EXPECT_EQ(out.str(), ""); } } TEST_F(UnknownValueTest, SerializeTo) { google::protobuf::io::CordOutputStream output; EXPECT_THAT( UnknownValue().SerializeTo(descriptor_pool(), message_factory(), &output), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(UnknownValueTest, ConvertToJson) { auto* message = NewArenaValueMessage(); EXPECT_THAT(UnknownValue().ConvertToJson(descriptor_pool(), message_factory(), message), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST_F(UnknownValueTest, NativeTypeId) { EXPECT_EQ(NativeTypeId::Of(UnknownValue()), NativeTypeId::For()); EXPECT_EQ(NativeTypeId::Of(Value(UnknownValue())), NativeTypeId::For()); } } // namespace } // namespace cel ================================================ FILE: common/values/value_builder.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include #include #include "absl/base/call_once.h" #include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/container/flat_hash_map.h" #include "absl/hash/hash.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/allocator.h" #include "common/arena.h" #include "common/legacy_value.h" #include "common/native_type.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "common/values/list_value_builder.h" #include "common/values/map_value_builder.h" #include "eval/public/cel_value.h" #include "internal/casts.h" #include "internal/manual.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace common_internal { namespace { using ::cel::well_known_types::ListValueReflection; using ::cel::well_known_types::StructReflection; using ::cel::well_known_types::ValueReflection; using ::google::api::expr::runtime::CelValue; using ValueVector = std::vector>; absl::Status CheckListElement(const Value& value) { if (auto error_value = value.AsError(); ABSL_PREDICT_FALSE(error_value)) { return error_value->ToStatus(); } if (auto unknown_value = value.AsUnknown(); ABSL_PREDICT_FALSE(unknown_value)) { return absl::InvalidArgumentError("cannot add unknown value to list"); } return absl::OkStatus(); } template absl::Status ListValueToJsonArray( const Vector& vector, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); ListValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(json->GetDescriptor())); json->Clear(); if (vector.empty()) { return absl::OkStatus(); } for (const auto& element : vector) { CEL_RETURN_IF_ERROR(element->ConvertToJson(descriptor_pool, message_factory, reflection.AddValues(json))); } return absl::OkStatus(); } template absl::Status ListValueToJson( const Vector& vector, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(json->GetDescriptor())); return ListValueToJsonArray(vector, descriptor_pool, message_factory, reflection.MutableListValue(json)); } class CompatListValueImplIterator final : public ValueIterator { public: explicit CompatListValueImplIterator(absl::Span elements) : elements_(elements) {} bool HasNext() override { return index_ < elements_.size(); } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (ABSL_PREDICT_FALSE(index_ >= elements_.size())) { return absl::FailedPreconditionError( "ValueManager::Next called after ValueManager::HasNext returned " "false"); } *result = elements_[index_++]; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (index_ >= elements_.size()) { return false; } *key_or_value = elements_[index_]; ++index_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (index_ >= elements_.size()) { return false; } if (value != nullptr) { *value = elements_[index_]; } *key = IntValue(index_++); return true; } private: const absl::Span elements_; size_t index_ = 0; }; struct ValueFormatter { void operator()(std::string* out, const std::pair& value) const { (*this)(out, value.first); out->append(": "); (*this)(out, value.second); } void operator()(std::string* out, const Value& value) const { out->append(value.DebugString()); } }; class ListValueBuilderImpl final : public ListValueBuilder { public: explicit ListValueBuilderImpl(google::protobuf::Arena* absl_nonnull arena) : arena_(arena) { elements_.Construct(arena); } ~ListValueBuilderImpl() override { if (!elements_trivially_destructible_) { elements_.Destruct(); } } absl::Status Add(Value value) override { CEL_RETURN_IF_ERROR(CheckListElement(value)); UnsafeAdd(std::move(value)); return absl::OkStatus(); } void UnsafeAdd(Value value) override { ABSL_DCHECK_OK(CheckListElement(value)); elements_->emplace_back(std::move(value)); if (elements_trivially_destructible_) { elements_trivially_destructible_ = ArenaTraits<>::trivially_destructible(elements_->back()); } } size_t Size() const override { return elements_->size(); } void Reserve(size_t capacity) override { elements_->reserve(capacity); } ListValue Build() && override; CustomListValue BuildCustom() &&; const CompatListValue* absl_nonnull BuildCompat() &&; const CompatListValue* absl_nonnull BuildCompatAt( void* absl_nonnull address) &&; private: google::protobuf::Arena* absl_nonnull const arena_; internal::Manual elements_; bool elements_trivially_destructible_ = true; }; class CompatListValueImpl final : public CompatListValue { public: explicit CompatListValueImpl(ValueVector&& elements) : elements_(std::move(elements)) {} std::string DebugString() const override { return absl::StrCat("[", absl::StrJoin(elements_, ", ", ValueFormatter{}), "]"); } absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { return ListValueToJsonArray(elements_, descriptor_pool, message_factory, json); } CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { ABSL_DCHECK(arena != nullptr); ListValueBuilderImpl builder(arena); builder.Reserve(elements_.size()); for (const auto& element : elements_) { builder.UnsafeAdd(element.Clone(arena)); } return std::move(builder).BuildCustom(); } size_t Size() const override { return elements_.size(); } absl::Status ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { const size_t size = elements_.size(); for (size_t i = 0; i < size; ++i) { CEL_ASSIGN_OR_RETURN(auto ok, callback(i, elements_[i])); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr NewIterator() const override { return std::make_unique( absl::MakeConstSpan(elements_)); } CelValue operator[](int index) const override { return Get(elements_.get_allocator().arena(), index); } // Like `operator[](int)` above, but also accepts an arena. Prefer calling // this variant if the arena is known. CelValue Get(google::protobuf::Arena* arena, int index) const override { if (arena == nullptr) { arena = elements_.get_allocator().arena(); } if (ABSL_PREDICT_FALSE(index < 0 || index >= size())) { return CelValue::CreateError(google::protobuf::Arena::Create( arena, IndexOutOfBoundsError(index).ToStatus())); } return common_internal::UnsafeLegacyValue( elements_[index], /*stable=*/true, arena != nullptr ? arena : elements_.get_allocator().arena()); } int size() const override { return static_cast(Size()); } protected: absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { if (index >= elements_.size()) { *result = IndexOutOfBoundsError(index); } else { *result = elements_[index]; } return absl::OkStatus(); } private: const ValueVector elements_; }; } // namespace } // namespace common_internal template <> struct ArenaTraits { using always_trivially_destructible = std::true_type; }; namespace common_internal { namespace { ListValue ListValueBuilderImpl::Build() && { if (elements_->empty()) { return ListValue(); } return std::move(*this).BuildCustom(); } CustomListValue ListValueBuilderImpl::BuildCustom() && { if (elements_->empty()) { return CustomListValue(EmptyCompatListValue(), arena_); } return CustomListValue(std::move(*this).BuildCompat(), arena_); } const CompatListValue* absl_nonnull ListValueBuilderImpl::BuildCompat() && { if (elements_->empty()) { return EmptyCompatListValue(); } return std::move(*this).BuildCompatAt(arena_->AllocateAligned( sizeof(CompatListValueImpl), alignof(CompatListValueImpl))); } const CompatListValue* absl_nonnull ListValueBuilderImpl::BuildCompatAt( void* absl_nonnull address) && { CompatListValueImpl* absl_nonnull impl = ::new (address) CompatListValueImpl(std::move(*elements_)); if (!elements_trivially_destructible_) { arena_->OwnDestructor(impl); elements_trivially_destructible_ = true; } return impl; } class MutableCompatListValueImpl final : public MutableCompatListValue { public: explicit MutableCompatListValueImpl(google::protobuf::Arena* absl_nonnull arena) : elements_(arena) {} std::string DebugString() const override { return absl::StrCat("[", absl::StrJoin(elements_, ", ", ValueFormatter{}), "]"); } absl::Status ConvertToJsonArray( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { return ListValueToJsonArray(elements_, descriptor_pool, message_factory, json); } CustomListValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { ABSL_DCHECK(arena != nullptr); ListValueBuilderImpl builder(arena); builder.Reserve(elements_.size()); for (const auto& element : elements_) { builder.UnsafeAdd(element.Clone(arena)); } return std::move(builder).BuildCustom(); } size_t Size() const override { return elements_.size(); } absl::Status ForEach( ForEachWithIndexCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { const size_t size = elements_.size(); for (size_t i = 0; i < size; ++i) { CEL_ASSIGN_OR_RETURN(auto ok, callback(i, elements_[i])); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr NewIterator() const override { return std::make_unique( absl::MakeConstSpan(elements_)); } CelValue operator[](int index) const override { return Get(elements_.get_allocator().arena(), index); } // Like `operator[](int)` above, but also accepts an arena. Prefer calling // this variant if the arena is known. CelValue Get(google::protobuf::Arena* arena, int index) const override { if (arena == nullptr) { arena = elements_.get_allocator().arena(); } if (ABSL_PREDICT_FALSE(index < 0 || index >= size())) { return CelValue::CreateError(google::protobuf::Arena::Create( arena, IndexOutOfBoundsError(index).ToStatus())); } return common_internal::UnsafeLegacyValue( elements_[index], /*stable=*/false, arena != nullptr ? arena : elements_.get_allocator().arena()); } int size() const override { return static_cast(Size()); } absl::Status Append(Value value) const override { CEL_RETURN_IF_ERROR(CheckListElement(value)); elements_.emplace_back(std::move(value)); if (elements_trivially_destructible_) { elements_trivially_destructible_ = ArenaTraits<>::trivially_destructible(elements_.back()); if (!elements_trivially_destructible_) { elements_.get_allocator().arena()->OwnDestructor( const_cast(this)); } } return absl::OkStatus(); } void Reserve(size_t capacity) const override { elements_.reserve(capacity); } protected: absl::Status Get(size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { if (index >= elements_.size()) { *result = IndexOutOfBoundsError(index); } else { *result = elements_[index]; } return absl::OkStatus(); } private: mutable ValueVector elements_; mutable bool elements_trivially_destructible_ = true; }; } // namespace } // namespace common_internal template <> struct ArenaTraits { using constructible = std::true_type; using always_trivially_destructible = std::true_type; }; namespace common_internal { namespace {} // namespace absl::StatusOr MakeCompatListValue( const CustomListValue& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { ListValueBuilderImpl builder(arena); builder.Reserve(value.Size()); CEL_RETURN_IF_ERROR(value.ForEach( [&](const Value& element) -> absl::StatusOr { CEL_RETURN_IF_ERROR(builder.Add(element)); return true; }, descriptor_pool, message_factory, arena)); return std::move(builder).BuildCompat(); } MutableListValue* absl_nonnull NewMutableListValue( google::protobuf::Arena* absl_nonnull arena) { return ::new (arena->AllocateAligned(sizeof(MutableCompatListValueImpl), alignof(MutableCompatListValueImpl))) MutableCompatListValueImpl(arena); } bool IsMutableListValue(const Value& value) { if (auto custom_list_value = value.AsCustomList(); custom_list_value) { NativeTypeId native_type_id = custom_list_value->GetTypeId(); if (native_type_id == NativeTypeId::For() || native_type_id == NativeTypeId::For()) { return true; } } return false; } bool IsMutableListValue(const ListValue& value) { if (auto custom_list_value = value.AsCustom(); custom_list_value) { NativeTypeId native_type_id = custom_list_value->GetTypeId(); if (native_type_id == NativeTypeId::For() || native_type_id == NativeTypeId::For()) { return true; } } return false; } const MutableListValue* absl_nullable AsMutableListValue(const Value& value) { if (auto custom_list_value = value.AsCustomList(); custom_list_value) { NativeTypeId native_type_id = custom_list_value->GetTypeId(); if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( custom_list_value->interface()); } if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( custom_list_value->interface()); } } return nullptr; } const MutableListValue* absl_nullable AsMutableListValue( const ListValue& value) { if (auto custom_list_value = value.AsCustom(); custom_list_value) { NativeTypeId native_type_id = custom_list_value->GetTypeId(); if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( custom_list_value->interface()); } if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( custom_list_value->interface()); } } return nullptr; } const MutableListValue& GetMutableListValue(const Value& value) { ABSL_DCHECK(IsMutableListValue(value)) << value; const auto& custom_list_value = value.GetCustomList(); NativeTypeId native_type_id = custom_list_value.GetTypeId(); if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( *custom_list_value.interface()); } if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( *custom_list_value.interface()); } ABSL_UNREACHABLE(); } const MutableListValue& GetMutableListValue(const ListValue& value) { ABSL_DCHECK(IsMutableListValue(value)) << value; const auto& custom_list_value = value.GetCustom(); NativeTypeId native_type_id = custom_list_value.GetTypeId(); if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( *custom_list_value.interface()); } if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( *custom_list_value.interface()); } ABSL_UNREACHABLE(); } absl_nonnull cel::ListValueBuilderPtr NewListValueBuilder( google::protobuf::Arena* absl_nonnull arena) { return std::make_unique(arena); } } // namespace common_internal } // namespace cel namespace cel { namespace common_internal { namespace { using ::google::api::expr::runtime::CelList; using ::google::api::expr::runtime::CelValue; absl::Status CheckMapValue(const Value& value) { if (auto error_value = value.AsError(); ABSL_PREDICT_FALSE(error_value)) { return error_value->ToStatus(); } if (auto unknown_value = value.AsUnknown(); ABSL_PREDICT_FALSE(unknown_value)) { return absl::InvalidArgumentError("cannot add unknown value to list"); } return absl::OkStatus(); } size_t ValueHash(const Value& value) { switch (value.kind()) { case ValueKind::kBool: return absl::HashOf(value.kind(), value.GetBool()); case ValueKind::kInt: return absl::HashOf(ValueKind::kInt, absl::implicit_cast(value.GetInt())); case ValueKind::kUint: return absl::HashOf(ValueKind::kUint, absl::implicit_cast(value.GetUint())); case ValueKind::kString: return absl::HashOf(value.kind(), value.GetString()); default: ABSL_UNREACHABLE(); } } size_t ValueHash(const CelValue& value) { switch (value.type()) { case CelValue::Type::kBool: return absl::HashOf(ValueKind::kBool, value.BoolOrDie()); case CelValue::Type::kInt: return absl::HashOf(ValueKind::kInt, value.Int64OrDie()); case CelValue::Type::kUint: return absl::HashOf(ValueKind::kUint, value.Uint64OrDie()); case CelValue::Type::kString: return absl::HashOf(ValueKind::kString, value.StringOrDie().value()); default: ABSL_UNREACHABLE(); } } bool ValueEquals(const Value& lhs, const Value& rhs) { switch (lhs.kind()) { case ValueKind::kBool: switch (rhs.kind()) { case ValueKind::kBool: return lhs.GetBool() == rhs.GetBool(); case ValueKind::kInt: return false; case ValueKind::kUint: return false; case ValueKind::kString: return false; default: ABSL_UNREACHABLE(); } case ValueKind::kInt: switch (rhs.kind()) { case ValueKind::kBool: return false; case ValueKind::kInt: return lhs.GetInt() == rhs.GetInt(); case ValueKind::kUint: return false; case ValueKind::kString: return false; default: ABSL_UNREACHABLE(); } case ValueKind::kUint: switch (rhs.kind()) { case ValueKind::kBool: return false; case ValueKind::kInt: return false; case ValueKind::kUint: return lhs.GetUint() == rhs.GetUint(); case ValueKind::kString: return false; default: ABSL_UNREACHABLE(); } case ValueKind::kString: switch (rhs.kind()) { case ValueKind::kBool: return false; case ValueKind::kInt: return false; case ValueKind::kUint: return false; case ValueKind::kString: return lhs.GetString() == rhs.GetString(); default: ABSL_UNREACHABLE(); } default: ABSL_UNREACHABLE(); } } bool CelValueEquals(const CelValue& lhs, const Value& rhs) { switch (lhs.type()) { case CelValue::Type::kBool: switch (rhs.kind()) { case ValueKind::kBool: return BoolValue(lhs.BoolOrDie()) == rhs.GetBool(); case ValueKind::kInt: return false; case ValueKind::kUint: return false; case ValueKind::kString: return false; default: ABSL_UNREACHABLE(); } case CelValue::Type::kInt: switch (rhs.kind()) { case ValueKind::kBool: return false; case ValueKind::kInt: return IntValue(lhs.Int64OrDie()) == rhs.GetInt(); case ValueKind::kUint: return false; case ValueKind::kString: return false; default: ABSL_UNREACHABLE(); } case CelValue::Type::kUint: switch (rhs.kind()) { case ValueKind::kBool: return false; case ValueKind::kInt: return false; case ValueKind::kUint: return UintValue(lhs.Uint64OrDie()) == rhs.GetUint(); case ValueKind::kString: return false; default: ABSL_UNREACHABLE(); } case CelValue::Type::kString: switch (rhs.kind()) { case ValueKind::kBool: return false; case ValueKind::kInt: return false; case ValueKind::kUint: return false; case ValueKind::kString: return rhs.GetString().Equals(lhs.StringOrDie().value()); default: ABSL_UNREACHABLE(); } default: ABSL_UNREACHABLE(); } } absl::StatusOr ValueToJsonString(const Value& value) { switch (value.kind()) { case ValueKind::kString: return value.GetString().NativeString(); default: return TypeConversionError(value.GetRuntimeType(), StringType()) .ToStatus(); } } template absl::Status MapValueToJsonObject( const Map& map, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); StructReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(json->GetDescriptor())); json->Clear(); if (map.empty()) { return absl::OkStatus(); } for (const auto& entry : map) { CEL_ASSIGN_OR_RETURN(auto key, ValueToJsonString(entry.first)); CEL_RETURN_IF_ERROR(entry.second.ConvertToJson( descriptor_pool, message_factory, reflection.InsertField(json, key))); } return absl::OkStatus(); } template absl::Status MapValueToJson( const Map& map, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(json != nullptr); ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); ValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(json->GetDescriptor())); return MapValueToJsonObject(map, descriptor_pool, message_factory, reflection.MutableStructValue(json)); } struct ValueHasher { using is_transparent = void; size_t operator()(const Value& value) const { return (ValueHash)(value); } size_t operator()(const CelValue& value) const { return (ValueHash)(value); } }; struct ValueEqualer { using is_transparent = void; bool operator()(const Value& lhs, const CelValue& rhs) const { return (*this)(rhs, lhs); } bool operator()(const CelValue& lhs, const Value& rhs) const { return (CelValueEquals)(lhs, rhs); } bool operator()(const Value& lhs, const Value& rhs) const { return (ValueEquals)(lhs, rhs); } }; using ValueFlatHashMapAllocator = ArenaAllocator>; using ValueFlatHashMap = absl::flat_hash_map; class CompatMapValueImplIterator final : public ValueIterator { public: explicit CompatMapValueImplIterator(const ValueFlatHashMap* absl_nonnull map) : begin_(map->begin()), end_(map->end()) {} bool HasNext() override { return begin_ != end_; } absl::Status Next(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) override { if (ABSL_PREDICT_FALSE(begin_ == end_)) { return absl::FailedPreconditionError( "ValueManager::Next called after ValueManager::HasNext returned " "false"); } *result = begin_->first; ++begin_; return absl::OkStatus(); } absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key_or_value != nullptr); if (begin_ == end_) { return false; } *key_or_value = begin_->first; ++begin_; return true; } absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key, Value* absl_nullable value) override { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(key != nullptr); if (begin_ == end_) { return false; } *key = begin_->first; if (value != nullptr) { *value = begin_->second; } ++begin_; return true; } private: typename ValueFlatHashMap::const_iterator begin_; const typename ValueFlatHashMap::const_iterator end_; }; class MapValueBuilderImpl final : public MapValueBuilder { public: explicit MapValueBuilderImpl(google::protobuf::Arena* absl_nonnull arena) : arena_(arena) { map_.Construct(arena_); } ~MapValueBuilderImpl() override { if (!entries_trivially_destructible_) { map_.Destruct(); } } absl::Status Put(Value key, Value value) override { CEL_RETURN_IF_ERROR(CheckMapKey(key)); CEL_RETURN_IF_ERROR(CheckMapValue(value)); if (auto it = map_->find(key); ABSL_PREDICT_FALSE(it != map_->end())) { return DuplicateKeyError().ToStatus(); } UnsafePut(std::move(key), std::move(value)); return absl::OkStatus(); } void UnsafePut(Value key, Value value) override { auto insertion = map_->insert({std::move(key), std::move(value)}); ABSL_DCHECK(insertion.second); if (entries_trivially_destructible_) { entries_trivially_destructible_ = ArenaTraits<>::trivially_destructible(insertion.first->first) && ArenaTraits<>::trivially_destructible(insertion.first->second); } } size_t Size() const override { return map_->size(); } void Reserve(size_t capacity) override { map_->reserve(capacity); } MapValue Build() && override; CustomMapValue BuildCustom() &&; const CompatMapValue* absl_nonnull BuildCompat() &&; private: google::protobuf::Arena* absl_nonnull const arena_; internal::Manual map_; bool entries_trivially_destructible_ = true; }; class CompatMapValueImpl final : public CompatMapValue { public: explicit CompatMapValueImpl(ValueFlatHashMap&& map) : map_(std::move(map)) {} std::string DebugString() const override { return absl::StrCat("{", absl::StrJoin(map_, ", ", ValueFormatter{}), "}"); } absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { return MapValueToJsonObject(map_, descriptor_pool, message_factory, json); } CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { ABSL_DCHECK(arena != nullptr); MapValueBuilderImpl builder(arena); builder.Reserve(map_.size()); for (const auto& entry : map_) { builder.UnsafePut(entry.first.Clone(arena), entry.second.Clone(arena)); } return std::move(builder).BuildCustom(); } size_t Size() const override { return map_.size(); } absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const override { *result = CustomListValue(ProjectKeys(), map_.get_allocator().arena()); return absl::OkStatus(); } absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { for (const auto& entry : map_) { CEL_ASSIGN_OR_RETURN(auto ok, callback(entry.first, entry.second)); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr NewIterator() const override { return std::make_unique(&map_); } absl::optional operator[](CelValue key) const override { return Get(map_.get_allocator().arena(), key); } using CompatMapValue::Get; absl::optional Get(google::protobuf::Arena* arena, CelValue key) const override { if (auto status = CelValue::CheckMapKeyType(key); !status.ok()) { status.IgnoreError(); return absl::nullopt; } if (auto it = map_.find(key); it != map_.end()) { return common_internal::UnsafeLegacyValue( it->second, /*stable=*/true, arena != nullptr ? arena : map_.get_allocator().arena()); } return absl::nullopt; } absl::StatusOr Has(const CelValue& key) const override { // This check safeguards against issues with invalid key types such as NaN. CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(key)); return map_.find(key) != map_.end(); } int size() const override { return static_cast(Size()); } absl::StatusOr ListKeys() const override { return ProjectKeys(); } absl::StatusOr ListKeys(google::protobuf::Arena* arena) const override { return ProjectKeys(); } protected: absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { CEL_RETURN_IF_ERROR(CheckMapKey(key)); if (auto it = map_.find(key); it != map_.end()) { *result = it->second; return true; } return false; } absl::StatusOr Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { CEL_RETURN_IF_ERROR(CheckMapKey(key)); return map_.find(key) != map_.end(); } private: const CompatListValue* absl_nonnull ProjectKeys() const { absl::call_once(keys_once_, [this]() { ListValueBuilderImpl builder(map_.get_allocator().arena()); builder.Reserve(map_.size()); for (const auto& entry : map_) { builder.UnsafeAdd(entry.first); } std::move(builder).BuildCompatAt(&keys_[0]); }); return std::launder( reinterpret_cast(&keys_[0])); } const ValueFlatHashMap map_; mutable absl::once_flag keys_once_; alignas(CompatListValueImpl) mutable char keys_[sizeof(CompatListValueImpl)]; }; MapValue MapValueBuilderImpl::Build() && { if (map_->empty()) { return MapValue(); } return std::move(*this).BuildCustom(); } CustomMapValue MapValueBuilderImpl::BuildCustom() && { if (map_->empty()) { return CustomMapValue(EmptyCompatMapValue(), arena_); } return CustomMapValue(std::move(*this).BuildCompat(), arena_); } const CompatMapValue* absl_nonnull MapValueBuilderImpl::BuildCompat() && { if (map_->empty()) { return EmptyCompatMapValue(); } CompatMapValueImpl* absl_nonnull impl = ::new (arena_->AllocateAligned( sizeof(CompatMapValueImpl), alignof(CompatMapValueImpl))) CompatMapValueImpl(std::move(*map_)); if (!entries_trivially_destructible_) { arena_->OwnDestructor(impl); entries_trivially_destructible_ = true; } return impl; } class TrivialMutableMapValueImpl final : public MutableCompatMapValue { public: explicit TrivialMutableMapValueImpl(google::protobuf::Arena* absl_nonnull arena) : map_(arena) {} std::string DebugString() const override { return absl::StrCat("{", absl::StrJoin(map_, ", ", ValueFormatter{}), "}"); } absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull json) const override { return MapValueToJsonObject(map_, descriptor_pool, message_factory, json); } CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { ABSL_DCHECK(arena != nullptr); MapValueBuilderImpl builder(arena); builder.Reserve(map_.size()); for (const auto& entry : map_) { builder.UnsafePut(entry.first.Clone(arena), entry.second.Clone(arena)); } return std::move(builder).BuildCustom(); } size_t Size() const override { return map_.size(); } absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const override { *result = CustomListValue(ProjectKeys(), map_.get_allocator().arena()); return absl::OkStatus(); } absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { for (const auto& entry : map_) { CEL_ASSIGN_OR_RETURN(auto ok, callback(entry.first, entry.second)); if (!ok) { break; } } return absl::OkStatus(); } absl::StatusOr NewIterator() const override { return std::make_unique(&map_); } absl::optional operator[](CelValue key) const override { return Get(map_.get_allocator().arena(), key); } using MutableCompatMapValue::Get; absl::optional Get(google::protobuf::Arena* arena, CelValue key) const override { if (auto status = CelValue::CheckMapKeyType(key); !status.ok()) { status.IgnoreError(); return absl::nullopt; } if (auto it = map_.find(key); it != map_.end()) { return common_internal::UnsafeLegacyValue( it->second, /*stable=*/false, arena != nullptr ? arena : map_.get_allocator().arena()); } return absl::nullopt; } absl::StatusOr Has(const CelValue& key) const override { // This check safeguards against issues with invalid key types such as NaN. CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(key)); return map_.find(key) != map_.end(); } int size() const override { return static_cast(Size()); } absl::StatusOr ListKeys() const override { return ProjectKeys(); } absl::StatusOr ListKeys(google::protobuf::Arena* arena) const override { return ProjectKeys(); } absl::Status Put(Value key, Value value) const override { CEL_RETURN_IF_ERROR(CheckMapKey(key)); CEL_RETURN_IF_ERROR(CheckMapValue(value)); if (auto it = map_.find(key); ABSL_PREDICT_FALSE(it != map_.end())) { return DuplicateKeyError().ToStatus(); } auto insertion = map_.insert({std::move(key), std::move(value)}); ABSL_DCHECK(insertion.second); if (entries_trivially_destructible_) { entries_trivially_destructible_ = ArenaTraits<>::trivially_destructible(insertion.first->first) && ArenaTraits<>::trivially_destructible(insertion.first->second); if (!entries_trivially_destructible_) { map_.get_allocator().arena()->OwnDestructor( const_cast(this)); } } return absl::OkStatus(); } void Reserve(size_t capacity) const override { map_.reserve(capacity); } protected: absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { CEL_RETURN_IF_ERROR(CheckMapKey(key)); if (auto it = map_.find(key); it != map_.end()) { *result = it->second; return true; } return false; } absl::StatusOr Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { CEL_RETURN_IF_ERROR(CheckMapKey(key)); return map_.find(key) != map_.end(); } private: const CompatListValue* absl_nonnull ProjectKeys() const { absl::call_once(keys_once_, [this]() { ListValueBuilderImpl builder(map_.get_allocator().arena()); builder.Reserve(map_.size()); for (const auto& entry : map_) { builder.UnsafeAdd(entry.first); } std::move(builder).BuildCompatAt(&keys_[0]); }); return std::launder( reinterpret_cast(&keys_[0])); } mutable ValueFlatHashMap map_; mutable bool entries_trivially_destructible_ = true; mutable absl::once_flag keys_once_; alignas(CompatListValueImpl) mutable char keys_[sizeof(CompatListValueImpl)]; }; } // namespace absl::StatusOr MakeCompatMapValue( const CustomMapValue& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { MapValueBuilderImpl builder(arena); builder.Reserve(value.Size()); CEL_RETURN_IF_ERROR(value.ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { CEL_RETURN_IF_ERROR(builder.Put(key, value)); return true; }, descriptor_pool, message_factory, arena)); return std::move(builder).BuildCompat(); } MutableMapValue* absl_nonnull NewMutableMapValue( google::protobuf::Arena* absl_nonnull arena) { return ::new (arena->AllocateAligned(sizeof(TrivialMutableMapValueImpl), alignof(TrivialMutableMapValueImpl))) TrivialMutableMapValueImpl(arena); } bool IsMutableMapValue(const Value& value) { if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { NativeTypeId native_type_id = custom_map_value->GetTypeId(); if (native_type_id == NativeTypeId::For() || native_type_id == NativeTypeId::For()) { return true; } } return false; } bool IsMutableMapValue(const MapValue& value) { if (auto custom_map_value = value.AsCustom(); custom_map_value) { NativeTypeId native_type_id = custom_map_value->GetTypeId(); if (native_type_id == NativeTypeId::For() || native_type_id == NativeTypeId::For()) { return true; } } return false; } const MutableMapValue* absl_nullable AsMutableMapValue(const Value& value) { if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { NativeTypeId native_type_id = custom_map_value->GetTypeId(); if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( custom_map_value->interface()); } if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( custom_map_value->interface()); } } return nullptr; } const MutableMapValue* absl_nullable AsMutableMapValue(const MapValue& value) { if (auto custom_map_value = value.AsCustom(); custom_map_value) { NativeTypeId native_type_id = custom_map_value->GetTypeId(); if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( custom_map_value->interface()); } if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( custom_map_value->interface()); } } return nullptr; } const MutableMapValue& GetMutableMapValue(const Value& value) { ABSL_DCHECK(IsMutableMapValue(value)) << value; const auto& custom_map_value = value.GetCustomMap(); NativeTypeId native_type_id = custom_map_value.GetTypeId(); if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( *custom_map_value.interface()); } if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( *custom_map_value.interface()); } ABSL_UNREACHABLE(); } const MutableMapValue& GetMutableMapValue(const MapValue& value) { ABSL_DCHECK(IsMutableMapValue(value)) << value; const auto& custom_map_value = value.GetCustom(); NativeTypeId native_type_id = custom_map_value.GetTypeId(); if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( *custom_map_value.interface()); } if (native_type_id == NativeTypeId::For()) { return cel::internal::down_cast( *custom_map_value.interface()); } ABSL_UNREACHABLE(); } absl_nonnull cel::MapValueBuilderPtr NewMapValueBuilder( google::protobuf::Arena* absl_nonnull arena) { return std::make_unique(arena); } } // namespace common_internal } // namespace cel ================================================ FILE: common/values/value_builder.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "common/allocator.h" #include "common/value.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::common_internal { // Like NewStructValueBuilder, but deals with well known types. absl_nullable cel::ValueBuilderPtr NewValueBuilder( Allocator<> allocator, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, absl::string_view name); } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_ ================================================ FILE: common/values/value_variant.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "common/values/value_variant.h" #include #include #include #include #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "common/values/bytes_value.h" #include "common/values/error_value.h" #include "common/values/string_value.h" #include "common/values/unknown_value.h" #include "common/values/values.h" namespace cel::common_internal { void ValueVariant::SlowCopyConstruct(const ValueVariant& other) noexcept { ABSL_DCHECK((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNonTrivial); switch (index_) { case ValueIndex::kBytes: ::new (static_cast(&raw_[0])) BytesValue(*other.At()); break; case ValueIndex::kString: ::new (static_cast(&raw_[0])) StringValue(*other.At()); break; case ValueIndex::kError: ::new (static_cast(&raw_[0])) ErrorValue(*other.At()); break; case ValueIndex::kUnknown: ::new (static_cast(&raw_[0])) UnknownValue(*other.At()); break; default: ABSL_UNREACHABLE(); } } void ValueVariant::SlowMoveConstruct(ValueVariant& other) noexcept { ABSL_DCHECK((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNonTrivial); switch (index_) { case ValueIndex::kBytes: ::new (static_cast(&raw_[0])) BytesValue(std::move(*other.At())); break; case ValueIndex::kString: ::new (static_cast(&raw_[0])) StringValue(std::move(*other.At())); break; case ValueIndex::kError: ::new (static_cast(&raw_[0])) ErrorValue(std::move(*other.At())); break; case ValueIndex::kUnknown: ::new (static_cast(&raw_[0])) UnknownValue(std::move(*other.At())); break; default: ABSL_UNREACHABLE(); } } void ValueVariant::SlowDestruct() noexcept { ABSL_DCHECK((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNonTrivial); switch (index_) { case ValueIndex::kBytes: At()->~BytesValue(); break; case ValueIndex::kString: At()->~StringValue(); break; case ValueIndex::kError: At()->~ErrorValue(); break; case ValueIndex::kUnknown: At()->~UnknownValue(); break; default: ABSL_UNREACHABLE(); } } void ValueVariant::SlowCopyAssign(const ValueVariant& other, bool trivial, bool other_trivial) noexcept { ABSL_DCHECK(!trivial || !other_trivial); if (trivial) { switch (other.index_) { case ValueIndex::kBytes: ::new (static_cast(&raw_[0])) BytesValue(*other.At()); break; case ValueIndex::kString: ::new (static_cast(&raw_[0])) StringValue(*other.At()); break; case ValueIndex::kError: ::new (static_cast(&raw_[0])) ErrorValue(*other.At()); break; case ValueIndex::kUnknown: ::new (static_cast(&raw_[0])) UnknownValue(*other.At()); break; default: ABSL_UNREACHABLE(); } index_ = other.index_; kind_ = other.kind_; flags_ = other.flags_; } else if (other_trivial) { switch (index_) { case ValueIndex::kBytes: At()->~BytesValue(); break; case ValueIndex::kString: At()->~StringValue(); break; case ValueIndex::kError: At()->~ErrorValue(); break; case ValueIndex::kUnknown: At()->~UnknownValue(); break; default: ABSL_UNREACHABLE(); } FastCopyAssign(other); } else { switch (index_) { case ValueIndex::kBytes: switch (other.index_) { case ValueIndex::kBytes: *At() = *other.At(); break; case ValueIndex::kString: At()->~BytesValue(); ::new (static_cast(&raw_[0])) StringValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kError: At()->~BytesValue(); ::new (static_cast(&raw_[0])) ErrorValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kUnknown: At()->~BytesValue(); ::new (static_cast(&raw_[0])) UnknownValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; default: ABSL_UNREACHABLE(); } break; case ValueIndex::kString: switch (other.index_) { case ValueIndex::kBytes: At()->~StringValue(); ::new (static_cast(&raw_[0])) BytesValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kString: *At() = *other.At(); break; case ValueIndex::kError: At()->~StringValue(); ::new (static_cast(&raw_[0])) ErrorValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kUnknown: At()->~StringValue(); ::new (static_cast(&raw_[0])) UnknownValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; default: ABSL_UNREACHABLE(); } break; case ValueIndex::kError: switch (other.index_) { case ValueIndex::kBytes: At()->~ErrorValue(); ::new (static_cast(&raw_[0])) BytesValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kString: At()->~ErrorValue(); ::new (static_cast(&raw_[0])) StringValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kError: *At() = *other.At(); break; case ValueIndex::kUnknown: At()->~ErrorValue(); ::new (static_cast(&raw_[0])) UnknownValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; default: ABSL_UNREACHABLE(); } break; case ValueIndex::kUnknown: switch (other.index_) { case ValueIndex::kBytes: At()->~UnknownValue(); ::new (static_cast(&raw_[0])) BytesValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kString: At()->~UnknownValue(); ::new (static_cast(&raw_[0])) StringValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kError: At()->~UnknownValue(); ::new (static_cast(&raw_[0])) ErrorValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kUnknown: At()->~UnknownValue(); ::new (static_cast(&raw_[0])) UnknownValue(*other.At()); index_ = other.index_; kind_ = other.kind_; break; default: ABSL_UNREACHABLE(); } break; default: ABSL_UNREACHABLE(); } flags_ = other.flags_; } } void ValueVariant::SlowMoveAssign(ValueVariant& other, bool trivial, bool other_trivial) noexcept { ABSL_DCHECK(!trivial || !other_trivial); if (trivial) { switch (other.index_) { case ValueIndex::kBytes: ::new (static_cast(&raw_[0])) BytesValue(std::move(*other.At())); break; case ValueIndex::kString: ::new (static_cast(&raw_[0])) StringValue(std::move(*other.At())); break; case ValueIndex::kError: ::new (static_cast(&raw_[0])) ErrorValue(std::move(*other.At())); break; case ValueIndex::kUnknown: ::new (static_cast(&raw_[0])) UnknownValue(std::move(*other.At())); break; default: ABSL_UNREACHABLE(); } index_ = other.index_; kind_ = other.kind_; flags_ = other.flags_; } else if (other_trivial) { switch (index_) { case ValueIndex::kBytes: At()->~BytesValue(); break; case ValueIndex::kString: At()->~StringValue(); break; case ValueIndex::kError: At()->~ErrorValue(); break; case ValueIndex::kUnknown: At()->~UnknownValue(); break; default: ABSL_UNREACHABLE(); } FastMoveAssign(other); } else { switch (index_) { case ValueIndex::kBytes: switch (other.index_) { case ValueIndex::kBytes: *At() = std::move(*other.At()); break; case ValueIndex::kString: At()->~BytesValue(); ::new (static_cast(&raw_[0])) StringValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kError: At()->~BytesValue(); ::new (static_cast(&raw_[0])) ErrorValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kUnknown: At()->~BytesValue(); ::new (static_cast(&raw_[0])) UnknownValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; default: ABSL_UNREACHABLE(); } break; case ValueIndex::kString: switch (other.index_) { case ValueIndex::kBytes: At()->~StringValue(); ::new (static_cast(&raw_[0])) BytesValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kString: *At() = std::move(*other.At()); break; case ValueIndex::kError: At()->~StringValue(); ::new (static_cast(&raw_[0])) ErrorValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kUnknown: At()->~StringValue(); ::new (static_cast(&raw_[0])) UnknownValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; default: ABSL_UNREACHABLE(); } break; case ValueIndex::kError: switch (other.index_) { case ValueIndex::kBytes: At()->~ErrorValue(); ::new (static_cast(&raw_[0])) BytesValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kString: At()->~ErrorValue(); ::new (static_cast(&raw_[0])) StringValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kError: *At() = std::move(*other.At()); break; case ValueIndex::kUnknown: At()->~ErrorValue(); ::new (static_cast(&raw_[0])) UnknownValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; default: ABSL_UNREACHABLE(); } break; case ValueIndex::kUnknown: switch (other.index_) { case ValueIndex::kBytes: At()->~UnknownValue(); ::new (static_cast(&raw_[0])) BytesValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kString: At()->~UnknownValue(); ::new (static_cast(&raw_[0])) StringValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kError: At()->~UnknownValue(); ::new (static_cast(&raw_[0])) ErrorValue(std::move(*other.At())); index_ = other.index_; kind_ = other.kind_; break; case ValueIndex::kUnknown: *At() = std::move(*other.At()); break; default: ABSL_UNREACHABLE(); } break; default: ABSL_UNREACHABLE(); } flags_ = other.flags_; } } void ValueVariant::SlowSwap(ValueVariant& lhs, ValueVariant& rhs, bool lhs_trivial, bool rhs_trivial) noexcept { using std::swap; ABSL_DCHECK(!lhs_trivial || !rhs_trivial); if (lhs_trivial) { alignas(ValueVariant) std::byte tmp[sizeof(ValueVariant)]; // This is acceptable. We know that both are trivially copyable at runtime. // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) std::memcpy(tmp, std::addressof(lhs), sizeof(ValueVariant)); switch (rhs.index_) { case ValueIndex::kBytes: ::new (static_cast(&lhs.raw_[0])) BytesValue(*rhs.At()); rhs.At()->~BytesValue(); break; case ValueIndex::kString: ::new (static_cast(&lhs.raw_[0])) StringValue(*rhs.At()); rhs.At()->~StringValue(); break; case ValueIndex::kError: ::new (static_cast(&lhs.raw_[0])) ErrorValue(*rhs.At()); rhs.At()->~ErrorValue(); break; case ValueIndex::kUnknown: ::new (static_cast(&lhs.raw_[0])) UnknownValue(*rhs.At()); rhs.At()->~UnknownValue(); break; default: ABSL_UNREACHABLE(); } lhs.index_ = rhs.index_; lhs.kind_ = rhs.kind_; lhs.flags_ = rhs.flags_; // This is acceptable. We know that both are trivially copyable at runtime. // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) std::memcpy(std::addressof(rhs), tmp, sizeof(ValueVariant)); } else if (rhs_trivial) { alignas(ValueVariant) std::byte tmp[sizeof(ValueVariant)]; // This is acceptable. We know that both are trivially copyable at runtime. // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) std::memcpy(tmp, std::addressof(rhs), sizeof(ValueVariant)); switch (lhs.index_) { case ValueIndex::kBytes: ::new (static_cast(&rhs.raw_[0])) BytesValue(*lhs.At()); lhs.At()->~BytesValue(); break; case ValueIndex::kString: ::new (static_cast(&rhs.raw_[0])) StringValue(*lhs.At()); lhs.At()->~StringValue(); break; case ValueIndex::kError: ::new (static_cast(&rhs.raw_[0])) ErrorValue(*lhs.At()); lhs.At()->~ErrorValue(); break; case ValueIndex::kUnknown: ::new (static_cast(&rhs.raw_[0])) UnknownValue(*lhs.At()); lhs.At()->~UnknownValue(); break; default: ABSL_UNREACHABLE(); } rhs.index_ = lhs.index_; rhs.kind_ = lhs.kind_; rhs.flags_ = lhs.flags_; // This is acceptable. We know that both are trivially copyable at runtime. // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) std::memcpy(std::addressof(lhs), tmp, sizeof(ValueVariant)); } else { ValueVariant tmp = std::move(lhs); lhs = std::move(rhs); rhs = std::move(tmp); } } } // namespace cel::common_internal ================================================ FILE: common/values/value_variant.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_VARIANT_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_VARIANT_H_ #include #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/meta/type_traits.h" #include "absl/utility/utility.h" #include "common/arena.h" #include "common/value_kind.h" #include "common/values/bool_value.h" #include "common/values/bytes_value.h" #include "common/values/custom_list_value.h" #include "common/values/custom_map_value.h" #include "common/values/custom_struct_value.h" #include "common/values/double_value.h" #include "common/values/duration_value.h" #include "common/values/error_value.h" #include "common/values/int_value.h" #include "common/values/legacy_list_value.h" #include "common/values/legacy_map_value.h" #include "common/values/legacy_struct_value.h" #include "common/values/list_value.h" #include "common/values/map_value.h" #include "common/values/null_value.h" #include "common/values/opaque_value.h" #include "common/values/parsed_json_list_value.h" #include "common/values/parsed_json_map_value.h" #include "common/values/parsed_map_field_value.h" #include "common/values/parsed_message_value.h" #include "common/values/parsed_repeated_field_value.h" #include "common/values/string_value.h" #include "common/values/timestamp_value.h" #include "common/values/type_value.h" #include "common/values/uint_value.h" #include "common/values/unknown_value.h" #include "common/values/values.h" namespace cel { class Value; namespace common_internal { // Used by ValueVariant to indicate the active alternative. enum class ValueIndex : uint8_t { kNull = 0, kBool, kInt, kUint, kDouble, kDuration, kTimestamp, kType, kLegacyList, kParsedJsonList, kParsedRepeatedField, kCustomList, kLegacyMap, kParsedJsonMap, kParsedMapField, kCustomMap, kLegacyStruct, kParsedMessage, kCustomStruct, kOpaque, // Keep non-trivial alternatives together to aid in compiling optimizations. kBytes, kString, kError, kUnknown, }; // Used by ValueVariant to indicate pre-computed behaviors. enum class ValueFlags : uint32_t { kNone = 0, kNonTrivial = 1, }; ABSL_ATTRIBUTE_ALWAYS_INLINE inline constexpr ValueFlags operator&( ValueFlags lhs, ValueFlags rhs) { return static_cast( static_cast>(lhs) & static_cast>(rhs)); } // Traits specialized by each alternative. // // ValueIndex ValueAlternative::kIndex // // Indicates the alternative index corresponding to T. // // ValueKind ValueAlternative::kKind // // Indicatates the kind corresponding to T. // // bool ValueAlternative::kAlwaysTrivial // // True if T is trivially_copyable, false otherwise. // // ValueFlags ValueAlternative::Flags(const T* absl_nonnull ) // // Returns the flags for the corresponding instance of T. template struct ValueAlternative; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kNull; static constexpr ValueKind kKind = NullValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const NullValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kBool; static constexpr ValueKind kKind = BoolValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const BoolValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kInt; static constexpr ValueKind kKind = IntValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const IntValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kUint; static constexpr ValueKind kKind = UintValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const UintValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kDouble; static constexpr ValueKind kKind = DoubleValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const DoubleValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kDuration; static constexpr ValueKind kKind = DurationValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const DurationValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kTimestamp; static constexpr ValueKind kKind = TimestampValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const TimestampValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kType; static constexpr ValueKind kKind = TypeValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const TypeValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kLegacyList; static constexpr ValueKind kKind = LegacyListValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const LegacyListValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kParsedJsonList; static constexpr ValueKind kKind = ParsedJsonListValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const ParsedJsonListValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kParsedRepeatedField; static constexpr ValueKind kKind = ParsedRepeatedFieldValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags( const ParsedRepeatedFieldValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kCustomList; static constexpr ValueKind kKind = CustomListValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const CustomListValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kLegacyMap; static constexpr ValueKind kKind = LegacyMapValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const LegacyMapValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kParsedJsonMap; static constexpr ValueKind kKind = ParsedJsonMapValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const ParsedJsonMapValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kParsedMapField; static constexpr ValueKind kKind = ParsedMapFieldValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const ParsedMapFieldValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kCustomMap; static constexpr ValueKind kKind = CustomMapValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const CustomMapValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kLegacyStruct; static constexpr ValueKind kKind = LegacyStructValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const LegacyStructValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kParsedMessage; static constexpr ValueKind kKind = ParsedMessageValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const ParsedMessageValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kCustomStruct; static constexpr ValueKind kKind = CustomStructValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const CustomStructValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kOpaque; static constexpr ValueKind kKind = OpaqueValue::kKind; static constexpr bool kAlwaysTrivial = true; static constexpr ValueFlags Flags(const OpaqueValue* absl_nonnull) { return ValueFlags::kNone; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kBytes; static constexpr ValueKind kKind = BytesValue::kKind; static constexpr bool kAlwaysTrivial = false; static ValueFlags Flags(const BytesValue* absl_nonnull alternative) { return ArenaTraits::trivially_destructible(*alternative) ? ValueFlags::kNone : ValueFlags::kNonTrivial; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kString; static constexpr ValueKind kKind = StringValue::kKind; static constexpr bool kAlwaysTrivial = false; static ValueFlags Flags(const StringValue* absl_nonnull alternative) { return ArenaTraits::trivially_destructible(*alternative) ? ValueFlags::kNone : ValueFlags::kNonTrivial; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kError; static constexpr ValueKind kKind = ErrorValue::kKind; static constexpr bool kAlwaysTrivial = false; static ValueFlags Flags(const ErrorValue* absl_nonnull alternative) { return ArenaTraits::trivially_destructible(*alternative) ? ValueFlags::kNone : ValueFlags::kNonTrivial; } }; template <> struct ValueAlternative { static constexpr ValueIndex kIndex = ValueIndex::kUnknown; static constexpr ValueKind kKind = UnknownValue::kKind; static constexpr bool kAlwaysTrivial = false; static constexpr ValueFlags Flags(const UnknownValue* absl_nonnull) { return ValueFlags::kNonTrivial; } }; template struct IsValueAlternative : std::false_type {}; template struct IsValueAlternative{})>> : std::true_type {}; template inline constexpr bool IsValueAlternativeV = IsValueAlternative::value; // Alignment and size of the storage inside ValueVariant, not for ValueVariant // itself. inline constexpr size_t kValueVariantAlign = 8; inline constexpr size_t kValueVariantSize = 24; // Hand-rolled variant used by cel::Value which exhibits up to a 25% performance // improvement compared to using std::variant. // // The implementation abuses the fact that most alternatives are trivially // copyable and some are conditionally trivially copyable at runtime. For the // fast path, we perform raw byte copying. For the slow path, we fallback to a // non-inlined function. The compiler is typically smart enough to inline the // fast path and emit efficient instructions for the raw byte copying (usually // two instructions). It also uses switch for visiting, which most compilers can // optimize better compared to a function pointer table (which libc++ currently // uses and Clang currently does not optimize well). class alignas(kValueVariantAlign) CEL_COMMON_INTERNAL_VALUE_VARIANT_TRIVIAL_ABI ValueVariant final { public: ValueVariant() = default; ValueVariant(const ValueVariant& other) noexcept : index_(other.index_), kind_(other.kind_), flags_(other.flags_) { if ((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone) { std::memcpy(raw_, other.raw_, sizeof(raw_)); } else { SlowCopyConstruct(other); } } ValueVariant(ValueVariant&& other) noexcept : index_(other.index_), kind_(other.kind_), flags_(other.flags_) { if ((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone) { std::memcpy(raw_, other.raw_, sizeof(raw_)); } else { SlowMoveConstruct(other); } } ~ValueVariant() { if ((flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNonTrivial) { SlowDestruct(); } } ValueVariant& operator=(const ValueVariant& other) noexcept { if (this != &other) { const bool trivial = (flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; const bool other_trivial = (other.flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; if (trivial && other_trivial) { FastCopyAssign(other); } else { SlowCopyAssign(other, trivial, other_trivial); } } return *this; } ValueVariant& operator=(ValueVariant&& other) noexcept { if (this != &other) { const bool trivial = (flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; const bool other_trivial = (other.flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; if (trivial && other_trivial) { FastMoveAssign(other); } else { SlowMoveAssign(other, trivial, other_trivial); } } return *this; } template explicit ValueVariant(absl::in_place_type_t, Args&&... args) : index_(ValueAlternative::kIndex), kind_(ValueAlternative::kKind) { static_assert(alignof(T) <= kValueVariantAlign); static_assert(sizeof(T) <= kValueVariantSize); flags_ = ValueAlternative::Flags(::new (static_cast(&raw_[0])) T(std::forward(args)...)); } template >>> explicit ValueVariant(T&& value) : ValueVariant(absl::in_place_type>, std::forward(value)) {} ValueKind kind() const { return kind_; } template void Assign(T&& value) { using U = absl::remove_cvref_t; static_assert(alignof(U) <= kValueVariantAlign); static_assert(sizeof(U) <= kValueVariantSize); if constexpr (ValueAlternative::kAlwaysTrivial) { if ((flags_ & ValueFlags::kNonTrivial) != ValueFlags::kNone) { SlowDestruct(); } index_ = ValueAlternative::kIndex; kind_ = ValueAlternative::kKind; flags_ = ValueAlternative::Flags(::new (static_cast(&raw_[0])) U(std::forward(value))); } else { // U is not always trivial. See if the current active alternative is U. If // it is, we can just do a simple assignment without having to destruct // first. Otherwise fallback to destruct and construct. if (index_ == ValueAlternative::kIndex) { *At() = std::forward(value); flags_ = ValueAlternative::Flags(At()); } else { if ((flags_ & ValueFlags::kNonTrivial) != ValueFlags::kNone) { SlowDestruct(); } index_ = ValueAlternative::kIndex; kind_ = ValueAlternative::kKind; flags_ = ValueAlternative::Flags(::new (static_cast(&raw_[0])) U(std::forward(value))); } } } template bool Is() const { return index_ == ValueAlternative::kIndex; } template T& Get() & ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return *At(); } template const T& Get() const& ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return *At(); } template T&& Get() && ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return std::move(*At()); } template const T&& Get() const&& ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Is()); return std::move(*At()); } template T* absl_nullable As() ABSL_ATTRIBUTE_LIFETIME_BOUND { if (Is()) { return At(); } return nullptr; } template const T* absl_nullable As() const ABSL_ATTRIBUTE_LIFETIME_BOUND { if (Is()) { return At(); } return nullptr; } template ABSL_ATTRIBUTE_ALWAYS_INLINE decltype(auto) Visit(Visitor&& visitor) & { return std::as_const(*this).Visit(std::forward(visitor)); } template decltype(auto) Visit(Visitor&& visitor) const& { switch (index_) { case ValueIndex::kNull: return std::forward(visitor)(Get()); case ValueIndex::kBool: return std::forward(visitor)(Get()); case ValueIndex::kInt: return std::forward(visitor)(Get()); case ValueIndex::kUint: return std::forward(visitor)(Get()); case ValueIndex::kDouble: return std::forward(visitor)(Get()); case ValueIndex::kDuration: return std::forward(visitor)(Get()); case ValueIndex::kTimestamp: return std::forward(visitor)(Get()); case ValueIndex::kType: return std::forward(visitor)(Get()); case ValueIndex::kLegacyList: return std::forward(visitor)(Get()); case ValueIndex::kParsedJsonList: return std::forward(visitor)(Get()); case ValueIndex::kParsedRepeatedField: return std::forward(visitor)(Get()); case ValueIndex::kCustomList: return std::forward(visitor)(Get()); case ValueIndex::kLegacyMap: return std::forward(visitor)(Get()); case ValueIndex::kParsedJsonMap: return std::forward(visitor)(Get()); case ValueIndex::kParsedMapField: return std::forward(visitor)(Get()); case ValueIndex::kCustomMap: return std::forward(visitor)(Get()); case ValueIndex::kLegacyStruct: return std::forward(visitor)(Get()); case ValueIndex::kParsedMessage: return std::forward(visitor)(Get()); case ValueIndex::kCustomStruct: return std::forward(visitor)(Get()); case ValueIndex::kOpaque: return std::forward(visitor)(Get()); case ValueIndex::kBytes: return std::forward(visitor)(Get()); case ValueIndex::kString: return std::forward(visitor)(Get()); case ValueIndex::kError: return std::forward(visitor)(Get()); case ValueIndex::kUnknown: return std::forward(visitor)(Get()); } } template decltype(auto) Visit(Visitor&& visitor) && { switch (index_) { case ValueIndex::kNull: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kBool: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kInt: return std::forward(visitor)(std::move(*this).Get()); case ValueIndex::kUint: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kDouble: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kDuration: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kTimestamp: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kType: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kLegacyList: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kParsedJsonList: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kParsedRepeatedField: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kCustomList: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kLegacyMap: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kParsedJsonMap: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kParsedMapField: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kCustomMap: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kLegacyStruct: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kParsedMessage: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kCustomStruct: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kOpaque: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kBytes: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kString: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kError: return std::forward(visitor)( std::move(*this).Get()); case ValueIndex::kUnknown: return std::forward(visitor)( std::move(*this).Get()); } } template ABSL_ATTRIBUTE_ALWAYS_INLINE decltype(auto) Visit(Visitor&& visitor) const&& { return Visit(std::forward(visitor)); } friend void swap(ValueVariant& lhs, ValueVariant& rhs) noexcept { if (&lhs != &rhs) { const bool lhs_trivial = (lhs.flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; const bool rhs_trivial = (rhs.flags_ & ValueFlags::kNonTrivial) == ValueFlags::kNone; if (lhs_trivial && rhs_trivial) { // We validated the instances can be copied byte-wise at runtime, but compilers // warn since this is not safe in the general case. #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wclass-memaccess" #elif defined(__clang__) && __clang_major__ >= 20 #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wnontrivial-memcall" #endif alignas(ValueVariant) std::byte tmp[sizeof(ValueVariant)]; // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) std::memcpy(tmp, std::addressof(lhs), sizeof(ValueVariant)); // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) std::memcpy(std::addressof(lhs), std::addressof(rhs), sizeof(ValueVariant)); // NOLINTNEXTLINE(bugprone-undefined-memory-manipulation) std::memcpy(std::addressof(rhs), tmp, sizeof(ValueVariant)); #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic pop #elif defined(__clang__) && __clang_major__ >= 20 #pragma clang diagnostic pop #endif } else { SlowSwap(lhs, rhs, lhs_trivial, rhs_trivial); } } } private: friend struct cel::ArenaTraits; template ABSL_ATTRIBUTE_ALWAYS_INLINE T* absl_nonnull At() ABSL_ATTRIBUTE_LIFETIME_BOUND { static_assert(alignof(T) <= kValueVariantAlign); static_assert(sizeof(T) <= kValueVariantSize); return std::launder(reinterpret_cast(&raw_[0])); } template ABSL_ATTRIBUTE_ALWAYS_INLINE const T* absl_nonnull At() const ABSL_ATTRIBUTE_LIFETIME_BOUND { static_assert(alignof(T) <= kValueVariantAlign); static_assert(sizeof(T) <= kValueVariantSize); return std::launder(reinterpret_cast(&raw_[0])); } ABSL_ATTRIBUTE_ALWAYS_INLINE void FastCopyAssign( const ValueVariant& other) noexcept { index_ = other.index_; kind_ = other.kind_; flags_ = other.flags_; std::memcpy(raw_, other.raw_, sizeof(raw_)); } ABSL_ATTRIBUTE_ALWAYS_INLINE void FastMoveAssign( ValueVariant& other) noexcept { FastCopyAssign(other); } void SlowCopyConstruct(const ValueVariant& other) noexcept; void SlowMoveConstruct(ValueVariant& other) noexcept; void SlowDestruct() noexcept; void SlowCopyAssign(const ValueVariant& other, bool trivial, bool other_trivial) noexcept; void SlowMoveAssign(ValueVariant& other, bool ntrivial, bool other_trivial) noexcept; static void SlowSwap(ValueVariant& lhs, ValueVariant& rhs, bool lhs_trivial, bool rhs_trivial) noexcept; ValueIndex index_ = ValueIndex::kNull; ValueKind kind_ = ValueKind::kNull; ValueFlags flags_ = ValueFlags::kNone; alignas(kValueVariantAlign) std::byte raw_[kValueVariantSize]; }; } // namespace common_internal template <> struct ArenaTraits { static bool trivially_destructible( const common_internal::ValueVariant& value) { return (value.flags_ & common_internal::ValueFlags::kNonTrivial) == common_internal::ValueFlags::kNone; } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_VARIANT_H_ ================================================ FILE: common/values/value_variant_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/strings/cord.h" #include "common/value.h" #include "internal/testing.h" namespace cel::common_internal { namespace { template class ValueVariantTest : public ::testing::Test {}; #define VALUE_VARIANT_TYPES(T) \ std::pair, std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair using ValueVariantTypes = ::testing::Types< VALUE_VARIANT_TYPES(NullValue), VALUE_VARIANT_TYPES(BoolValue), VALUE_VARIANT_TYPES(IntValue), VALUE_VARIANT_TYPES(UintValue), VALUE_VARIANT_TYPES(DoubleValue), VALUE_VARIANT_TYPES(DurationValue), VALUE_VARIANT_TYPES(TimestampValue), VALUE_VARIANT_TYPES(TypeValue), VALUE_VARIANT_TYPES(LegacyListValue), VALUE_VARIANT_TYPES(ParsedJsonListValue), VALUE_VARIANT_TYPES(ParsedRepeatedFieldValue), VALUE_VARIANT_TYPES(CustomListValue), VALUE_VARIANT_TYPES(LegacyMapValue), VALUE_VARIANT_TYPES(ParsedJsonMapValue), VALUE_VARIANT_TYPES(ParsedMapFieldValue), VALUE_VARIANT_TYPES(CustomMapValue), VALUE_VARIANT_TYPES(LegacyStructValue), VALUE_VARIANT_TYPES(ParsedMessageValue), VALUE_VARIANT_TYPES(CustomStructValue), VALUE_VARIANT_TYPES(OpaqueValue), VALUE_VARIANT_TYPES(BytesValue), VALUE_VARIANT_TYPES(StringValue), VALUE_VARIANT_TYPES(ErrorValue), VALUE_VARIANT_TYPES(UnknownValue)>; template struct DefaultValue { T operator()() const { return T(); } }; template <> struct DefaultValue { BytesValue operator()() const { return BytesValue( absl::Cord("Some somewhat large string that is not storable inline!")); } }; template <> struct DefaultValue { StringValue operator()() const { return StringValue( absl::Cord("Some somewhat large string that is not storable inline!")); } }; #undef VALUE_VARIANT_TYPES TYPED_TEST_SUITE(ValueVariantTest, ValueVariantTypes); TYPED_TEST(ValueVariantTest, CopyAssign) { using Left = typename TypeParam::first_type; using Right = typename TypeParam::second_type; ValueVariant lhs(DefaultValue{}()); ValueVariant rhs(DefaultValue{}()); EXPECT_TRUE(lhs.Is()); lhs = rhs; EXPECT_TRUE(lhs.Is()); EXPECT_TRUE(rhs.Is()); } TYPED_TEST(ValueVariantTest, MoveAssign) { using Left = typename TypeParam::first_type; using Right = typename TypeParam::second_type; ValueVariant lhs(DefaultValue{}()); ValueVariant rhs(DefaultValue{}()); EXPECT_TRUE(lhs.Is()); lhs = std::move(rhs); EXPECT_TRUE(lhs.Is()); } TYPED_TEST(ValueVariantTest, Swap) { using Left = typename TypeParam::first_type; using Right = typename TypeParam::second_type; ValueVariant lhs(DefaultValue{}()); ValueVariant rhs(DefaultValue{}()); swap(lhs, rhs); EXPECT_TRUE(lhs.Is()); EXPECT_TRUE(rhs.Is()); } } // namespace } // namespace cel::common_internal ================================================ FILE: common/values/values.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // IWYU pragma: private #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUES_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUES_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" // absl::Cord is trivially relocatable IFF we are not using ASan or MSan. When // using ASan or MSan absl::Cord will poison/unpoison its inline storage. #if defined(ABSL_HAVE_ADDRESS_SANITIZER) || defined(ABSL_HAVE_MEMORY_SANITIZER) #define CEL_COMMON_INTERNAL_VALUE_VARIANT_TRIVIAL_ABI #else #define CEL_COMMON_INTERNAL_VALUE_VARIANT_TRIVIAL_ABI ABSL_ATTRIBUTE_TRIVIAL_ABI #endif namespace cel { class ValueInterface; class ListValueInterface; class StructValueInterface; class Value; class BoolValue; class BytesValue; class DoubleValue; class DurationValue; class ABSL_ATTRIBUTE_TRIVIAL_ABI ErrorValue; class IntValue; class ListValue; class MapValue; class NullValue; class OpaqueValue; class OptionalValue; class StringValue; class StructValue; class TimestampValue; class TypeValue; class UintValue; class UnknownValue; class ParsedMessageValue; class ParsedMapFieldValue; class ParsedRepeatedFieldValue; class ParsedJsonListValue; class ParsedJsonMapValue; class CustomListValue; class CustomListValueInterface; class CustomMapValue; class CustomMapValueInterface; class CustomStructValue; class CustomStructValueInterface; class ValueIterator; using ValueIteratorPtr = std::unique_ptr; class ValueIterator { public: virtual ~ValueIterator() = default; virtual bool HasNext() = 0; // Returns a view of the next value. If the underlying implementation cannot // directly return a view of a value, the value will be stored in `scratch`, // and the returned view will be that of `scratch`. virtual absl::Status Next( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) = 0; absl::StatusOr Next( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); // Next1 returns values for lists and keys for maps. virtual absl::StatusOr Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull key_or_value); absl::StatusOr> Next1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); // Next2 returns indices (in ascending order) and values for lists and keys // (in any order) and values for maps. virtual absl::StatusOr Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nullable key, Value* absl_nullable value) = 0; absl::StatusOr>> Next2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); }; namespace common_internal { class SharedByteString; class SharedByteStringView; class LegacyListValue; class LegacyMapValue; class LegacyStructValue; class ListValueVariant; class MapValueVariant; class StructValueVariant; class CEL_COMMON_INTERNAL_VALUE_VARIANT_TRIVIAL_ABI ValueVariant; ErrorValue GetDefaultErrorValue(); CustomListValue GetEmptyDynListValue(); CustomMapValue GetEmptyDynDynMapValue(); OptionalValue GetEmptyDynOptionalValue(); absl::Status ListValueEqual( const ListValue& lhs, const ListValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); absl::Status ListValueEqual( const CustomListValueInterface& lhs, const ListValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); absl::Status MapValueEqual( const MapValue& lhs, const MapValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); absl::Status MapValueEqual( const CustomMapValueInterface& lhs, const MapValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); absl::Status StructValueEqual( const StructValue& lhs, const StructValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); absl::Status StructValueEqual( const CustomStructValueInterface& lhs, const StructValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result); const SharedByteString& AsSharedByteString(const BytesValue& value); const SharedByteString& AsSharedByteString(const StringValue& value); using ListValueForEachCallback = absl::FunctionRef(const Value&)>; using ListValueForEach2Callback = absl::FunctionRef(size_t, const Value&)>; template class ValueMixin { public: absl::StatusOr Equal( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; friend Base; }; template class ListValueMixin : public ValueMixin { public: using ValueMixin::Equal; absl::StatusOr Get( size_t index, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; using ForEachCallback = absl::FunctionRef(const Value&)>; absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { return static_cast(this)->ForEach( [callback](size_t, const Value& value) -> absl::StatusOr { return callback(value); }, descriptor_pool, message_factory, arena); } absl::StatusOr Contains( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; friend Base; }; template class MapValueMixin : public ValueMixin { public: using ValueMixin::Equal; absl::StatusOr Get( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::StatusOr> Find( const Value& other, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::StatusOr Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::StatusOr ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; friend Base; }; template class StructValueMixin : public ValueMixin { public: using ValueMixin::Equal; absl::StatusOr GetFieldByName( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::Status GetFieldByName( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return static_cast(this)->GetFieldByName( name, ProtoWrapperTypeOptions::kUnsetNull, descriptor_pool, message_factory, arena, result); } absl::StatusOr GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::StatusOr GetFieldByNumber( int64_t number, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::Status GetFieldByNumber( int64_t number, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return static_cast(this)->GetFieldByNumber( number, ProtoWrapperTypeOptions::kUnsetNull, descriptor_pool, message_factory, arena, result); } absl::StatusOr GetFieldByNumber( int64_t number, ProtoWrapperTypeOptions unboxing_options, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; absl::StatusOr> Qualify( absl::Span qualifiers, bool presence_test, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; friend Base; }; template class OpaqueValueMixin : public ValueMixin { public: using ValueMixin::Equal; friend Base; }; } // namespace common_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUES_H_ ================================================ FILE: compiler/BUILD ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) cc_library( name = "compiler", hdrs = ["compiler.h"], deps = [ "//checker:checker_options", "//checker:type_checker", "//checker:type_checker_builder", "//checker:validation_result", "//parser:options", "//parser:parser_interface", "//validator", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "compiler_factory", srcs = ["compiler_factory.cc"], hdrs = ["compiler_factory.h"], deps = [ ":compiler", "//checker:type_checker", "//checker:type_checker_builder", "//checker:type_checker_builder_factory", "//checker:validation_result", "//common:source", "//internal:noop_delete", "//internal:status_macros", "//parser", "//parser:parser_interface", "//validator", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "compiler_factory_test", srcs = ["compiler_factory_test.cc"], deps = [ ":compiler", ":compiler_factory", ":optional", ":standard_library", "//checker:optional", "//checker:standard_library", "//checker:type_check_issue", "//checker:type_checker", "//checker:validation_result", "//common:decl", "//common:source", "//common:type", "//internal:testing", "//internal:testing_descriptor_pool", "//parser:macro", "//parser:parser_interface", "//testutil:baseline_tests", "//validator:timestamp_literal_validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "optional", srcs = ["optional.cc"], hdrs = ["optional.h"], deps = [ ":compiler", "//checker:optional", "//parser:macro", "//parser:parser_interface", "@com_google_absl//absl/status", ], ) cc_test( name = "optional_test", srcs = ["optional_test.cc"], deps = [ ":compiler", ":compiler_factory", ":optional", ":standard_library", "//checker:optional", "//checker:standard_library", "//checker:type_check_issue", "//checker:validation_result", "//common:decl", "//common:source", "//common:type", "//internal:testing", "//internal:testing_descriptor_pool", "//testutil:baseline_tests", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", ], ) cc_library( name = "standard_library", srcs = ["standard_library.cc"], hdrs = ["standard_library.h"], deps = [ ":compiler", "//checker:standard_library", "//internal:status_macros", "//parser:macro", "//parser:parser_interface", "@com_google_absl//absl/status", ], ) cc_library( name = "compiler_library_subset_factory", srcs = ["compiler_library_subset_factory.cc"], hdrs = ["compiler_library_subset_factory.h"], deps = [ ":compiler", "//checker:type_checker_subset_factory", "//parser:parser_subset_factory", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) cc_test( name = "compiler_library_subset_factory_test", srcs = ["compiler_library_subset_factory_test.cc"], deps = [ ":compiler", ":compiler_factory", ":compiler_library_subset_factory", ":standard_library", "//checker:validation_result", "//common:standard_definitions", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) ================================================ FILE: compiler/compiler.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ #define THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "parser/options.h" #include "parser/parser_interface.h" #include "validator/validator.h" #include "google/protobuf/arena.h" namespace cel { class Compiler; class CompilerBuilder; // A CompilerLibrary represents a package of CEL configuration that can be // added to a Compiler. // // It may contain either or both of a Parser configuration and a // TypeChecker configuration. struct CompilerLibrary { // Optional identifier to avoid collisions re-adding the same library. // If id is empty, it is not considered. std::string id; // Optional callback for configuring the parser. ParserBuilderConfigurer configure_parser; // Optional callback for configuring the type checker. TypeCheckerBuilderConfigurer configure_checker; CompilerLibrary(std::string id, ParserBuilderConfigurer configure_parser, TypeCheckerBuilderConfigurer configure_checker = nullptr) : id(std::move(id)), configure_parser(std::move(configure_parser)), configure_checker(std::move(configure_checker)) {} CompilerLibrary(std::string id, TypeCheckerBuilderConfigurer configure_checker) : id(std::move(id)), configure_parser(std::move(nullptr)), configure_checker(std::move(configure_checker)) {} // Convenience conversion from the CheckerLibrary type. // // Note: if a related CompilerLibrary exists, prefer to use that to // include expected parser configuration. static CompilerLibrary FromCheckerLibrary(CheckerLibrary checker_library) { return CompilerLibrary(std::move(checker_library.id), /*configure_parser=*/nullptr, std::move(checker_library.configure)); } // For backwards compatibility. To be removed. // NOLINTNEXTLINE(google-explicit-constructor) CompilerLibrary(CheckerLibrary checker_library) : id(std::move(checker_library.id)), configure_parser(nullptr), configure_checker(std::move(checker_library.configure)) {} }; struct CompilerLibrarySubset { // The id of the library to subset. Only one subset can be applied per // library id. // // Must be non-empty. std::string library_id; ParserLibrarySubset::MacroPredicate should_include_macro; TypeCheckerSubset::FunctionPredicate should_include_overload; // TODO(uncreated-issue/71): to faithfully report the subset back, we need to track // the default (include or exclude) behavior for each of the predicates. }; // General options for configuring the underlying parser and checker. struct CompilerOptions { ParserOptions parser_options; CheckerOptions checker_options; }; // Interface for CEL CompilerBuilder objects. // // Builder implementations do not provide any synchronization themselves, // but create thread-compatible Compiler instances. class CompilerBuilder { public: virtual ~CompilerBuilder() = default; virtual absl::Status AddLibrary(CompilerLibrary library) = 0; virtual absl::Status AddLibrarySubset(CompilerLibrarySubset subset) = 0; virtual TypeCheckerBuilder& GetCheckerBuilder() = 0; virtual ParserBuilder& GetParserBuilder() = 0; virtual Validator& GetValidator() = 0; virtual absl::StatusOr> Build() = 0; }; // Interface for CEL Compiler objects. // // For CEL, compilation is the process of bundling the parse and type-check // passes. // // Compiler instances should be thread-compatible. class Compiler { public: virtual ~Compiler() = default; virtual absl::StatusOr Compile( absl::string_view source, absl::string_view description, google::protobuf::Arena* absl_nullable arena) const = 0; absl::StatusOr Compile(absl::string_view source) const { return Compile(source, "", nullptr); } absl::StatusOr Compile( absl::string_view source, absl::string_view description) const { return Compile(source, description, nullptr); } // Accessor for the underlying type checker. virtual const TypeChecker& GetTypeChecker() const = 0; // Accessor for the underlying parser. virtual const Parser& GetParser() const = 0; // Accessor for the underlying validator. virtual const Validator& GetValidator() const = 0; // Returns a builder initialized with the configuration of this compiler. // // The returned builder is a copy of the validated environment and may // behave differently than the builder that created this compiler. // // The returned builder does not share state with the compiler and may be // modified independently. virtual std::unique_ptr ToBuilder() const = 0; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_INTERFACE_H_ ================================================ FILE: compiler/compiler_factory.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "compiler/compiler_factory.h" #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" #include "common/source.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/parser.h" #include "parser/parser_interface.h" #include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { class CompilerImpl : public Compiler { public: CompilerImpl(std::unique_ptr type_checker, std::unique_ptr parser, // Copy the validator in case builder is reused. Validator validator) : type_checker_(std::move(type_checker)), parser_(std::move(parser)), validator_(std::move(validator)) {} absl::StatusOr Compile( absl::string_view expression, absl::string_view description, google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); CEL_ASSIGN_OR_RETURN(ValidationResult result, type_checker_->Check(std::move(ast), arena)); result.SetSource(std::move(source)); if (!validator_.validations().empty()) { validator_.UpdateValidationResult(result); } return result; } std::unique_ptr ToBuilder() const override; const TypeChecker& GetTypeChecker() const override { return *type_checker_; } const Parser& GetParser() const override { return *parser_; } const Validator& GetValidator() const override { return validator_; } private: std::unique_ptr type_checker_; std::unique_ptr parser_; Validator validator_; }; class CompilerBuilderImpl : public CompilerBuilder { public: CompilerBuilderImpl(std::unique_ptr type_checker_builder, std::unique_ptr parser_builder, Validator validator = Validator()) : type_checker_builder_(std::move(type_checker_builder)), parser_builder_(std::move(parser_builder)), validator_(std::move(validator)) {} absl::Status AddLibrary(CompilerLibrary library) override { if (!library.id.empty()) { auto [it, inserted] = library_ids_.insert(library.id); if (!inserted) { return absl::AlreadyExistsError( absl::StrCat("library already exists: ", library.id)); } } if (library.configure_checker) { CEL_RETURN_IF_ERROR(type_checker_builder_->AddLibrary({ .id = library.id, .configure = std::move(library.configure_checker), })); } if (library.configure_parser) { CEL_RETURN_IF_ERROR(parser_builder_->AddLibrary({ .id = library.id, .configure = std::move(library.configure_parser), })); } return absl::OkStatus(); } absl::Status AddLibrarySubset(CompilerLibrarySubset subset) override { if (subset.library_id.empty()) { return absl::InvalidArgumentError("library id must not be empty"); } std::string library_id = subset.library_id; auto [it, inserted] = subsets_.insert(library_id); if (!inserted) { return absl::AlreadyExistsError( absl::StrCat("library subset already exists for: ", library_id)); } if (subset.should_include_macro) { CEL_RETURN_IF_ERROR(parser_builder_->AddLibrarySubset({ library_id, std::move(subset.should_include_macro), })); } if (subset.should_include_overload) { CEL_RETURN_IF_ERROR(type_checker_builder_->AddLibrarySubset( {library_id, std::move(subset.should_include_overload)})); } return absl::OkStatus(); } ParserBuilder& GetParserBuilder() override { return *parser_builder_; } TypeCheckerBuilder& GetCheckerBuilder() override { return *type_checker_builder_; } Validator& GetValidator() override { return validator_; } absl::StatusOr> Build() override { CEL_ASSIGN_OR_RETURN(auto parser, parser_builder_->Build()); CEL_ASSIGN_OR_RETURN(auto type_checker, type_checker_builder_->Build()); return std::make_unique(std::move(type_checker), std::move(parser), validator_); } private: std::unique_ptr type_checker_builder_; std::unique_ptr parser_builder_; Validator validator_; absl::flat_hash_set library_ids_; absl::flat_hash_set subsets_; }; std::unique_ptr CompilerImpl::ToBuilder() const { auto builder = std::make_unique( type_checker_->ToBuilder(), parser_->ToBuilder(), validator_); return builder; } } // namespace absl::StatusOr> NewCompilerBuilder( std::shared_ptr descriptor_pool, CompilerOptions options) { if (descriptor_pool == nullptr) { return absl::InvalidArgumentError("descriptor_pool must not be null"); } CEL_ASSIGN_OR_RETURN(auto type_checker_builder, CreateTypeCheckerBuilder(std::move(descriptor_pool), options.checker_options)); auto parser_builder = NewParserBuilder(options.parser_options); return std::make_unique(std::move(type_checker_builder), std::move(parser_builder)); } } // namespace cel ================================================ FILE: compiler/compiler_factory.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ #include #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "compiler/compiler.h" #include "internal/noop_delete.h" #include "google/protobuf/descriptor.h" namespace cel { // Creates a new unconfigured CompilerBuilder for creating a new CEL Compiler // instance. // // The builder is thread-hostile and intended to be configured by a single // thread, but the created Compiler instances are thread-compatible (and // effectively immutable). // // The descriptor pool must include the standard definitions for the protobuf // well-known types: // - google.protobuf.NullValue // - google.protobuf.BoolValue // - google.protobuf.Int32Value // - google.protobuf.Int64Value // - google.protobuf.UInt32Value // - google.protobuf.UInt64Value // - google.protobuf.FloatValue // - google.protobuf.DoubleValue // - google.protobuf.BytesValue // - google.protobuf.StringValue // - google.protobuf.Any // - google.protobuf.Duration // - google.protobuf.Timestamp absl::StatusOr> NewCompilerBuilder( std::shared_ptr descriptor_pool, CompilerOptions options = {}); // Convenience overload for non-owning pointers (such as the generated pool). // The descriptor pool must outlive the compiler builder and any compiler // instances it builds. inline absl::StatusOr> NewCompilerBuilder( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, CompilerOptions options = {}) { return NewCompilerBuilder( std::shared_ptr( descriptor_pool, internal::NoopDeleteFor()), std::move(options)); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_FACTORY_H_ ================================================ FILE: compiler/compiler_factory_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "compiler/compiler_factory.h" #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/match.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_check_issue.h" #include "checker/type_checker.h" #include "checker/validation_result.h" #include "common/decl.h" #include "common/source.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/optional.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/macro.h" #include "parser/parser_interface.h" #include "testutil/baseline_tests.h" #include "validator/timestamp_literal_validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::test::FormatBaselineAst; using ::testing::Contains; using ::testing::HasSubstr; using ::testing::Property; using ::testing::Truly; TEST(CompilerFactoryTest, Works) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); ASSERT_OK_AND_ASSIGN( ValidationResult result, compiler->Compile("['a', 'b', 'c'].exists(x, x in ['c', 'd', 'e']) && 10 " "< (5 % 3 * 2 + 1 - 2)")); ASSERT_TRUE(result.IsValid()); EXPECT_EQ(FormatBaselineAst(*result.GetAst()), R"(_&&_( __comprehension__( // Variable x, // Target [ "a"~string, "b"~string, "c"~string ]~list(string), // Accumulator @result, // Init false~bool, // LoopCondition @not_strictly_false( !_( @result~bool^@result )~bool^logical_not )~bool^not_strictly_false, // LoopStep _||_( @result~bool^@result, @in( x~string^x, [ "c"~string, "d"~string, "e"~string ]~list(string) )~bool^in_list )~bool^logical_or, // Result @result~bool^@result)~bool, _<_( 10~int, _-_( _+_( _*_( _%_( 5~int, 3~int )~int^modulo_int64, 2~int )~int^multiply_int64, 1~int )~int^add_int64, 2~int )~int^subtract_int64 )~bool^less_int64 )~bool^logical_and)"); } TEST(CompilerFactoryTest, ParserLibrary) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT( builder->AddLibrary({"test", [](ParserBuilder& builder) -> absl::Status { builder.GetOptions().disable_standard_macros = true; return builder.AddMacro(cel::HasMacro()); }}), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("a", MapType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); ASSERT_THAT(compiler->Compile("has(a.b)"), IsOk()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("[].map(x, x)")); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), Contains(Property(&TypeCheckIssue::message, HasSubstr("undeclared reference to 'map'")))) << result.GetIssues()[2].message(); } TEST(CompilerFactoryTest, ParserOptions) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); builder->GetParserBuilder().GetOptions().enable_optional_syntax = true; ASSERT_THAT(builder->AddLibrary(OptionalCheckerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("a", MapType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); ASSERT_THAT(compiler->Compile("a.?b.orValue('foo')"), IsOk()); } TEST(CompilerFactoryTest, GetParser) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); const cel::Parser& parser = compiler->GetParser(); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("Or(a, b)")); ASSERT_OK_AND_ASSIGN(auto ast, parser.Parse(*source)); } TEST(CompilerFactoryTest, GetTypeChecker) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); absl::Status s; s.Update(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("a", BoolType()))); s.Update(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("b", BoolType()))); ASSERT_OK_AND_ASSIGN( auto or_decl, MakeFunctionDecl("Or", MakeOverloadDecl("Or_bool_bool", BoolType(), BoolType(), BoolType()))); s.Update(builder->GetCheckerBuilder().AddFunction(std::move(or_decl))); ASSERT_THAT(s, IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); const cel::Parser& parser = compiler->GetParser(); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("Or(a, b)")); ASSERT_OK_AND_ASSIGN(auto ast, parser.Parse(*source)); const cel::TypeChecker& checker = compiler->GetTypeChecker(); ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, checker.Check(std::move(ast))); EXPECT_TRUE(result.IsValid()); } TEST(CompilerFactoryTest, DisableStandardMacros) { CompilerOptions options; options.parser_options.disable_standard_macros = true; ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool(), options)); // Add the type checker library, but not the parser library for CEL standard. ASSERT_THAT(builder->AddLibrary(CompilerLibrary::FromCheckerLibrary( StandardCheckerLibrary())), IsOk()); ASSERT_THAT(builder->GetParserBuilder().AddMacro(cel::ExistsMacro()), IsOk()); // a: map(dyn, dyn) ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("a", MapType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a.b")); EXPECT_TRUE(result.IsValid()); // The has macro is disabled, so looks like a function call. ASSERT_OK_AND_ASSIGN(result, compiler->Compile("has(a.b)")); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), Contains(Truly([](const TypeCheckIssue& issue) { return absl::StrContains(issue.message(), "undeclared reference to 'has'"); }))); ASSERT_OK_AND_ASSIGN(result, compiler->Compile("a.exists(x, x == 'foo')")); EXPECT_TRUE(result.IsValid()); } TEST(CompilerFactoryTest, DisableStandardMacrosWithStdlib) { CompilerOptions options; options.parser_options.disable_standard_macros = true; ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool(), options)); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetParserBuilder().AddMacro(cel::ExistsMacro()), IsOk()); // a: map(dyn, dyn) ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("a", MapType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a.b")); EXPECT_TRUE(result.IsValid()); // The has macro is disabled, so looks like a function call. ASSERT_OK_AND_ASSIGN(result, compiler->Compile("has(a.b)")); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.GetIssues(), Contains(Truly([](const TypeCheckIssue& issue) { return absl::StrContains(issue.message(), "undeclared reference to 'has'"); }))); ASSERT_OK_AND_ASSIGN(result, compiler->Compile("a.exists(x, x == 'foo')")); EXPECT_TRUE(result.IsValid()); } TEST(CompilerFactoryTest, AddValidator) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); builder->GetValidator().AddValidation(TimestampLiteralValidator()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("timestamp('invalid')")); EXPECT_FALSE(result.IsValid()); ASSERT_OK_AND_ASSIGN(result, compiler->Compile("timestamp('2024-01-01T00:00:00Z')")); EXPECT_TRUE(result.IsValid()); } TEST(CompilerFactoryTest, FailsIfLibraryAddedTwice) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("library already exists: stdlib"))); } TEST(CompilerFactoryTest, FailsIfLibrarySubsetAddedTwice) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrarySubset({ .library_id = "stdlib", .should_include_macro = nullptr, .should_include_overload = nullptr, }), IsOk()); ASSERT_THAT(builder->AddLibrarySubset({ .library_id = "stdlib", .should_include_macro = nullptr, .should_include_overload = nullptr, }), StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("library subset already exists for: stdlib"))); } TEST(CompilerFactoryTest, FailsIfLibrarySubsetHasNoId) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrarySubset({ .library_id = "", .should_include_macro = nullptr, .should_include_overload = nullptr, }), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("library id must not be empty"))); } TEST(CompilerFactoryTest, FailsIfNullDescriptorPool) { std::shared_ptr pool = internal::GetSharedTestingDescriptorPool(); pool.reset(); ASSERT_THAT( NewCompilerBuilder(std::move(pool)), absl_testing::StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("descriptor_pool must not be null"))); } TEST(CompilerFactoryTest, ToBuilderWorks) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("a", MapType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); auto derived_builder = compiler->ToBuilder(); ASSERT_THAT(derived_builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto derived_compiler, derived_builder->Build()); ASSERT_OK_AND_ASSIGN( ValidationResult result, derived_compiler->Compile("has(a.b) && a.?b.orValue('foo') == 'foo'")); EXPECT_TRUE(result.IsValid()); } TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("[[1, 2, 3]][?0]", "", &arena)); ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); auto it = result.GetResolvedTypeMap().find(ast->root_expr().id()); ASSERT_TRUE(it != result.GetResolvedTypeMap().end()); EXPECT_TRUE( it->second.IsOptional() && it->second.GetOptional().GetParameter().IsList() && it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); } } // namespace } // namespace cel ================================================ FILE: compiler/compiler_library_subset_factory.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "compiler/compiler_library_subset_factory.h" #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/type_checker_subset_factory.h" #include "compiler/compiler.h" #include "parser/parser_subset_factory.h" namespace cel { CompilerLibrarySubset MakeStdlibSubset( absl::flat_hash_set macro_names, absl::flat_hash_set function_overload_ids, StdlibSubsetOptions options) { CompilerLibrarySubset subset; subset.library_id = "stdlib"; switch (options.macro_list) { case cel::StdlibSubsetOptions::ListKind::kInclude: subset.should_include_macro = IncludeMacrosByNamePredicate(std::move(macro_names)); break; case cel::StdlibSubsetOptions::ListKind::kExclude: subset.should_include_macro = ExcludeMacrosByNamePredicate(std::move(macro_names)); break; case cel::StdlibSubsetOptions::ListKind::kIgnore: subset.should_include_macro = nullptr; break; } switch (options.function_list) { case cel::StdlibSubsetOptions::ListKind::kInclude: subset.should_include_overload = IncludeOverloadsByIdPredicate(std::move(function_overload_ids)); break; case cel::StdlibSubsetOptions::ListKind::kExclude: subset.should_include_overload = ExcludeOverloadsByIdPredicate(std::move(function_overload_ids)); break; case cel::StdlibSubsetOptions::ListKind::kIgnore: subset.should_include_overload = nullptr; break; } return subset; } CompilerLibrarySubset MakeStdlibSubset( absl::Span macro_names, absl::Span function_overload_ids, StdlibSubsetOptions options) { return MakeStdlibSubset( absl::flat_hash_set(macro_names.begin(), macro_names.end()), absl::flat_hash_set(function_overload_ids.begin(), function_overload_ids.end()), options); } CompilerLibrarySubset MakeStdlibSubsetByOverloadId( absl::Span function_overload_ids, StdlibSubsetOptions options) { options.macro_list = StdlibSubsetOptions::ListKind::kIgnore; return MakeStdlibSubset({}, function_overload_ids, options); } CompilerLibrarySubset MakeStdlibSubsetByMacroName( absl::Span macro_names, StdlibSubsetOptions options) { options.function_list = StdlibSubsetOptions::ListKind::kIgnore; return MakeStdlibSubset(macro_names, {}, options); } } // namespace cel ================================================ FILE: compiler/compiler_library_subset_factory.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_LIBRARY_SUBSET_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_LIBRARY_SUBSET_FACTORY_H_ #include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "compiler/compiler.h" namespace cel { struct StdlibSubsetOptions { enum class ListKind { // Include the given list of macros or functions, default to exclude. kInclude, // Exclude the given list of macros or functions, default to include. kExclude, // Ignore the given list of macros or functions. This is used to clarify // intent of an empty list. kIgnore }; ListKind macro_list = ListKind::kInclude; ListKind function_list = ListKind::kInclude; }; // Creates a subset of the CEL standard library. // // Example usage: // // Include only the core boolean operators, and exists/all. // // std::unique_ptr builder = ...; // builder->AddLibrary(StandardCompilerLibrary()); // // Add the subset. // builder->AddLibrarySubset(MakeStdlibSubset( // {"exists", "all"}, // {"logical_and", "logical_or", "logical_not", "not_strictly_false", // "equal", "inequal"}); // // // Exclude list concatenation and map macros. // builder->AddLibrarySubset(MakeStdlibSubset( // {"map"}, // {"add_list"}, // { .macro_list = StdlibSubsetOptions::ListKind::kExclude, // .function_list = StdlibSubsetOptions::ListKind::kExclude // })); CompilerLibrarySubset MakeStdlibSubset( absl::flat_hash_set macro_names, absl::flat_hash_set function_overload_ids, StdlibSubsetOptions options = {}); CompilerLibrarySubset MakeStdlibSubset( absl::Span macro_names, absl::Span function_overload_ids, StdlibSubsetOptions options = {}); CompilerLibrarySubset MakeStdlibSubsetByOverloadId( absl::Span function_overload_ids, StdlibSubsetOptions options = {}); CompilerLibrarySubset MakeStdlibSubsetByMacroName( absl::Span macro_names, StdlibSubsetOptions options = {}); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMPILER_COMPILER_LIBRARY_SUBSET_FACTORY_H_ ================================================ FILE: compiler/compiler_library_subset_factory_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "compiler/compiler_library_subset_factory.h" #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/validation_result.h" #include "common/standard_definitions.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" using ::absl_testing::IsOk; using ::testing::Not; namespace cel { namespace { MATCHER(IsValid, "") { const absl::StatusOr& result = arg; if (!result.ok()) { (*result_listener) << "compilation failed: " << result.status(); return false; } if (!result->GetIssues().empty()) { (*result_listener) << "compilation issues: \n" << result->FormatError(); } return result->IsValid(); } TEST(CompilerLibrarySubsetFactoryTest, MakeStdlibSubsetInclude) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT( builder->AddLibrarySubset(MakeStdlibSubset( {"exists", "all"}, {StandardOverloadIds::kAnd, StandardOverloadIds::kOr, StandardOverloadIds::kNot, StandardOverloadIds::kNotStrictlyFalse, StandardOverloadIds::kEquals, StandardOverloadIds::kNotEquals})), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); EXPECT_THAT( compiler->Compile( "[1, 2, 3].exists(x, x != 1 || x == 2 && !(x == 4 || x == 5) )"), IsValid()); EXPECT_THAT(compiler->Compile("1+2"), Not(IsValid())); EXPECT_THAT(compiler->Compile("[1, 2, 3].map(x, x)"), Not(IsValid())); } TEST(CompilerLibrarySubsetFactoryTest, MakeStdlibSubsetExclude) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrarySubset(MakeStdlibSubset( absl::flat_hash_set({"map"}), {"add_list"}, {.macro_list = StdlibSubsetOptions::ListKind::kExclude, .function_list = StdlibSubsetOptions::ListKind::kExclude})), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); EXPECT_THAT( compiler->Compile( "[1, 2, 3].exists(x, x != 1 || x == 2 && !(x == 4 || x == 5) )"), IsValid()); EXPECT_THAT(compiler->Compile("1+2"), IsValid()); EXPECT_THAT(compiler->Compile("[1, 2, 3].map(x, x)"), Not(IsValid())); EXPECT_THAT(compiler->Compile("[2] + [1]"), Not(IsValid())); } TEST(CompilerLibrarySubsetFactoryTest, MakeStdlibSubsetByMacroName) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); absl::string_view kMacroNames[] = {"map"}; ASSERT_THAT(builder->AddLibrarySubset(MakeStdlibSubsetByMacroName( kMacroNames, {.macro_list = StdlibSubsetOptions::ListKind::kExclude})), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); EXPECT_THAT( compiler->Compile( "[1, 2, 3].exists(x, x != 1 || x == 2 && !(x == 4 || x == 5) )"), IsValid()); EXPECT_THAT(compiler->Compile("1+2"), IsValid()); EXPECT_THAT(compiler->Compile("[1, 2, 3].map(x, x)"), Not(IsValid())); EXPECT_THAT(compiler->Compile("[2] + [1]"), IsValid()); } TEST(CompilerLibrarySubsetFactoryTest, MakeStdlibSubsetByOverloadId) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); absl::string_view kOverloadIds[] = {"add_list", "add_string"}; ASSERT_THAT(builder->AddLibrarySubset(MakeStdlibSubsetByOverloadId( kOverloadIds, {// unused .macro_list = StdlibSubsetOptions::ListKind::kInclude, .function_list = StdlibSubsetOptions::ListKind::kExclude})), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); EXPECT_THAT( compiler->Compile( "[1, 2, 3].exists(x, x != 1 || x == 2 && !(x == 4 || x == 5) )"), IsValid()); EXPECT_THAT(compiler->Compile("1+2"), IsValid()); EXPECT_THAT(compiler->Compile("[1, 2, 3].map(x, x)"), Not(IsValid())); EXPECT_THAT(compiler->Compile("[2] + [1]"), Not(IsValid())); } } // namespace } // namespace cel ================================================ FILE: compiler/optional.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "compiler/optional.h" #include "absl/status/status.h" #include "checker/optional.h" #include "compiler/compiler.h" #include "parser/macro.h" #include "parser/parser_interface.h" namespace cel { CompilerLibrary OptionalCompilerLibrary(int version) { CompilerLibrary library = CompilerLibrary::FromCheckerLibrary(OptionalCheckerLibrary(version)); library.configure_parser = [version](ParserBuilder& builder) { builder.GetOptions().enable_optional_syntax = true; absl::Status status; status.Update(builder.AddMacro(OptMapMacro())); if (version == 0) { return status; } status.Update(builder.AddMacro(OptFlatMapMacro())); return status; }; return library; } } // namespace cel ================================================ FILE: compiler/optional.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ #define THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ #include "checker/optional.h" #include "compiler/compiler.h" namespace cel { // CompilerLibrary that enables support for CEL optional types. CompilerLibrary OptionalCompilerLibrary( int version = kOptionalExtensionLatestVersion); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMPILER_OPTIONALS_H_ ================================================ FILE: compiler/optional_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "compiler/optional.h" #include #include #include #include #include "absl/algorithm/container.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/decl.h" #include "common/source.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "testutil/baseline_tests.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::test::FormatBaselineAst; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::ValuesIn; struct TestCase { std::string expr; std::string expected_ast; }; class OptionalTest : public testing::TestWithParam {}; std::string FormatIssues(const ValidationResult& result) { const Source* source = result.GetSource(); return absl::StrJoin( result.GetIssues(), "\n", [=](std::string* out, const TypeCheckIssue& issue) { absl::StrAppend( out, (source) ? issue.ToDisplayString(*source) : issue.message()); }); } TEST_P(OptionalTest, OptionalsEnabled) { const TestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddVariable(MakeVariableDecl( "msg", MessageType(TestAllTypes::descriptor()))), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); absl::StatusOr maybe_result = compiler->Compile(test_case.expr); ASSERT_OK_AND_ASSIGN(ValidationResult result, std::move(maybe_result)); ASSERT_TRUE(result.IsValid()) << FormatIssues(result); EXPECT_EQ(FormatBaselineAst(*result.GetAst()), absl::StripAsciiWhitespace(test_case.expected_ast)) << test_case.expr; } INSTANTIATE_TEST_SUITE_P( OptionalTest, OptionalTest, ::testing::Values( TestCase{ .expr = "msg.?single_int64", .expected_ast = R"( _?._( msg~cel.expr.conformance.proto3.TestAllTypes^msg, "single_int64" )~optional_type(int)^select_optional_field)", }, TestCase{ .expr = "optional.of('foo')", .expected_ast = R"( optional.of( "foo"~string )~optional_type(string)^optional_of)", }, TestCase{ .expr = "optional.of('foo').optMap(x, x)", .expected_ast = R"( _?_:_( optional.of( "foo"~string )~optional_type(string)^optional_of.hasValue()~bool^optional_hasValue, optional.of( __comprehension__( // Variable #unused, // Target []~list(dyn), // Accumulator x, // Init optional.of( "foo"~string )~optional_type(string)^optional_of.value()~string^optional_value, // LoopCondition false~bool, // LoopStep x~string^x, // Result x~string^x)~string )~optional_type(string)^optional_of, optional.none()~optional_type(string)^optional_none )~optional_type(string)^conditional )", }, TestCase{ .expr = "optional.of('foo').optFlatMap(x, optional.of(x))", .expected_ast = R"( _?_:_( optional.of( "foo"~string )~optional_type(string)^optional_of.hasValue()~bool^optional_hasValue, __comprehension__( // Variable #unused, // Target []~list(dyn), // Accumulator x, // Init optional.of( "foo"~string )~optional_type(string)^optional_of.value()~string^optional_value, // LoopCondition false~bool, // LoopStep x~string^x, // Result optional.of( x~string^x )~optional_type(string)^optional_of)~optional_type(string), optional.none()~optional_type(string)^optional_none )~optional_type(string)^conditional )", }, TestCase{ .expr = "optional.ofNonZeroValue(1)", .expected_ast = R"( optional.ofNonZeroValue( 1~int )~optional_type(int)^optional_ofNonZeroValue )", }, TestCase{ .expr = "[0][?1]", .expected_ast = R"( _[?_]( [ 0~int ]~list(int), 1~int )~optional_type(int)^list_optindex_optional_int )", }, TestCase{ .expr = "{0: 2}[?1]", .expected_ast = R"( _[?_]( { 0~int:2~int }~map(int, int), 1~int )~optional_type(int)^map_optindex_optional_value )", }, TestCase{ .expr = "msg.?repeated_int64[1]", .expected_ast = R"( _[_]( _?._( msg~cel.expr.conformance.proto3.TestAllTypes^msg, "repeated_int64" )~optional_type(list(int))^select_optional_field, 1~int )~optional_type(int)^optional_list_index_int )", }, TestCase{ .expr = "msg.?map_int64_int64[1]", .expected_ast = R"( _[_]( _?._( msg~cel.expr.conformance.proto3.TestAllTypes^msg, "map_int64_int64" )~optional_type(map(int, int))^select_optional_field, 1~int )~optional_type(int)^optional_map_index_value )", }, TestCase{ .expr = "optional.of(1).or(optional.of(2))", .expected_ast = R"( optional.of( 1~int )~optional_type(int)^optional_of.or( optional.of( 2~int )~optional_type(int)^optional_of )~optional_type(int)^optional_or_optional)", }, TestCase{ .expr = "optional.of(1).orValue(2)", .expected_ast = R"( optional.of( 1~int )~optional_type(int)^optional_of.orValue( 2~int )~int^optional_orValue_value )", }, TestCase{ .expr = "optional.of(1).value()", .expected_ast = R"( optional.of( 1~int )~optional_type(int)^optional_of.value()~int^optional_value )", }, TestCase{ .expr = "optional.of(1).hasValue()", .expected_ast = R"( optional.of( 1~int )~optional_type(int)^optional_of.hasValue()~bool^optional_hasValue )", })); TEST(OptionalTest, NotEnabled) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddVariable(MakeVariableDecl( "msg", MessageType(TestAllTypes::descriptor()))), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("optional.of(1)")); EXPECT_THAT(FormatIssues(result), HasSubstr("undeclared reference to 'optional'")); } struct OptionalExtensionVersionTestCase { std::string expr; std::vector expected_supported_versions; }; class OptionalExtensionVersionTest : public ::testing::TestWithParam {}; TEST_P(OptionalExtensionVersionTest, OptionalExtensionVersions) { const OptionalExtensionVersionTestCase& test_case = GetParam(); for (int version = 0; version <= cel::kOptionalExtensionLatestVersion; ++version) { CompilerLibrary compiler_library = OptionalCompilerLibrary(version); CompilerOptions compiler_options; compiler_options.parser_options.enable_optional_syntax = true; ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), compiler_options)); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(test_case.expr)); if (absl::c_contains(test_case.expected_supported_versions, version)) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << "Expected no issues for expr: " << test_case.expr << " at version: " << version << " but got: " << result.FormatError(); } else { EXPECT_THAT(result.GetIssues(), Contains(Property(&TypeCheckIssue::message, HasSubstr("undeclared reference")))) << "Expected undeclared reference for expr: " << test_case.expr << " at version: " << version; } } }; std::vector CreateOptionalExtensionVersionParams() { return { OptionalExtensionVersionTestCase{ .expr = "optional_type", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "optional.of('foo').optMap(x, x)", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "optional.of('foo')", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "optional.ofNonZeroValue(1)", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "optional.of('foo').value()", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "optional.of('foo').hasValue()", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "optional.of(1).or(optional.of(2))", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "optional.of(1).orValue(2)", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "[1, 2, 3][?5]", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "dyn(1).?bar", .expected_supported_versions = {0, 1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "optional.of('foo').optFlatMap(x, optional.of(x))", .expected_supported_versions = {1, 2}, }, OptionalExtensionVersionTestCase{ .expr = "[1, 2, 3].first()", .expected_supported_versions = {2}, }, OptionalExtensionVersionTestCase{ .expr = "[1, 2, 3].last()", .expected_supported_versions = {2}, }, }; } INSTANTIATE_TEST_SUITE_P(OptionalExtensionVersionTest, OptionalExtensionVersionTest, ValuesIn(CreateOptionalExtensionVersionParams())); } // namespace } // namespace cel ================================================ FILE: compiler/standard_library.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "compiler/standard_library.h" #include "absl/status/status.h" #include "checker/standard_library.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "parser/parser_interface.h" namespace cel { namespace { absl::Status AddStandardLibraryMacros(ParserBuilder& builder) { // For consistency with the Parse free functions, follow the convenience // option to disable all the standard macros. if (builder.GetOptions().disable_standard_macros) { return absl::OkStatus(); } for (const auto& macro : Macro::AllMacros()) { CEL_RETURN_IF_ERROR(builder.AddMacro(macro)); } return absl::OkStatus(); } } // namespace CompilerLibrary StandardCompilerLibrary() { CompilerLibrary library = CompilerLibrary::FromCheckerLibrary(StandardCheckerLibrary()); library.configure_parser = AddStandardLibraryMacros; return library; } } // namespace cel ================================================ FILE: compiler/standard_library.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMPILER_STANDARD_LIBRARY_H_ #define THIRD_PARTY_CEL_CPP_COMPILER_STANDARD_LIBRARY_H_ #include "compiler/compiler.h" namespace cel { // Returns a CompilerLibrary containing all of the standard CEL declarations // and macros. CompilerLibrary StandardCompilerLibrary(); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMPILER_STANDARD_LIBRARY_H_ ================================================ FILE: conformance/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("//conformance:run.bzl", "gen_conformance_tests") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "service", testonly = True, srcs = ["service.cc"], hdrs = ["service.h"], deps = [ "//checker:optional", "//checker:standard_library", "//checker:type_checker_builder", "//checker:type_checker_builder_factory", "//common:ast", "//common:ast_proto", "//common:decl_proto_v1alpha1", "//common:expr", "//common:source", "//common:value", "//common/internal:value_conversion", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:transform_utility", "//extensions:bindings_ext", "//extensions:comprehensions_v2", "//extensions:comprehensions_v2_functions", "//extensions:comprehensions_v2_macros", "//extensions:encoders", "//extensions:math_ext", "//extensions:math_ext_decls", "//extensions:math_ext_macros", "//extensions:proto_ext", "//extensions:select_optimization", "//extensions:strings", "//extensions/protobuf:enum_adapter", "//internal:status_macros", "//parser", "//parser:macro", "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:standard_macros", "//runtime", "//runtime:activation", "//runtime:constant_folding", "//runtime:optional_types", "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_library( name = "run", testonly = True, srcs = ["run.cc"], deps = [ ":service", ":utils", "//internal:testing_no_main", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/test:simple_cc_proto", "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//src/google/protobuf/io", ], alwayslink = True, ) cc_library( name = "utils", testonly = True, hdrs = ["utils.h"], deps = [ "//internal:testing_no_main", "@com_google_absl//absl/log:absl_check", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) _ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/basic.textproto", "@com_google_cel_spec//tests/simple:testdata/bindings_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/comparisons.textproto", "@com_google_cel_spec//tests/simple:testdata/conversions.textproto", "@com_google_cel_spec//tests/simple:testdata/dynamic.textproto", "@com_google_cel_spec//tests/simple:testdata/encoders_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/enums.textproto", "@com_google_cel_spec//tests/simple:testdata/fields.textproto", "@com_google_cel_spec//tests/simple:testdata/fp_math.textproto", "@com_google_cel_spec//tests/simple:testdata/integer_math.textproto", "@com_google_cel_spec//tests/simple:testdata/lists.textproto", "@com_google_cel_spec//tests/simple:testdata/logic.textproto", "@com_google_cel_spec//tests/simple:testdata/macros.textproto", "@com_google_cel_spec//tests/simple:testdata/macros2.textproto", "@com_google_cel_spec//tests/simple:testdata/math_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", "@com_google_cel_spec//tests/simple:testdata/optionals.textproto", "@com_google_cel_spec//tests/simple:testdata/parse.textproto", "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", "@com_google_cel_spec//tests/simple:testdata/proto2_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", "@com_google_cel_spec//tests/simple:testdata/string.textproto", "@com_google_cel_spec//tests/simple:testdata/string_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/timestamps.textproto", "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", "@com_google_cel_spec//tests/simple:testdata/wrappers.textproto", "@com_google_cel_spec//tests/simple:testdata/block_ext.textproto", "@com_google_cel_spec//tests/simple:testdata/type_deduction.textproto", ] _TESTS_TO_SKIP = [ # Tests which require spec changes. # TODO(issues/93): Deprecate Duration.getMilliseconds. "timestamps/duration_converters/get_milliseconds", # Broken test cases which should be supported. # TODO(issues/112): Unbound functions result in empty eval response. "basic/functions/unbound", "basic/functions/unbound_is_runtime_error", # TODO(issues/97): Parse-only qualified variable lookup "x.y" with binding "x.y" or "y" within container "x" fails "fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", "namespace/qualified/self_eval_qualified_lookup", "namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", # TODO(issues/117): Integer overflow on enum assignments should error. "enums/legacy_proto2/select_big,select_neg", # Skip until fixed. "wrappers/field_mask/to_json", "wrappers/empty/to_json", "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", "parse/receiver_function_names", # Future features for CEL 1.0 # TODO(issues/119): Strong typing support for enums, specified but not implemented. "enums/strong_proto2", "enums/strong_proto3", # These depend on legacy US/ timezones. It's spotty if these are included with a normally # configured timezone database. "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", "timestamps/timestamp_selectors_tz/getDayOfYear", # These depend on using charconv (or equivalent) to format doubles with shortest possible # precision to preserve value. Not available on older compilers where we just use absl::Format. # We should probably update the spec to allow different formats that parse to the same value. "conversions/string/double_hard", ] _TESTS_TO_SKIP_MODERN = _TESTS_TO_SKIP _TESTS_TO_SKIP_MODERN_DASHBOARD = [ # Future features for CEL 1.0 # TODO(issues/119): Strong typing support for enums, specified but not implemented. "enums/strong_proto2", "enums/strong_proto3", ] _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ # Legacy value does not support optional_type. "optionals/optionals", # TODO(uncreated-issue/81): Fix null assignment to a field "proto2/set_null/list_value", "proto2/set_null/single_struct", "proto3/set_null/list_value", "proto3/set_null/single_struct", # cel.@block "block_ext/basic/optional_list", "block_ext/basic/optional_map", "block_ext/basic/optional_map_chained", "block_ext/basic/optional_message", ] _TESTS_TO_SKIP_CHECKED = [ # block is a post-check optimization that inserts internal variables. The C++ type checker # needs support for a proper optimizer for this to work. "block_ext", ] _TESTS_TO_SKIP_LEGACY_DASHBOARD = [ # Future features for CEL 1.0 # TODO(issues/119): Strong typing support for enums, specified but not implemented. "enums/strong_proto2", "enums/strong_proto3", # Legacy value does not support optional_type. "optionals/optionals", ] # Generates a bunch of `cc_test` whose names follow the pattern # `conformance_(...)_{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`. gen_conformance_tests( name = "conformance_parse_only", data = _ALL_TESTS, modern = True, skip_tests = _TESTS_TO_SKIP_MODERN + ["type_deductions"], ) gen_conformance_tests( name = "conformance_legacy_parse_only", data = _ALL_TESTS, modern = False, skip_tests = _TESTS_TO_SKIP_LEGACY + ["type_deductions"], ) gen_conformance_tests( name = "conformance_checked", checked = True, data = _ALL_TESTS, modern = True, skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, ) gen_conformance_tests( name = "conformance_legacy_checked", checked = True, data = _ALL_TESTS, modern = False, skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, ) # select optimization is only supported for checked expressions. gen_conformance_tests( name = "conformance_legacy_select_opt", checked = True, data = _ALL_TESTS, modern = False, select_opt = True, skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, ) gen_conformance_tests( name = "conformance_select_opt", checked = True, data = _ALL_TESTS, modern = True, select_opt = True, skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, ) # Generates a bunch of `cc_test` whose names follow the pattern # `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`. gen_conformance_tests( name = "conformance_dashboard_parse_only", dashboard = True, data = _ALL_TESTS, modern = True, skip_tests = _TESTS_TO_SKIP_MODERN_DASHBOARD + ["type_deductions"], tags = [ "guitar", "notap", ], ) gen_conformance_tests( name = "conformance_dashboard_checked", checked = True, dashboard = True, data = _ALL_TESTS, modern = True, skip_tests = _TESTS_TO_SKIP_MODERN_DASHBOARD, tags = [ "guitar", "notap", ], ) gen_conformance_tests( name = "conformance_dashboard_legacy_parse_only", dashboard = True, data = _ALL_TESTS, modern = False, skip_tests = _TESTS_TO_SKIP_LEGACY_DASHBOARD + ["type_deductions"], tags = [ "guitar", "notap", ], ) ================================================ FILE: conformance/run.bzl ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module contains build rules for generating the conformance test targets. """ load("@rules_cc//cc:cc_test.bzl", "cc_test") _TESTS_TO_SKIP_WINDOWS = [ # These tests depend on configuring a timezone database which isn't available in our windows # test environment. "timestamps/timestamp_selectors_tz/getDate", "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", "timestamps/timestamp_selectors_tz/getDayOfMonth_name_neg", "timestamps/timestamp_selectors_tz/getDayOfYear", "timestamps/timestamp_selectors_tz/getMinutes", ] # Converts the list of tests to skip from the format used by the original Go test runner to a single # flag value where each test is separated by a comma. It also performs expansion, for example # `foo/bar,baz` becomes two entries which are `foo/bar` and `foo/baz`. def _expand_tests_to_skip(tests_to_skip): result = [] for test_to_skip in tests_to_skip: comma = test_to_skip.find(",") if comma == -1: result.append(test_to_skip) continue slash = test_to_skip.rfind("/", 0, comma) if slash == -1: slash = 0 else: slash = slash + 1 for part in test_to_skip[slash:].split(","): result.append(test_to_skip[0:slash] + part) return result def _conformance_test_name(name, optimize, recursive): return "_".join( [ name, "optimized" if optimize else "unoptimized", "recursive" if recursive else "iterative", ], ) def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard): args = [] if modern: args.append("--modern") if optimize: args.append("--opt") if select_opt: args.append("--select_optimization") if recursive: args.append("--recursive") if skip_check: args.append("--skip_check") else: args.append("--noskip_check") args.append("--skip_tests={}".format(",".join(_expand_tests_to_skip(skip_tests)))) if dashboard: args.append("--dashboard") return args def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): cc_test( name = _conformance_test_name(name, optimize, recursive), args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data] + select( { "@platforms//os:windows": ["--skip_tests={}".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], "//conditions:default": ["--skip_tests={}".format(",".join(skip_tests))], }, ), data = data, deps = ["//conformance:run"], tags = tags, ) def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = []): """Generates conformance tests. Args: name: prefix for all tests modern: run using modern APIs checked: whether to apply type checking data: textproto targets describing conformance tests skip_tests: tests to skip in the format of the cel-spec test runner. See documentation in github.com/google/cel-spec/tests/simple/simple_test.go tags: tags added to the generated targets dashboard: enable dashboard mode """ skip_check = not checked tests = [] for optimize in (True, False): for recursive in (True, False): test_name = _conformance_test_name(name, optimize, recursive) tests.append(test_name) _conformance_test( name, data, modern = modern, optimize = optimize, recursive = recursive, select_opt = select_opt, skip_check = skip_check, skip_tests = _expand_tests_to_skip(skip_tests), tags = tags, dashboard = dashboard, ) native.test_suite( name = name, tests = tests, tags = tags, ) ================================================ FILE: conformance/run.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This file is a native C++ implementation of the original Go conformance test // runner located at // https://github.com/google/cel-spec/tree/master/tests/simple. It was ported to // C++ to avoid having to pull in Go, gRPC, and others just to run C++ // conformance tests; as well as integrating better with C++ testing // infrastructure. #include #include #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" #include "cel/expr/eval.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" // IWYU pragma: keep #include "google/api/expr/v1alpha1/eval.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" // IWYU pragma: keep #include "google/api/expr/v1alpha1/value.pb.h" #include "cel/expr/value.pb.h" #include "google/rpc/code.pb.h" #include "absl/flags/flag.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/span.h" #include "conformance/service.h" #include "conformance/utils.h" #include "internal/testing.h" #include "cel/expr/conformance/test/simple.pb.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" ABSL_FLAG(bool, opt, false, "Enable optimizations (constant folding)"); ABSL_FLAG( bool, modern, false, "Use modern cel::Value APIs implementation of the conformance service."); ABSL_FLAG(bool, recursive, false, "Enable recursive plans. Depth limited to slightly more than the " "default nesting limit."); ABSL_FLAG(std::vector, skip_tests, {}, "Tests to skip"); ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures"); ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions"); ABSL_FLAG(bool, select_optimization, false, "Enable select optimization."); namespace { using ::testing::IsEmpty; using cel::expr::conformance::test::SimpleTest; using cel::expr::conformance::test::SimpleTestFile; using google::api::expr::conformance::v1alpha1::CheckRequest; using google::api::expr::conformance::v1alpha1::CheckResponse; using google::api::expr::conformance::v1alpha1::EvalRequest; using google::api::expr::conformance::v1alpha1::EvalResponse; using google::api::expr::conformance::v1alpha1::ParseRequest; using google::api::expr::conformance::v1alpha1::ParseResponse; google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } bool ShouldSkipTest(absl::Span tests_to_skip, absl::string_view name) { for (absl::string_view test_to_skip : tests_to_skip) { auto consumed_name = name; if (absl::ConsumePrefix(&consumed_name, test_to_skip) && (consumed_name.empty() || absl::StartsWith(consumed_name, "/"))) { return true; } } return false; } SimpleTest DefaultTestMatcherToTrueIfUnset(const SimpleTest& test) { auto test_copy = test; if (test_copy.result_matcher_case() == SimpleTest::RESULT_MATCHER_NOT_SET) { test_copy.mutable_value()->set_bool_value(true); } return test_copy; } class ConformanceTest : public testing::Test { public: explicit ConformanceTest( std::shared_ptr service, const SimpleTest& test, bool skip) : service_(std::move(service)), test_(DefaultTestMatcherToTrueIfUnset(test)), skip_(skip) {} void TestBody() override { if (skip_) { GTEST_SKIP(); } ParseRequest parse_request; parse_request.set_cel_source(test_.expr()); parse_request.set_source_location(test_.name()); parse_request.set_disable_macros(test_.disable_macros()); ParseResponse parse_response; service_->Parse(parse_request, parse_response); ASSERT_THAT(parse_response.issues(), IsEmpty()); EvalRequest eval_request; if (!test_.container().empty()) { eval_request.set_container(test_.container()); } if (!test_.bindings().empty()) { for (const auto& binding : test_.bindings()) { absl::Cord serialized; ABSL_CHECK(binding.second.SerializePartialToString(&serialized)); ABSL_CHECK((*eval_request.mutable_bindings())[binding.first] .ParsePartialFromString(serialized)); } } if (absl::GetFlag(FLAGS_skip_check) || test_.disable_check()) { eval_request.set_allocated_parsed_expr( parse_response.release_parsed_expr()); } else { CheckRequest check_request; check_request.set_allocated_parsed_expr( parse_response.release_parsed_expr()); check_request.set_container(test_.container()); for (const auto& type_env : test_.type_env()) { absl::Cord serialized; ABSL_CHECK(type_env.SerializePartialToString(&serialized)); ABSL_CHECK( check_request.add_type_env()->ParsePartialFromString(serialized)); } CheckResponse check_response; service_->Check(check_request, check_response); ASSERT_THAT(check_response.issues(), IsEmpty()) << absl::StrCat( "unexpected type check issues for: '", test_.expr(), "'\n"); eval_request.set_allocated_checked_expr( check_response.release_checked_expr()); } if (test_.check_only()) { ASSERT_TRUE(test_.has_typed_result()) << "test must specify a typed result if check_only is set"; EXPECT_THAT(eval_request.checked_expr(), cel_conformance::ResultTypeMatches( test_.typed_result().deduced_type())); return; } EvalResponse eval_response; if (auto status = service_->Eval(eval_request, eval_response); !status.ok()) { auto* issue = eval_response.add_issues(); issue->set_message(status.message()); issue->set_code(ToGrpcCode(status.code())); } ASSERT_TRUE(eval_response.has_result()) << eval_response; switch (test_.result_matcher_case()) { case SimpleTest::kValue: { absl::Cord serialized; ABSL_CHECK( eval_response.result().SerializePartialToString(&serialized)); cel::expr::ExprValue test_value; ABSL_CHECK(test_value.ParsePartialFromString(serialized)); EXPECT_THAT(test_value, cel_conformance::MatchesConformanceValue(test_.value())); break; } case SimpleTest::kTypedResult: { ASSERT_TRUE(eval_request.has_checked_expr()) << "expression was not type checked"; absl::Cord serialized; ABSL_CHECK( eval_response.result().SerializePartialToString(&serialized)); cel::expr::ExprValue test_value; ABSL_CHECK(test_value.ParsePartialFromString(serialized)); EXPECT_THAT(test_value, cel_conformance::MatchesConformanceValue( test_.typed_result().result())); EXPECT_THAT(eval_request.checked_expr(), cel_conformance::ResultTypeMatches( test_.typed_result().deduced_type())); break; } case SimpleTest::kEvalError: EXPECT_TRUE(eval_response.result().has_error()) << eval_response.result(); break; default: ADD_FAILURE() << "unexpected matcher kind: " << test_.result_matcher_case(); break; } } private: const std::shared_ptr service_; const SimpleTest test_; const bool skip_; }; absl::Status RegisterTestsFromFile( const std::shared_ptr& service, absl::Span tests_to_skip, absl::string_view path) { SimpleTestFile file; { std::ifstream in; in.open(std::string(path), std::ios_base::in | std::ios_base::binary); if (!in.is_open()) { return absl::UnknownError(absl::StrCat("failed to open file: ", path)); } google::protobuf::io::IstreamInputStream stream(&in); if (!google::protobuf::TextFormat::Parse(&stream, &file)) { return absl::UnknownError(absl::StrCat("failed to parse file: ", path)); } } for (const auto& section : file.section()) { for (const auto& test : section.test()) { const bool skip = ShouldSkipTest( tests_to_skip, absl::StrCat(file.name(), "/", section.name(), "/", test.name())); testing::RegisterTest( file.name().c_str(), absl::StrCat(section.name(), "/", test.name()).c_str(), nullptr, nullptr, __FILE__, __LINE__, [=]() -> ConformanceTest* { return new ConformanceTest(service, test, skip); }); } } return absl::OkStatus(); } // We could push this do be done per test or suite, but to avoid changing more // than necessary we do it once to mimic the previous runner. std::shared_ptr NewConformanceServiceFromFlags() { auto status_or_service = cel_conformance::NewConformanceService( cel_conformance::ConformanceServiceOptions{ .optimize = absl::GetFlag(FLAGS_opt), .modern = absl::GetFlag(FLAGS_modern), .recursive = absl::GetFlag(FLAGS_recursive), .select_optimization = absl::GetFlag(FLAGS_select_optimization), }); ABSL_CHECK_OK(status_or_service); return std::shared_ptr( std::move(*status_or_service)); } } // namespace int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); { auto service = NewConformanceServiceFromFlags(); auto tests_to_skip = absl::GetFlag(FLAGS_skip_tests); for (int argi = 1; argi < argc; argi++) { ABSL_CHECK_OK(RegisterTestsFromFile(service, tests_to_skip, absl::string_view(argv[argi]))); } } int exit_code = RUN_ALL_TESTS(); if (absl::GetFlag(FLAGS_dashboard)) { exit_code = EXIT_SUCCESS; } return exit_code; } ================================================ FILE: conformance/service.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "conformance/service.h" #include #include #include #include #include #include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" #include "cel/expr/syntax.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/eval.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/api/expr/v1alpha1/value.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/rpc/code.pb.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "common/ast.h" #include "common/ast_proto.h" #include "common/decl_proto_v1alpha1.h" #include "common/expr.h" #include "common/internal/value_conversion.h" #include "common/source.h" #include "common/value.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/transform_utility.h" #include "extensions/bindings_ext.h" #include "extensions/comprehensions_v2.h" #include "extensions/comprehensions_v2_functions.h" #include "extensions/comprehensions_v2_macros.h" #include "extensions/encoders.h" #include "extensions/math_ext.h" #include "extensions/math_ext_decls.h" #include "extensions/math_ext_macros.h" #include "extensions/proto_ext.h" #include "extensions/protobuf/enum_adapter.h" #include "extensions/select_optimization.h" #include "extensions/strings.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" #include "parser/standard_macros.h" #include "runtime/activation.h" #include "runtime/constant_folding.h" #include "runtime/optional_types.h" #include "runtime/reference_resolver.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" using ::cel::CreateStandardRuntimeBuilder; using ::cel::Runtime; using ::cel::RuntimeOptions; using ::cel::extensions::RegisterProtobufEnum; using ::cel::test::ConvertWireCompatProto; using ::cel::test::FromExprValue; using ::cel::test::ToExprValue; using ::google::protobuf::Arena; namespace google::api::expr::runtime { namespace { bool IsCelNamespace(const cel::Expr& target) { return target.has_ident_expr() && target.ident_expr().name() == "cel"; } absl::optional CelBlockMacroExpander(cel::MacroExprFactory& factory, cel::Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } cel::Expr& bindings_arg = args[0]; if (!bindings_arg.has_list_expr()) { return factory.ReportErrorAt( bindings_arg, "cel.block requires the first arg to be a list literal"); } return factory.NewCall("cel.@block", args); } absl::optional CelIndexMacroExpander(cel::MacroExprFactory& factory, cel::Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } cel::Expr& index_arg = args[0]; if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { return factory.ReportErrorAt( index_arg, "cel.index requires a single non-negative int constant arg"); } int64_t index = index_arg.const_expr().int_value(); if (index < 0) { return factory.ReportErrorAt( index_arg, "cel.index requires a single non-negative int constant arg"); } return factory.NewIdent(absl::StrCat("@index", index)); } absl::optional CelIterVarMacroExpander( cel::MacroExprFactory& factory, cel::Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } cel::Expr& depth_arg = args[0]; if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || depth_arg.const_expr().int_value() < 0) { return factory.ReportErrorAt( depth_arg, "cel.iterVar requires two non-negative int constant args"); } cel::Expr& unique_arg = args[1]; if (!unique_arg.has_const_expr() || !unique_arg.const_expr().has_int_value() || unique_arg.const_expr().int_value() < 0) { return factory.ReportErrorAt( unique_arg, "cel.iterVar requires two non-negative int constant args"); } return factory.NewIdent( absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", unique_arg.const_expr().int_value())); } absl::optional CelAccuVarMacroExpander( cel::MacroExprFactory& factory, cel::Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } cel::Expr& depth_arg = args[0]; if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || depth_arg.const_expr().int_value() < 0) { return factory.ReportErrorAt( depth_arg, "cel.accuVar requires two non-negative int constant args"); } cel::Expr& unique_arg = args[1]; if (!unique_arg.has_const_expr() || !unique_arg.const_expr().has_int_value() || unique_arg.const_expr().int_value() < 0) { return factory.ReportErrorAt( unique_arg, "cel.accuVar requires two non-negative int constant args"); } return factory.NewIdent( absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", unique_arg.const_expr().int_value())); } absl::Status RegisterCelBlockMacros(cel::MacroRegistry& registry) { CEL_ASSIGN_OR_RETURN(auto block_macro, cel::Macro::Receiver("block", 2, CelBlockMacroExpander)); CEL_RETURN_IF_ERROR(registry.RegisterMacro(block_macro)); CEL_ASSIGN_OR_RETURN(auto index_macro, cel::Macro::Receiver("index", 1, CelIndexMacroExpander)); CEL_RETURN_IF_ERROR(registry.RegisterMacro(index_macro)); CEL_ASSIGN_OR_RETURN( auto iter_var_macro, cel::Macro::Receiver("iterVar", 2, CelIterVarMacroExpander)); CEL_RETURN_IF_ERROR(registry.RegisterMacro(iter_var_macro)); CEL_ASSIGN_OR_RETURN( auto accu_var_macro, cel::Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander)); CEL_RETURN_IF_ERROR(registry.RegisterMacro(accu_var_macro)); return absl::OkStatus(); } google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } using ConformanceServiceInterface = ::cel_conformance::ConformanceServiceInterface; // Return a normalized raw expr for evaluation. cel::expr::Expr ExtractExpr( const conformance::v1alpha1::EvalRequest& request) { const v1alpha1::Expr* expr = nullptr; // For now, discard type-check information if any. if (request.has_parsed_expr()) { expr = &request.parsed_expr().expr(); } else if (request.has_checked_expr()) { expr = &request.checked_expr().expr(); } cel::expr::Expr out; if (expr != nullptr) { ABSL_CHECK(ConvertWireCompatProto(*expr, &out)); // Crash OK } return out; } absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response, bool enable_optional_syntax) { if (request.cel_source().empty()) { return absl::InvalidArgumentError("no source code"); } cel::ParserOptions options; options.enable_optional_syntax = enable_optional_syntax; options.enable_quoted_identifiers = true; cel::MacroRegistry macros; CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options)); CEL_RETURN_IF_ERROR( cel::extensions::RegisterComprehensionsV2Macros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); CEL_RETURN_IF_ERROR(RegisterCelBlockMacros(macros)); CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(request.cel_source(), request.source_location())); CEL_ASSIGN_OR_RETURN(auto parsed_expr, parser::Parse(*source, macros, options)); ABSL_CHECK( // Crash OK ConvertWireCompatProto(parsed_expr, response.mutable_parsed_expr())); return absl::OkStatus(); } absl::Status CheckImpl(google::protobuf::Arena* arena, const conformance::v1alpha1::CheckRequest& request, conformance::v1alpha1::CheckResponse& response) { cel::expr::ParsedExpr parsed_expr; ABSL_CHECK(ConvertWireCompatProto(request.parsed_expr(), // Crash OK &parsed_expr)); CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, cel::CreateAstFromParsedExpr(parsed_expr)); absl::string_view location = parsed_expr.source_info().location(); std::unique_ptr source; if (absl::StartsWith(location, "Source: ")) { location = absl::StripPrefix(location, "Source: "); CEL_ASSIGN_OR_RETURN(source, cel::NewSource(location)); } CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::CreateTypeCheckerBuilder(google::protobuf::DescriptorPool::generated_pool())); if (!request.no_std_env()) { CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::StringsCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::MathCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::EncodersCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::ComprehensionsV2CheckerLibrary())); } for (const auto& decl : request.type_env()) { const auto& name = decl.name(); if (decl.has_function()) { CEL_ASSIGN_OR_RETURN( auto fn_decl, cel::FunctionDeclFromV1Alpha1Proto( name, decl.function(), google::protobuf::DescriptorPool::generated_pool(), arena)); CEL_RETURN_IF_ERROR(builder->AddFunction(std::move(fn_decl))); } else if (decl.has_ident()) { CEL_ASSIGN_OR_RETURN( auto var_decl, cel::VariableDeclFromV1Alpha1Proto( name, decl.ident(), google::protobuf::DescriptorPool::generated_pool(), arena)); CEL_RETURN_IF_ERROR(builder->AddVariable(std::move(var_decl))); } } builder->set_container(request.container()); CEL_ASSIGN_OR_RETURN(auto checker, std::move(*builder).Build()); CEL_ASSIGN_OR_RETURN(auto validation_result, checker->Check(std::move(ast))); for (const auto& checker_issue : validation_result.GetIssues()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(absl::StatusCode::kInvalidArgument)); if (source) { issue->set_message(checker_issue.ToDisplayString(*source)); } else { issue->set_message(checker_issue.message()); } } const cel::Ast* checked_ast = validation_result.GetAst(); if (!validation_result.IsValid() || checked_ast == nullptr) { return absl::OkStatus(); } cel::expr::CheckedExpr pb_checked_ast; CEL_RETURN_IF_ERROR( cel::AstToCheckedExpr(*validation_result.GetAst(), &pb_checked_ast)); ABSL_CHECK(ConvertWireCompatProto(pb_checked_ast, // Crash OK response.mutable_checked_expr())); return absl::OkStatus(); } class LegacyConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( bool optimize, bool recursive, bool select_optimization) { static auto* constant_arena = new Arena(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto2::TestAllTypes>(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::NestedTestAllTypes>(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto2::NestedTestAllTypes>(); google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::int32_ext); google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::nested_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::test_all_types_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::nested_enum_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::repeated_test_all_types); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: int64_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: message_scoped_nested_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: nested_enum_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: message_scoped_repeated_test_all_types); InterpreterOptions options; options.enable_qualified_type_identifiers = true; options.enable_timestamp_duration_overflow_errors = true; options.enable_heterogeneous_equality = true; options.enable_empty_wrapper_null_unboxing = true; options.enable_qualified_identifier_rewrites = true; options.fail_on_warnings = false; if (optimize) { std::cerr << "Enabling optimizations" << std::endl; options.constant_folding = true; options.constant_arena = constant_arena; } if (select_optimization) { std::cerr << "Enabling select optimizations" << std::endl; options.enable_select_optimization = true; } if (recursive) { options.max_recursion_depth = 48; } std::unique_ptr builder = CreateCelExpressionBuilder(options); auto type_registry = builder->GetTypeRegistry(); type_registry->Register( cel::expr::conformance::proto2::GlobalEnum_descriptor()); type_registry->Register( cel::expr::conformance::proto3::GlobalEnum_descriptor()); type_registry->Register( cel::expr::conformance::proto2::TestAllTypes::NestedEnum_descriptor()); type_registry->Register( cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor()); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions( builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterStringsFunctions( builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions( builder->GetRegistry(), options)); return absl::WrapUnique( new LegacyConformanceServiceImpl(std::move(builder))); } void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = LegacyParse(request, response, /*enable_optional_syntax=*/false); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); issue->set_message(status.message()); } } void Check(const conformance::v1alpha1::CheckRequest& request, conformance::v1alpha1::CheckResponse& response) override { google::protobuf::Arena arena; auto status = CheckImpl(&arena, request, response); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); issue->set_message(status.message()); } } absl::Status Eval(const conformance::v1alpha1::EvalRequest& request, conformance::v1alpha1::EvalResponse& response) override { Arena arena; cel::expr::SourceInfo source_info; cel::expr::Expr expr = ExtractExpr(request); builder_->set_container(request.container()); absl::StatusOr> cel_expression_status = absl::InternalError( "no expression provided in ConformanceService::Eval"); if (request.has_parsed_expr()) { cel::expr::ParsedExpr parsed_expr; if (!ConvertWireCompatProto(request.parsed_expr(), &parsed_expr)) { return absl::InternalError( "failed to convert versioned ParsedExpr to unversioned"); } cel_expression_status = builder_->CreateExpression( &parsed_expr.expr(), &parsed_expr.source_info()); } else if (request.has_checked_expr()) { cel::expr::CheckedExpr checked_expr; if (!ConvertWireCompatProto(request.checked_expr(), &checked_expr)) { return absl::InternalError( "failed to convert versioned CheckedExpr to unversioned"); } cel_expression_status = builder_->CreateExpression(&checked_expr); } if (!cel_expression_status.ok()) { return absl::InternalError(cel_expression_status.status().ToString( absl::StatusToStringMode::kWithEverything)); } auto cel_expression = std::move(cel_expression_status.value()); Activation activation; for (const auto& pair : request.bindings()) { auto* import_value = Arena::Create(&arena); ABSL_CHECK(ConvertWireCompatProto(pair.second.value(), // Crash OK import_value)); auto import_status = ValueToCelValue(*import_value, &arena); if (!import_status.ok()) { return absl::InternalError(import_status.status().ToString( absl::StatusToStringMode::kWithEverything)); } activation.InsertValue(pair.first, import_status.value()); } auto eval_status = cel_expression->Evaluate(activation, &arena); if (!eval_status.ok()) { *response.mutable_result() ->mutable_error() ->add_errors() ->mutable_message() = eval_status.status().ToString( absl::StatusToStringMode::kWithEverything); return absl::OkStatus(); } CelValue result = eval_status.value(); if (result.IsError()) { *response.mutable_result() ->mutable_error() ->add_errors() ->mutable_message() = std::string(result.ErrorOrDie()->ToString( absl::StatusToStringMode::kWithEverything)); } else { cel::expr::Value export_value; auto export_status = CelValueToValue(result, &export_value); if (!export_status.ok()) { return absl::InternalError( export_status.ToString(absl::StatusToStringMode::kWithEverything)); } auto* result_value = response.mutable_result()->mutable_value(); ABSL_CHECK( // Crash OK ConvertWireCompatProto(export_value, result_value)); } return absl::OkStatus(); } private: explicit LegacyConformanceServiceImpl( std::unique_ptr builder) : builder_(std::move(builder)) {} std::unique_ptr builder_; }; class ModernConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( bool optimize, bool recursive, bool select_optimization) { google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto2::TestAllTypes>(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::NestedTestAllTypes>(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto2::NestedTestAllTypes>(); google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::int32_ext); google::protobuf::LinkExtensionReflection(cel::expr::conformance::proto2::nested_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::test_all_types_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::nested_enum_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::repeated_test_all_types); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: int64_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: message_scoped_nested_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: nested_enum_ext); google::protobuf::LinkExtensionReflection( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: message_scoped_repeated_test_all_types); RuntimeOptions options; options.enable_qualified_type_identifiers = true; options.enable_timestamp_duration_overflow_errors = true; options.enable_heterogeneous_equality = true; options.enable_empty_wrapper_null_unboxing = true; // Planning warnings are expected in conformance tests, but the test expects // failure to happen at evaluation time so we ignore them. options.fail_on_warnings = false; if (recursive) { options.max_recursion_depth = 48; } return absl::WrapUnique(new ModernConformanceServiceImpl( options, optimize, select_optimization)); } absl::StatusOr> Setup( absl::string_view container) { RuntimeOptions options(options_); options.container = std::string(container); CEL_ASSIGN_OR_RETURN( auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); if (enable_optimizations_) { CEL_RETURN_IF_ERROR(cel::extensions::EnableConstantFolding( builder, google::protobuf::MessageFactory::generated_factory())); } CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( builder, cel::ReferenceResolverEnabled::kAlways)); if (enable_select_optimization_) { CEL_RETURN_IF_ERROR(cel::extensions::EnableSelectOptimization(builder)); } auto& type_registry = builder.type_registry(); // Use linked pbs in the generated descriptor pool. CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, cel::expr::conformance::proto2::GlobalEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, cel::expr::conformance::proto3::GlobalEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, cel::expr::conformance::proto2::TestAllTypes::NestedEnum_descriptor())); CEL_RETURN_IF_ERROR(RegisterProtobufEnum( type_registry, cel::expr::conformance::proto3::TestAllTypes::NestedEnum_descriptor())); CEL_RETURN_IF_ERROR(cel::extensions::RegisterComprehensionsV2Functions( builder.function_registry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::EnableOptionalTypes(builder)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterEncodersFunctions( builder.function_registry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterStringsFunctions( builder.function_registry(), options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions( builder.function_registry(), options)); return std::move(builder).Build(); } void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = LegacyParse(request, response, /*enable_optional_syntax=*/true); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); issue->set_message(status.message()); } } void Check(const conformance::v1alpha1::CheckRequest& request, conformance::v1alpha1::CheckResponse& response) override { google::protobuf::Arena arena; auto status = CheckImpl(&arena, request, response); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); issue->set_message(status.message()); } } absl::Status Eval(const conformance::v1alpha1::EvalRequest& request, conformance::v1alpha1::EvalResponse& response) override { google::protobuf::Arena arena; auto runtime_status = Setup(request.container()); if (!runtime_status.ok()) { return absl::InternalError(runtime_status.status().ToString( absl::StatusToStringMode::kWithEverything)); } std::unique_ptr runtime = std::move(runtime_status).value(); auto program_status = Plan(*runtime, request); if (!program_status.ok()) { return absl::InternalError(program_status.status().ToString( absl::StatusToStringMode::kWithEverything)); } std::unique_ptr program = std::move(program_status).value(); cel::Activation activation; for (const auto& pair : request.bindings()) { cel::expr::Value import_value; ABSL_CHECK(ConvertWireCompatProto(pair.second.value(), // Crash OK &import_value)); auto import_status = FromExprValue(import_value, runtime->GetDescriptorPool(), runtime->GetMessageFactory(), &arena); if (!import_status.ok()) { return absl::InternalError(import_status.status().ToString( absl::StatusToStringMode::kWithEverything)); } activation.InsertOrAssignValue(pair.first, std::move(import_status).value()); } auto eval_status = program->Evaluate(&arena, activation); if (!eval_status.ok()) { *response.mutable_result() ->mutable_error() ->add_errors() ->mutable_message() = eval_status.status().ToString( absl::StatusToStringMode::kWithEverything); return absl::OkStatus(); } cel::Value result = eval_status.value(); if (result->Is()) { const absl::Status& error = result.GetError().NativeValue(); *response.mutable_result() ->mutable_error() ->add_errors() ->mutable_message() = std::string( error.ToString(absl::StatusToStringMode::kWithEverything)); } else { auto export_status = ToExprValue(result, runtime->GetDescriptorPool(), runtime->GetMessageFactory(), &arena); if (!export_status.ok()) { return absl::InternalError(export_status.status().ToString( absl::StatusToStringMode::kWithEverything)); } auto* result_value = response.mutable_result()->mutable_value(); ABSL_CHECK( // Crash OK ConvertWireCompatProto(*export_status, result_value)); } return absl::OkStatus(); } private: ModernConformanceServiceImpl(const RuntimeOptions& options, bool enable_optimizations, bool enable_select_optimization) : options_(options), enable_optimizations_(enable_optimizations), enable_select_optimization_(enable_select_optimization) {} static absl::StatusOr> Plan( const cel::Runtime& runtime, const conformance::v1alpha1::EvalRequest& request) { std::unique_ptr ast; if (request.has_parsed_expr()) { cel::expr::ParsedExpr unversioned; ABSL_CHECK(ConvertWireCompatProto(request.parsed_expr(), // Crash OK &unversioned)); CEL_ASSIGN_OR_RETURN( ast, cel::CreateAstFromParsedExpr(std::move(unversioned))); } else if (request.has_checked_expr()) { cel::expr::CheckedExpr unversioned; ABSL_CHECK(ConvertWireCompatProto(request.checked_expr(), // Crash OK &unversioned)); CEL_ASSIGN_OR_RETURN( ast, cel::CreateAstFromCheckedExpr(std::move(unversioned))); } if (ast == nullptr) { return absl::InternalError("no expression provided"); } return runtime.CreateTraceableProgram(std::move(ast)); } RuntimeOptions options_; bool enable_optimizations_; bool enable_select_optimization_; }; } // namespace } // namespace google::api::expr::runtime namespace cel_conformance { absl::StatusOr> NewConformanceService(const ConformanceServiceOptions& options) { if (options.modern) { return google::api::expr::runtime::ModernConformanceServiceImpl::Create( options.optimize, options.recursive, options.select_optimization); } else { return google::api::expr::runtime::LegacyConformanceServiceImpl::Create( options.optimize, options.recursive, options.select_optimization); } } } // namespace cel_conformance ================================================ FILE: conformance/service.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CONFORMANCE_SERVICE_H_ #define THIRD_PARTY_CEL_CPP_CONFORMANCE_SERVICE_H_ #include #include "google/api/expr/conformance/v1alpha1/conformance_service.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" namespace cel_conformance { class ConformanceServiceInterface { public: virtual ~ConformanceServiceInterface() = default; virtual void Parse( const google::api::expr::conformance::v1alpha1::ParseRequest& request, google::api::expr::conformance::v1alpha1::ParseResponse& response) = 0; virtual void Check( const google::api::expr::conformance::v1alpha1::CheckRequest& request, google::api::expr::conformance::v1alpha1::CheckResponse& response) = 0; virtual absl::Status Eval( const google::api::expr::conformance::v1alpha1::EvalRequest& request, google::api::expr::conformance::v1alpha1::EvalResponse& response) = 0; }; struct ConformanceServiceOptions { bool optimize; bool modern; bool arena; bool recursive; bool select_optimization; }; absl::StatusOr> NewConformanceService(const ConformanceServiceOptions&); } // namespace cel_conformance #endif // THIRD_PARTY_CEL_CPP_CONFORMANCE_SERVICE_H_ ================================================ FILE: conformance/utils.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_ #define THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_ #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/eval.pb.h" #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" #include "cel/expr/value.pb.h" #include "absl/log/absl_check.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" #include "google/protobuf/util/field_comparator.h" #include "google/protobuf/util/message_differencer.h" namespace cel_conformance { inline std::string DescribeMessage(const google::protobuf::Message& message) { std::string string; ABSL_CHECK(google::protobuf::TextFormat::PrintToString(message, &string)); if (string.empty()) { string = "\"\"\n"; } return string; } MATCHER_P(MatchesConformanceValue, expected, "") { static auto* kFieldComparator = []() { auto* field_comparator = new google::protobuf::util::DefaultFieldComparator(); field_comparator->set_treat_nan_as_equal(true); return field_comparator; }(); static auto* kDifferencer = []() { auto* differencer = new google::protobuf::util::MessageDifferencer(); differencer->set_message_field_comparison( google::protobuf::util::MessageDifferencer::EQUIVALENT); differencer->set_field_comparator(kFieldComparator); const auto* descriptor = cel::expr::MapValue::descriptor(); const auto* entries_field = descriptor->FindFieldByName("entries"); const auto* key_field = entries_field->message_type()->FindFieldByName("key"); differencer->TreatAsMap(entries_field, key_field); return differencer; }(); const cel::expr::ExprValue& got = arg; const cel::expr::Value& want = expected; cel::expr::ExprValue test_value; (*test_value.mutable_value()) = want; if (kDifferencer->Compare(got, test_value)) { return true; } (*result_listener) << "got: " << DescribeMessage(got); (*result_listener) << "\n"; (*result_listener) << "wanted: " << DescribeMessage(test_value); return false; } MATCHER_P(ResultTypeMatches, expected, "") { static auto* kDifferencer = []() { auto* differencer = new google::protobuf::util::MessageDifferencer(); differencer->set_message_field_comparison( google::protobuf::util::MessageDifferencer::EQUIVALENT); return differencer; }(); const cel::expr::Type& want = expected; const google::api::expr::v1alpha1::CheckedExpr& checked_expr = arg; int64_t root_id = checked_expr.expr().id(); auto it = checked_expr.type_map().find(root_id); if (it == checked_expr.type_map().end()) { (*result_listener) << "type map does not contain root id: " << root_id; return false; } auto got_versioned = it->second; std::string serialized; cel::expr::Type got; if (!got_versioned.SerializeToString(&serialized) || !got.ParseFromString(serialized)) { (*result_listener) << "type cannot be converted from versioned type: " << DescribeMessage(got_versioned); return false; } if (kDifferencer->Compare(got, want)) { return true; } (*result_listener) << "got: " << DescribeMessage(got); (*result_listener) << "\n"; (*result_listener) << "wanted: " << DescribeMessage(want); return false; } } // namespace cel_conformance #endif // THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_ ================================================ FILE: env/BUILD ================================================ # Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) cc_library( name = "config", srcs = [ "config.cc", "type_info.cc", ], hdrs = [ "config.h", "type_info.h", ], deps = [ "//common:constant", "//common:type", "//common:type_kind", "//internal:status_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "env", srcs = ["env.cc"], hdrs = ["env.h"], deps = [ ":config", "//checker:type_checker_builder", "//common:constant", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//env/internal:ext_registry", "//internal:status_macros", "//parser:macro", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "env_runtime", srcs = ["env_runtime.cc"], hdrs = ["env_runtime.h"], deps = [ ":config", "//env/internal:runtime_ext_registry", "//internal:status_macros", "//runtime", "//runtime:runtime_builder", "//runtime:runtime_builder_factory", "//runtime:runtime_options", "//runtime:standard_functions", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "env_std_extensions", srcs = ["env_std_extensions.cc"], hdrs = ["env_std_extensions.h"], deps = [ ":env", "//checker:optional", "//compiler:optional", "//extensions:bindings_ext", "//extensions:comprehensions_v2", "//extensions:encoders", "//extensions:lists_functions", "//extensions:math_ext_decls", "//extensions:proto_ext", "//extensions:regex_ext", "//extensions:sets_functions", "//extensions:strings", ], ) cc_library( name = "env_yaml", srcs = ["env_yaml.cc"], hdrs = ["env_yaml.h"], copts = [ "-fexceptions", ], features = ["-use_header_modules"], deps = [ ":config", "//common:constant", "//internal:status_macros", "//internal:strings", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@yaml-cpp", ], ) cc_library( name = "runtime_std_extensions", srcs = ["runtime_std_extensions.cc"], hdrs = ["runtime_std_extensions.h"], deps = [ ":env_runtime", "//checker:optional", "//env/internal:runtime_ext_registry", "//extensions:encoders", "//extensions:lists_functions", "//extensions:math_ext", "//extensions:math_ext_decls", "//extensions:regex_ext", "//extensions:sets_functions", "//extensions:strings", "//runtime:optional_types", "//runtime:runtime_builder", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) cc_test( name = "config_test", srcs = ["config_test.cc"], deps = [ ":config", "//common:constant", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", ], ) cc_test( name = "type_info_test", srcs = ["type_info_test.cc"], deps = [ ":config", "//common:type", "//common:type_proto", "//internal:proto_matchers", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "env_test", srcs = ["env_test.cc"], deps = [ ":config", ":env", "//checker:type_check_issue", "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:constant", "//common:decl", "//common:expr", "//common:type", "//common:value", "//compiler", "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//parser:macro", "//parser:macro_expr_factory", "//parser:parser_interface", "//runtime", "//runtime:activation", "//runtime:reference_resolver", "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "env_runtime_test", srcs = ["env_runtime_test.cc"], deps = [ ":config", ":env", ":env_runtime", ":env_std_extensions", ":env_yaml", ":runtime_std_extensions", "//checker:validation_result", "//common:ast", "//common:source", "//common:value", "//compiler", "//extensions:math_ext", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", "//runtime:runtime_builder", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "env_std_extensions_test", srcs = ["env_std_extensions_test.cc"], deps = [ ":config", ":env", ":env_std_extensions", "//compiler", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "env_yaml_test", srcs = ["env_yaml_test.cc"], deps = [ ":config", ":env_yaml", "//common:constant", "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", ], ) cc_test( name = "runtime_std_extensions_test", srcs = ["runtime_std_extensions_test.cc"], deps = [ ":config", ":env", ":env_runtime", ":env_std_extensions", ":runtime_std_extensions", "//checker:optional", "//checker:validation_result", "//common:ast", "//common:value", "//compiler", "//extensions:lists_functions", "//extensions:math_ext_decls", "//extensions:strings", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: env/config.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/config.h" #include #include #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/functional/overload.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/constant.h" #include "internal/status_macros.h" namespace cel { namespace { const char* ConstantKindToTypeName(const ConstantKind& kind) { return std::visit(absl::Overload{ [](const std::monostate& arg) { return "dyn"; }, [](const std::nullptr_t& arg) { return "null"; }, [](bool arg) { return "bool"; }, [](int64_t arg) { return "int"; }, [](uint64_t arg) { return "uint"; }, [](double arg) { return "double"; }, [](const BytesConstant& arg) { return "bytes"; }, [](const StringConstant& arg) { return "string"; }, [](absl::Duration arg) { return "duration"; }, [](absl::Time arg) { return "timestamp"; }, }, kind); } } // namespace absl::Status Config::AddExtensionConfig(std::string name, int version) { for (const ExtensionConfig& extension_config : extension_configs_) { if (extension_config.name == name) { if (extension_config.version == version) { return absl::OkStatus(); } std::string version_str; if (version == ExtensionConfig::kLatest) { version_str = "'latest'"; } else { version_str = absl::StrCat(version); } return absl::AlreadyExistsError(absl::StrCat( "Extension '", name, "' version ", extension_config.version, " is already included. Cannot also include version ", version_str)); } } extension_configs_.push_back( ExtensionConfig{.name = std::move(name), .version = version}); return absl::OkStatus(); } absl::Status Config::SetStandardLibraryConfig( const Config::StandardLibraryConfig& standard_library_config) { if (!standard_library_config.included_macros.empty() && !standard_library_config.excluded_macros.empty()) { return absl::InvalidArgumentError( "Cannot set both included and excluded macros."); } if (!standard_library_config.included_functions.empty() && !standard_library_config.excluded_functions.empty()) { return absl::InvalidArgumentError( "Cannot set both included and excluded functions."); } absl::flat_hash_set included_function_names; for (const auto& function : standard_library_config.included_functions) { if (function.second.empty()) { included_function_names.insert(function.first); } } for (const auto& function : standard_library_config.included_functions) { if (included_function_names.contains(function.first) && !function.second.empty()) { return absl::InvalidArgumentError(absl::StrCat( "Cannot include function '", function.first, "' and also its specific overload '", function.second, "'")); } } absl::flat_hash_set excluded_function_names; for (const auto& function : standard_library_config.excluded_functions) { if (function.second.empty()) { excluded_function_names.insert(function.first); } } for (const auto& function : standard_library_config.excluded_functions) { if (excluded_function_names.contains(function.first) && !function.second.empty()) { return absl::InvalidArgumentError(absl::StrCat( "Cannot exclude function '", function.first, "' and also its specific overload '", function.second, "'")); } } standard_library_config_ = standard_library_config; return absl::OkStatus(); } absl::Status Config::AddVariableConfig(const VariableConfig& variable_config) { for (const VariableConfig& existing_variable_config : variable_configs_) { if (existing_variable_config.name == variable_config.name) { return absl::AlreadyExistsError(absl::StrCat( "Variable '", variable_config.name, "' is already included.")); } } if (variable_config.value.has_value()) { absl::string_view constant_type_name = ConstantKindToTypeName(variable_config.value.kind()); if (constant_type_name != variable_config.type_info.name) { return absl::InvalidArgumentError( absl::StrCat("Variable '", variable_config.name, "' has type ", variable_config.type_info.name, " but is assigned a constant value of type ", constant_type_name, ".")); } } variable_configs_.push_back(variable_config); return absl::OkStatus(); } absl::Status Config::ValidateFunctionConfig( const FunctionConfig& function_config) { for (const auto& overload : function_config.overload_configs) { if (overload.is_member_function && overload.parameters.empty()) { return absl::InvalidArgumentError(absl::StrCat( "Function '", function_config.name, "' overload '", overload.overload_id, "' is marked as a member function but has no parameters. Member " "functions must have at least one parameter (target).")); } } return absl::OkStatus(); } absl::Status Config::AddFunctionConfig(const FunctionConfig& function_config) { CEL_RETURN_IF_ERROR(ValidateFunctionConfig(function_config)); function_configs_.push_back(function_config); return absl::OkStatus(); } std::ostream& operator<<(std::ostream& os, const Config::StandardLibraryConfig& config) { os << "StandardLibraryConfig("; if (!config.included_macros.empty()) { os << "\n included_macros=" << absl::StrJoin(config.included_macros, ", "); } if (!config.excluded_macros.empty()) { os << "\n excluded_macros=" << absl::StrJoin(config.excluded_macros, ", "); } if (!config.included_functions.empty()) { os << "\n included_functions=" << absl::StrJoin(config.included_functions, ", ", [](std::string* out, const std::pair& p) { absl::StrAppend(out, p.first, ":", p.second); }); } if (!config.excluded_functions.empty()) { os << "\n excluded_functions=" << absl::StrJoin(config.excluded_functions, ", ", [](std::string* out, const std::pair& p) { absl::StrAppend(out, p.first, ":", p.second); }); } os << "\n)"; return os; } } // namespace cel ================================================ FILE: env/config.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ #define THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ #include #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "common/constant.h" namespace cel { class Config { public: void SetName(std::string name) { name_ = std::move(name); } std::string GetName() const { return name_; } struct ContainerConfig { std::string name; // TODO(uncreated-issue/87): add support for aliases and abbreviations. bool IsEmpty() const { return name.empty(); } }; void SetContainerConfig(ContainerConfig container_config) { container_config_ = std::move(container_config); } const ContainerConfig& GetContainerConfig() const { return container_config_; } struct ExtensionConfig { static constexpr int kLatest = std::numeric_limits::max(); std::string name; int version = kLatest; }; absl::Status AddExtensionConfig(std::string name, int version = ExtensionConfig::kLatest); const std::vector& GetExtensionConfigs() const { return extension_configs_; } struct StandardLibraryConfig { // Exclude the entire standard library. bool disable = false; // Exclude all standard library macros. bool disable_macros = false; // Either included or excluded macros can be set, not both. If neither are // set, all standard library macros are included. absl::flat_hash_set included_macros; absl::flat_hash_set excluded_macros; // Sets of pairs of function name and overload id to include or exclude. // Either included or excluded functions can be set, not both. If neither // are set, all standard library functions are included. // If an overload is specified, only that overload is included or excluded. // If no overload is specified (empty second element of pair), all overloads // are included or excluded. absl::flat_hash_set> included_functions; absl::flat_hash_set> excluded_functions; bool IsEmpty() const { return !disable && !disable_macros && included_macros.empty() && excluded_macros.empty() && included_functions.empty() && excluded_functions.empty(); } }; absl::Status SetStandardLibraryConfig( const StandardLibraryConfig& standard_library_config); const StandardLibraryConfig& GetStandardLibraryConfig() const { return standard_library_config_; } struct TypeInfo { std::string name; std::vector params; bool is_type_param = false; }; struct VariableConfig { std::string name; std::string description; TypeInfo type_info; Constant value; }; // Adds a variable config to the environment. The variable name and type // are used by the CEL type checker to validate expressions. The variable // value is used as an input value at runtime. // // Returns an error if a variable with the same name already exists, or if the // type of the constant value does not match the specified type. absl::Status AddVariableConfig(const VariableConfig& variable_config); const std::vector& GetVariableConfigs() const { return variable_configs_; } struct FunctionOverloadConfig { std::string overload_id; std::vector examples; bool is_member_function = false; std::vector parameters; TypeInfo return_type; }; struct FunctionConfig { std::string name; std::string description; std::vector overload_configs; }; absl::Status AddFunctionConfig(const FunctionConfig& function_config); const std::vector& GetFunctionConfigs() const { return function_configs_; } private: std::string name_; ContainerConfig container_config_; std::vector extension_configs_; StandardLibraryConfig standard_library_config_; std::vector variable_configs_; std::vector function_configs_; absl::Status ValidateFunctionConfig(const FunctionConfig& function_config); }; std::ostream& operator<<(std::ostream& os, const Config::StandardLibraryConfig& config); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ ================================================ FILE: env/config_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/config.h" #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "common/constant.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::testing::AllOf; using ::testing::ElementsAre; using ::testing::Field; using ::testing::HasSubstr; using ::testing::UnorderedElementsAre; TEST(EnvConfigTest, ExtensionConfigs) { Config config; ASSERT_THAT( config.AddExtensionConfig("math", Config::ExtensionConfig::kLatest), IsOk()); ASSERT_THAT(config.AddExtensionConfig("optional", 2), IsOk()); ASSERT_THAT(config.AddExtensionConfig("strings"), IsOk()); EXPECT_THAT(config.GetExtensionConfigs(), UnorderedElementsAre( AllOf(Field(&Config::ExtensionConfig::name, "math"), Field(&Config::ExtensionConfig::version, Config::ExtensionConfig::kLatest)), AllOf(Field(&Config::ExtensionConfig::name, "optional"), Field(&Config::ExtensionConfig::version, 2)), AllOf(Field(&Config::ExtensionConfig::name, "strings"), Field(&Config::ExtensionConfig::version, Config::ExtensionConfig::kLatest)))); } TEST(EnvConfigTest, ExtensionConfigConflict) { Config config; ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); ASSERT_THAT(config.AddExtensionConfig("math", 3), StatusIs(absl::StatusCode::kAlreadyExists)); } struct StandardLibraryConfigTestCase { Config::StandardLibraryConfig standard_library_config; std::string expected_error; // Empty if no error is expected. }; class StandardLibraryConfigTest : public testing::TestWithParam {}; TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { const StandardLibraryConfigTestCase& param = GetParam(); Config config; absl::Status status = config.SetStandardLibraryConfig(param.standard_library_config); if (param.expected_error.empty()) { EXPECT_THAT(status, IsOk()); } else { EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); } } INSTANTIATE_TEST_SUITE_P( StandardLibraryConfigTest, StandardLibraryConfigTest, ::testing::Values( StandardLibraryConfigTestCase{ .standard_library_config = {}, }, StandardLibraryConfigTestCase{ .standard_library_config = { .included_macros = {"all", "exists"}, .excluded_macros = {"map", "filter"}, }, .expected_error = "Cannot set both included and excluded macros.", }, StandardLibraryConfigTestCase{ .standard_library_config = { .included_functions = {{"_+_", "add_int64"}, {"_+_", "add_list"}}, .excluded_functions = {{"_-_", ""}}, }, .expected_error = "Cannot set both included and excluded functions.", }, StandardLibraryConfigTestCase{ .standard_library_config = { .included_functions = {{"_+_", ""}, {"_+_", "add_list"}}, }, .expected_error = "Cannot include function '_+_' and also its " "specific overload 'add_list'", }, StandardLibraryConfigTestCase{ .standard_library_config = { .excluded_functions = {{"_+_", ""}, {"_+_", "add_list"}}, }, .expected_error = "Cannot exclude function '_+_' and also its " "specific overload 'add_list'", })); TEST(VariableConfigTest, VariableConfig) { Config config; Config::VariableConfig variable_config{ .name = "test", .type_info = { .name = "mytype", .params = {{.name = "int"}, {.name = "A", .is_type_param = true}}, }, }; ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); ASSERT_EQ(config.GetVariableConfigs().size(), 1); const auto& added_config = config.GetVariableConfigs()[0]; EXPECT_EQ(added_config.type_info.name, "mytype"); ASSERT_THAT(added_config.type_info.params.size(), 2); EXPECT_EQ(added_config.type_info.params[0].name, "int"); EXPECT_FALSE(added_config.type_info.params[0].is_type_param); EXPECT_EQ(added_config.type_info.params[1].name, "A"); EXPECT_TRUE(added_config.type_info.params[1].is_type_param); } TEST(VariableConfigTest, VariableConfigConflict) { Config config; Config::VariableConfig variable_config{ .name = "test", .type_info = {.name = "int"}, }; EXPECT_THAT(config.AddVariableConfig(variable_config), IsOk()); EXPECT_THAT(config.AddVariableConfig(variable_config), StatusIs(absl::StatusCode::kAlreadyExists)); } TEST(VariableConfigTest, VariableConfigValueTypeMismatch) { Config config; Config::VariableConfig variable_config{ .name = "test", .type_info = {.name = "int"}, .value = Constant(StringConstant("hello")), }; EXPECT_THAT(config.AddVariableConfig(variable_config), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Variable 'test' has type int but is assigned " "a constant value of type string."))); } TEST(FunctionConfigTest, FunctionConfig) { Config config; Config::FunctionConfig function_config; function_config.name = "test"; function_config.description = "Ultimate test"; function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ .overload_id = "test_with_pill", .examples = {"oracle.isTheOne('Neo', RED)"}, .is_member_function = true, .parameters = {{.name = "string"}, {.name = "Choice"}}, .return_type = {.name = "bool"}, }); ASSERT_THAT(config.AddFunctionConfig(function_config), IsOk()); ASSERT_EQ(config.GetFunctionConfigs().size(), 1); const auto& added_config = config.GetFunctionConfigs()[0]; EXPECT_EQ(added_config.name, "test"); EXPECT_EQ(added_config.description, "Ultimate test"); EXPECT_EQ(added_config.overload_configs.size(), 1); const auto& overload_config = added_config.overload_configs[0]; EXPECT_EQ(overload_config.overload_id, "test_with_pill"); EXPECT_THAT(overload_config.examples, ElementsAre("oracle.isTheOne('Neo', RED)")); EXPECT_TRUE(overload_config.is_member_function); EXPECT_THAT( overload_config.parameters, ElementsAre(AllOf(Field(&Config::TypeInfo::name, "string"), Field(&Config::TypeInfo::is_type_param, false)), AllOf(Field(&Config::TypeInfo::name, "Choice"), Field(&Config::TypeInfo::is_type_param, false)))); EXPECT_THAT(overload_config.return_type, Field(&Config::TypeInfo::name, "bool")); } TEST(FunctionConfigTest, FunctionConfigInvalidMember) { Config config; Config::FunctionConfig function_config; function_config.name = "test"; function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ .overload_id = "test_member_no_params", .is_member_function = true, .parameters = {}, }); EXPECT_THAT(config.AddFunctionConfig(function_config), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("is marked as a member function but has no " "parameters"))); } } // namespace } // namespace cel ================================================ FILE: env/env.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/env.h" #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "common/constant.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "env/config.h" #include "env/type_info.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, absl::string_view macro) { if (config.disable_macros) { return false; } if (config.excluded_macros.contains(macro)) { return false; } if (!config.included_macros.empty() && !config.included_macros.contains(macro)) { return false; } return true; } bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, absl::string_view function, absl::string_view overload_id) { if (config.excluded_functions.contains( std::make_pair(std::string(function), std::string(overload_id))) || config.excluded_functions.contains( std::make_pair(std::string(function), ""))) { return false; } if (!config.included_functions.empty() && !config.included_functions.contains( std::make_pair(std::string(function), "")) && !config.included_functions.contains( std::make_pair(std::string(function), std::string(overload_id)))) { return false; } return true; } absl::StatusOr MakeStdlibSubset( const Config::StandardLibraryConfig& standard_library_config) { CompilerLibrarySubset subset; subset.library_id = "stdlib"; // Capturing by reference is safe. The returned CompilerLibrarySubset's // callbacks are only used during CompilerBuilder::Build() to configure // contributed functions and macros. They are not retained by the constructed // Compiler instance. The referenced config outlives the Build() call. subset.should_include_macro = [&standard_library_config](const Macro& macro) { return ShouldIncludeMacro(standard_library_config, macro.function()); }; subset.should_include_overload = [&standard_library_config]( absl::string_view function, absl::string_view overload_id) { return ShouldIncludeFunction(standard_library_config, function, overload_id); }; return subset; } absl::StatusOr FunctionConfigToFunctionDecl( const Config::FunctionConfig& function_config, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool* descriptor_pool) { FunctionDecl function_decl; function_decl.set_name(function_config.name); for (const Config::FunctionOverloadConfig& overload_config : function_config.overload_configs) { OverloadDecl overload_decl; overload_decl.set_id(overload_config.overload_id); overload_decl.set_member(overload_config.is_member_function); for (const Config::TypeInfo& parameter : overload_config.parameters) { CEL_ASSIGN_OR_RETURN(Type parameter_type, TypeInfoToType(parameter, descriptor_pool, arena)); overload_decl.mutable_args().push_back(parameter_type); } CEL_ASSIGN_OR_RETURN( Type return_type, TypeInfoToType(overload_config.return_type, descriptor_pool, arena)); overload_decl.set_result(return_type); CEL_RETURN_IF_ERROR(function_decl.AddOverload(overload_decl)); } return function_decl; } } // namespace Env::Env() { compiler_options_.parser_options.enable_quoted_identifiers = true; } absl::StatusOr> Env::NewCompilerBuilder() { CEL_ASSIGN_OR_RETURN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(descriptor_pool_, compiler_options_)); cel::TypeCheckerBuilder& checker_builder = compiler_builder->GetCheckerBuilder(); checker_builder.set_container(config_.GetContainerConfig().name); if (!config_.GetStandardLibraryConfig().disable) { CEL_RETURN_IF_ERROR( compiler_builder->AddLibrary(StandardCompilerLibrary())); CEL_ASSIGN_OR_RETURN(CompilerLibrarySubset standard_library_subset, MakeStdlibSubset(config_.GetStandardLibraryConfig())); CEL_RETURN_IF_ERROR( compiler_builder->AddLibrarySubset(std::move(standard_library_subset))); } for (const Config::ExtensionConfig& extension_config : config_.GetExtensionConfigs()) { CEL_ASSIGN_OR_RETURN(CompilerLibrary library, extension_registry_.GetCompilerLibrary( extension_config.name, extension_config.version)); CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(std::move(library))); } google::protobuf::Arena* arena = checker_builder.arena(); for (const Config::VariableConfig& variable_config : config_.GetVariableConfigs()) { VariableDecl variable_decl; variable_decl.set_name(variable_config.name); CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(variable_config.type_info, descriptor_pool_.get(), arena)); variable_decl.set_type(type); if (variable_config.value.has_value()) { variable_decl.set_value(variable_config.value); } CEL_RETURN_IF_ERROR(checker_builder.AddVariable(variable_decl)); } for (const Config::FunctionConfig& function_config : config_.GetFunctionConfigs()) { CEL_ASSIGN_OR_RETURN(FunctionDecl function_decl, FunctionConfigToFunctionDecl(function_config, arena, descriptor_pool_.get())); CEL_RETURN_IF_ERROR(checker_builder.AddFunction(function_decl)); } return compiler_builder; } absl::StatusOr> Env::NewCompiler() { CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler_builder, NewCompilerBuilder()); return compiler_builder->Build(); } } // namespace cel ================================================ FILE: env/env.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_H_ #define THIRD_PARTY_CEL_CPP_ENV_ENV_H_ #include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "compiler/compiler.h" #include "env/config.h" #include "env/internal/ext_registry.h" #include "google/protobuf/descriptor.h" namespace cel { // Env class establishes the environment for compiling CEL expressions. // // It is used to configure compiler options, extension functions, and other // customizable CEL features. class Env { public: Env(); // Registers a `CompilerLibrary` with the environment. Note that the library // does not automatically get added to a `Compiler`. `NewCompiler` relies // on `Config` to determine which libraries to load. void RegisterCompilerLibrary( absl::string_view name, absl::string_view alias, int version, absl::AnyInvocable library_factory) { extension_registry_.RegisterCompilerLibrary(name, alias, version, std::move(library_factory)); } void SetDescriptorPool( std::shared_ptr descriptor_pool) { descriptor_pool_ = std::move(descriptor_pool); } const google::protobuf::DescriptorPool* GetDescriptorPool() const { return descriptor_pool_.get(); } void SetConfig(const Config& config) { config_ = config; } absl::StatusOr> NewCompilerBuilder(); // Shortcut for NewCompilerBuilder() followed by Build(). absl::StatusOr> NewCompiler(); private: cel::env_internal::ExtensionRegistry extension_registry_; std::shared_ptr descriptor_pool_; CompilerOptions compiler_options_; Config config_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_ENV_H_ ================================================ FILE: env/env_runtime.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/env_runtime.h" #include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "env/config.h" #include "internal/status_macros.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_builder_factory.h" #include "runtime/runtime_options.h" #include "runtime/standard_functions.h" namespace cel { void EnvRuntime::RegisterExtensionFunctions( absl::string_view name, absl::string_view alias, int version, absl::AnyInvocable function_registration_callback) { extension_registry_.AddFunctionRegistration( name, alias, version, std::move(function_registration_callback)); } absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { const std::vector& extension_configs = config_.GetExtensionConfigs(); const Config::ExtensionConfig* optional_extension_config = nullptr; for (const Config::ExtensionConfig& extension_config : extension_configs) { if (extension_config.name == "optional") { optional_extension_config = &extension_config; runtime_options_.enable_qualified_type_identifiers = true; break; } } CEL_ASSIGN_OR_RETURN( RuntimeBuilder runtime_builder, cel::CreateRuntimeBuilder(descriptor_pool_, runtime_options_)); if (!config_.GetStandardLibraryConfig().disable) { CEL_RETURN_IF_ERROR(RegisterStandardFunctions( runtime_builder.function_registry(), runtime_options_)); } // Register optional extension functions first, because other extensions // depend on it (e.g. regex). if (optional_extension_config != nullptr) { CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( runtime_builder, runtime_options_, optional_extension_config->name, optional_extension_config->version)); } for (const Config::ExtensionConfig& extension_config : extension_configs) { if (&extension_config == optional_extension_config) { continue; } CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( runtime_builder, runtime_options_, extension_config.name, extension_config.version)); } return runtime_builder; } absl::StatusOr> EnvRuntime::NewRuntime() { CEL_ASSIGN_OR_RETURN(RuntimeBuilder runtime_builder, CreateRuntimeBuilder()); return std::move(runtime_builder).Build(); } } // namespace cel ================================================ FILE: env/env_runtime.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ #define THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "env/config.h" #include "env/internal/runtime_ext_registry.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "google/protobuf/descriptor.h" namespace cel { // EnvRuntime class establishes the environment for creating CEL runtimes. // // It is used to configure runtime options, extension functions, and other // customizable CEL runtime features. // // EnvRuntime is separate from Env to avoid a dependency on the compiler for // binaries that only use the runtime. // // Even though EnvRuntime is separate from Env, the Config and DescriptorPool // passed to EnvRuntime are expected to be the same as those passed to Env for // compilation. This ensures consistency between compilation and runtime. class EnvRuntime { public: // Registers a function registration callback for an extension. The callback // is invoked when a runtime is created, if the corresponding functions are // enabled in the runtime config. void RegisterExtensionFunctions( absl::string_view name, absl::string_view alias, int version, absl::AnyInvocable function_registration_callback); void SetDescriptorPool( std::shared_ptr descriptor_pool) { descriptor_pool_ = std::move(descriptor_pool); } void SetConfig(const Config& config) { config_ = config; } RuntimeOptions& mutable_runtime_options() { return runtime_options_; } absl::StatusOr CreateRuntimeBuilder(); // Shortcut for CreateRuntimeBuilder() followed by Build(). absl::StatusOr> NewRuntime(); private: cel::env_internal::RuntimeExtensionRegistry& GetRuntimeExtensionRegistry() { return extension_registry_; } friend void RegisterStandardExtensions(EnvRuntime& env_runtime); cel::env_internal::RuntimeExtensionRegistry extension_registry_; std::shared_ptr descriptor_pool_; Config config_; RuntimeOptions runtime_options_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ ================================================ FILE: env/env_runtime_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/env_runtime.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/source.h" #include "common/value.h" #include "compiler/compiler.h" #include "env/config.h" #include "env/env.h" #include "env/env_std_extensions.h" #include "env/env_yaml.h" #include "env/runtime_std_extensions.h" #include "extensions/math_ext.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::testing::IsEmpty; using ::testing::ValuesIn; struct TestCase { std::string config_yaml; std::string expr; bool expected_to_fail = false; }; class EnvRuntimeTest : public testing::TestWithParam {}; TEST_P(EnvRuntimeTest, EndToEnd) { const TestCase& param = GetParam(); auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.config_yaml)); Env env; env.SetDescriptorPool(descriptor_pool); RegisterStandardExtensions(env); env.SetConfig(config); EnvRuntime env_runtime; env_runtime.SetDescriptorPool(descriptor_pool); RegisterStandardExtensions(env_runtime); env_runtime.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); std::unique_ptr ast; if (!param.expected_to_fail) { ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(ast, result.ReleaseAst()); } else { // Bypass type checking to allow compilation to succeed since we expect the // runtime to fail. ASSERT_OK_AND_ASSIGN(std::unique_ptr source, NewSource(param.expr, "")); ASSERT_OK_AND_ASSIGN(ast, compiler->GetParser().Parse(*source)); } ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, env_runtime.NewRuntime()); absl::StatusOr> program_or = runtime->CreateProgram(std::move(ast)); if (param.expected_to_fail) { EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) << " expr: " << param.expr; return; } ASSERT_THAT(program_or, IsOk()) << " expr: " << param.expr; std::unique_ptr program = *std::move(program_or); ASSERT_NE(program, nullptr); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); EXPECT_TRUE(value.GetBool()) << " expr: " << param.expr; } std::vector GetEnvRuntimeTestCases() { return { TestCase{ .config_yaml = R"yaml( extensions: - name: "encoders" )yaml", .expr = "base64.encode(b'hello') == 'aGVsbG8='", }, TestCase{ .config_yaml = R"yaml( extensions: - name: "encoders" - name: "optional" )yaml", .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " "optional.of(1).hasValue()", }, TestCase{ .config_yaml = R"yaml( extensions: - name: "encoders" )yaml", .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " "optional.of(1).hasValue()", .expected_to_fail = true, }, TestCase{ .config_yaml = R"yaml( stdlib: disable: true )yaml", .expr = "1 + 2 == 3", .expected_to_fail = true, }, TestCase{ .config_yaml = R"yaml( stdlib: disable: true extensions: - name: "encoders" )yaml", .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " "1 + 2 == 3", .expected_to_fail = true, }, }; } INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, ValuesIn(GetEnvRuntimeTestCases())); TEST(EnvRuntimeTest, RegisterExtensionFunctions) { auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); Config config; ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); Env env; env.SetDescriptorPool(descriptor_pool); RegisterStandardExtensions(env); env.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("math.sqrt(4) == 2.0")); EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); EnvRuntime env_runtime; env_runtime.SetDescriptorPool(descriptor_pool); env_runtime.RegisterExtensionFunctions( "cel.lib.math", "math", 2, [](cel::RuntimeBuilder& runtime_builder, const cel::RuntimeOptions& opts) -> absl::Status { return cel::extensions::RegisterMathExtensionFunctions( runtime_builder.function_registry(), opts, 2); }); env_runtime.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, env_runtime.NewRuntime()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); ASSERT_NE(program, nullptr); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); EXPECT_TRUE(value.GetBool()); } } // namespace } // namespace cel ================================================ FILE: env/env_std_extensions.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/env_std_extensions.h" #include "checker/optional.h" #include "compiler/optional.h" #include "env/env.h" #include "extensions/bindings_ext.h" #include "extensions/comprehensions_v2.h" #include "extensions/encoders.h" #include "extensions/lists_functions.h" #include "extensions/math_ext_decls.h" #include "extensions/proto_ext.h" #include "extensions/regex_ext.h" #include "extensions/sets_functions.h" #include "extensions/strings.h" namespace cel { void RegisterStandardExtensions(Env& env) { env.RegisterCompilerLibrary("cel.lib.ext.bindings", "bindings", 0, []() { return extensions::BindingsCompilerLibrary(); }); env.RegisterCompilerLibrary("cel.lib.ext.encoders", "encoders", 0, []() { return extensions::EncodersCompilerLibrary(); }); for (int version = 0; version <= extensions::kListsExtensionLatestVersion; ++version) { env.RegisterCompilerLibrary( "cel.lib.ext.lists", "lists", version, [version]() { return extensions::ListsCompilerLibrary(version); }); } for (int version = 0; version <= extensions::kMathExtensionLatestVersion; ++version) { env.RegisterCompilerLibrary( "cel.lib.ext.math", "math", version, [version]() { return extensions::MathCompilerLibrary(version); }); } for (int version = 0; version <= kOptionalExtensionLatestVersion; ++version) { env.RegisterCompilerLibrary("optional", "", version, [version]() { return OptionalCompilerLibrary(version); }); } env.RegisterCompilerLibrary("cel.lib.ext.protos", "protos", 0, []() { return extensions::ProtoExtCompilerLibrary(); }); env.RegisterCompilerLibrary("cel.lib.ext.sets", "sets", 0, []() { return extensions::SetsCompilerLibrary(); }); for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; ++version) { env.RegisterCompilerLibrary( "cel.lib.ext.strings", "strings", version, [version]() { return extensions::StringsCompilerLibrary(version); }); } env.RegisterCompilerLibrary( "cel.lib.ext.comprev2", "two-var-comprehensions", 0, []() { return extensions::ComprehensionsV2CompilerLibrary(); }); env.RegisterCompilerLibrary("cel.lib.ext.regex", "regex", 0, []() { return extensions::RegexExtCompilerLibrary(); }); } } // namespace cel ================================================ FILE: env/env_std_extensions.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ #define THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ #include "env/env.h" namespace cel { // Registers the standard CEL extensions with the given environment. This makes // them available, but does not enable them. See Env::Config for how to enable // extensions. // // Extensions are registered under the following names: // // - cel.lib.ext.bindings (alias: "bindings") // - cel.lib.ext.encoders (alias: "encoders") // - cel.lib.ext.lists (alias: "lists") // - cel.lib.ext.math (alias: "math") // - optional // - cel.lib.ext.protos (alias: "protos") // - cel.lib.ext.sets (alias: "sets") // - cel.lib.ext.strings (alias: "strings") // - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") // - cel.lib.ext.regex (alias: "regex") void RegisterStandardExtensions(Env& env); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ ================================================ FILE: env/env_std_extensions_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/env_std_extensions.h" #include #include #include "absl/strings/string_view.h" #include "compiler/compiler.h" #include "env/config.h" #include "env/env.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::testing::TestWithParam; struct TestCase { std::string extension; std::string expr; }; class EnvStdExtensions : public testing::TestWithParam {}; TEST_P(EnvStdExtensions, RegistrationTest) { const TestCase& param = GetParam(); Env env; RegisterStandardExtensions(env); env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); Config config; ASSERT_THAT(config.AddExtensionConfig(param.extension), IsOk()); env.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(param.expr)); ASSERT_TRUE(result.IsValid()) << "Expected no issues for expr: " << param.expr << " but got: " << result.FormatError(); } INSTANTIATE_TEST_SUITE_P( RegistrationTest, EnvStdExtensions, ::testing::Values( TestCase{ .extension = "cel.lib.ext.bindings", // official name .expr = "cel.bind(t, true, t)", }, TestCase{ .extension = "bindings", // alias .expr = "cel.bind(t, true, t)", }, TestCase{ .extension = "encoders", .expr = "base64.encode(b'hello')", }, TestCase{ .extension = "lists", .expr = "[1, 2, 3].sort()", }, TestCase{ .extension = "lists", .expr = "['a'].sortBy(e, e)", }, TestCase{ .extension = "math", .expr = "math.sqrt(-1)", }, TestCase{ .extension = "optional", .expr = "[1, 2].first()", }, TestCase{ .extension = "optional", .expr = "[0][?1]", // optional syntax auto-enabled }, TestCase{ .extension = "protos", .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " "cel.expr.conformance.proto2.nested_ext)", }, TestCase{ .extension = "sets", .expr = "sets.contains([1], [1])", }, TestCase{ .extension = "strings", .expr = "'foo'.reverse()", }, TestCase{ .extension = "two-var-comprehensions", .expr = "[1, 2, 3, 4].all(i, v, i < v)", }, TestCase{ .extension = "regex", .expr = "regex.replace('abc', '$', '_end')", })); } // namespace } // namespace cel ================================================ FILE: env/env_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/env.h" #include #include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "env/config.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/macro.h" #include "parser/macro_expr_factory.h" #include "parser/parser_interface.h" #include "runtime/activation.h" #include "runtime/reference_resolver.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Property; using ::testing::UnorderedElementsAre; using ::testing::Values; using ::testing::ValuesIn; Expr TestMacroExpander(MacroExprFactory& factory, absl::Span args) { return factory.NewStringConst("Hello"); } class TestLibrary : public CompilerLibrary { public: explicit TestLibrary(int version) : CompilerLibrary( "testlib", [version](ParserBuilder& builder) { absl::Status status; CEL_ASSIGN_OR_RETURN( auto macro1, cel::Macro::Global("testMacro1", 0, TestMacroExpander)); status.Update(builder.AddMacro(macro1)); if (version == 2) { CEL_ASSIGN_OR_RETURN( auto macro2, cel::Macro::Global("testMacro2", 0, TestMacroExpander)); status.Update(builder.AddMacro(macro2)); } return status; }, [version](TypeCheckerBuilder& builder) { absl::Status status; CEL_ASSIGN_OR_RETURN( auto func1, cel::MakeFunctionDecl( "testFunc1", MakeOverloadDecl(StringType()))); status.Update(builder.AddFunction(func1)); if (version == 2) { CEL_ASSIGN_OR_RETURN( auto func2, cel::MakeFunctionDecl("testFunc2", MakeOverloadDecl(StringType()))); status.Update(builder.AddFunction(func2)); } return status; }) {}; }; absl::StatusOr CompileAndEvalExpr( Env& env, absl::string_view expr, const Activation& activation = Activation()) { CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, env.NewCompiler()); if (compiler == nullptr) { return absl::InternalError("Failed to create compiler"); } CEL_ASSIGN_OR_RETURN(ValidationResult result, compiler->Compile(expr)); if (!result.GetIssues().empty()) { return absl::InvalidArgumentError(result.FormatError()); } cel::RuntimeOptions opts; CEL_ASSIGN_OR_RETURN( cel::RuntimeBuilder rt_builder, cel::CreateStandardRuntimeBuilder(env.GetDescriptorPool(), opts)); CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( rt_builder, cel::ReferenceResolverEnabled::kAlways)); CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, std::move(rt_builder).Build()); if (runtime == nullptr) { return absl::InternalError("Failed to create runtime"); } CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, result.ReleaseAst()); if (ast == nullptr) { return absl::InternalError("Failed to create AST"); } google::protobuf::Arena arena; CEL_ASSIGN_OR_RETURN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); if (program == nullptr) { return absl::InternalError("Failed to create program"); } CEL_ASSIGN_OR_RETURN(Value value, program->Evaluate(&arena, activation)); return value; } absl::StatusOr CompileAndEvalBooleanExpr( Env& env, absl::string_view expr, const Activation& activation = Activation()) { CEL_ASSIGN_OR_RETURN(auto value, CompileAndEvalExpr(env, expr, activation)); return value.GetBool(); } class LibraryConfigTest : public testing::Test { protected: void SetUp() override { env_.RegisterCompilerLibrary("testlib", "ml", 1, []() { return TestLibrary(1); }); env_.RegisterCompilerLibrary("testlib", "ml", 2, []() { return TestLibrary(2); }); env_.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); } Env env_; }; TEST_F(LibraryConfigTest, DefaultVersion) { Config config; ASSERT_THAT(config.AddExtensionConfig("testlib"), IsOk()); env_.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); EXPECT_THAT(result1.GetIssues(), IsEmpty()); EXPECT_THAT(result2.GetIssues(), IsEmpty()); EXPECT_THAT(result3.GetIssues(), IsEmpty()); EXPECT_THAT(result4.GetIssues(), IsEmpty()); } TEST_F(LibraryConfigTest, SpecificVersion) { Config config; ASSERT_THAT(config.AddExtensionConfig("testlib", 1), IsOk()); env_.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); EXPECT_THAT(result1.GetIssues(), IsEmpty()); EXPECT_THAT(result2.GetIssues(), IsEmpty()); EXPECT_THAT(result3.GetIssues(), UnorderedElementsAre( Property(&TypeCheckIssue::message, HasSubstr("undeclared reference to 'testMacro2'")))); EXPECT_THAT(result4.GetIssues(), UnorderedElementsAre( Property(&TypeCheckIssue::message, HasSubstr("undeclared reference to 'testFunc2'")))); } struct StandardLibraryConfigTestCase { Config::StandardLibraryConfig standard_library_config; std::vector expected_valid_expressions; std::vector expected_invalid_expressions; }; class StandardLibraryConfigTest : public testing::TestWithParam {}; TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { const StandardLibraryConfigTestCase& param = GetParam(); Env env; env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); Config config; ASSERT_THAT(config.SetStandardLibraryConfig(param.standard_library_config), IsOk()); env.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); for (const std::string& expr : param.expected_valid_expressions) { ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); EXPECT_THAT(result1.GetIssues(), IsEmpty()) << "With config: " << param.standard_library_config << ", expected no issues for expr: " << expr << " but got: " << result1.FormatError(); } for (const std::string& expr : param.expected_invalid_expressions) { ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); EXPECT_THAT(result1.GetIssues(), Not(IsEmpty())) << "With config: " << param.standard_library_config << ", expected compilation error for expr: " << expr << " but got: \'" << result1.FormatError() << "\'"; } } INSTANTIATE_TEST_SUITE_P( StandardLibraryConfigTest, StandardLibraryConfigTest, Values( StandardLibraryConfigTestCase{ .standard_library_config = {}, .expected_valid_expressions = {"1 + 2", "[1, 2, 3].exists(x, x == 1)", "[1, 2, 3].all(x, x == 1)", "[1, 2, 3].map(x, x)"}, }, StandardLibraryConfigTestCase{ .standard_library_config = {.disable = true}, .expected_invalid_expressions = {"1 + 2", "[1, 2, 3].exists(x, x == 1)", "[1, 2, 3].all(x, x == 1)", "[1, 2, 3].map(x, x)"}, }, StandardLibraryConfigTestCase{ .standard_library_config = {.disable_macros = true}, .expected_valid_expressions = {"1 + 2"}, .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)", "[1, 2, 3].all(x, x == 1)", "[1, 2, 3].map(x, x)"}, }, StandardLibraryConfigTestCase{ .standard_library_config = {.excluded_macros = {"map", "all"}}, .expected_valid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, .expected_invalid_expressions = {"[1, 2, 3].all(x, x == 1)", "[1, 2, 3].map(x, x)"}, }, StandardLibraryConfigTestCase{ .standard_library_config = {.included_macros = {"map", "all"}}, .expected_valid_expressions = {"[1, 2, 3].all(x, x == 1)", "[1, 2, 3].map(x, x)"}, .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, }, StandardLibraryConfigTestCase{ .standard_library_config = {.excluded_functions = {{"_+_", ""}}}, .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, StandardLibraryConfigTestCase{ .standard_library_config = {.excluded_functions = {{"_+_", "add_bytes"}, {"_+_", "add_list"}, {"_+_", "add_string"}}}, .expected_valid_expressions = {"1 + 2"}, .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, StandardLibraryConfigTestCase{ .standard_library_config = {.included_functions = {{"_+_", ""}}}, .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, StandardLibraryConfigTestCase{ .standard_library_config = {.included_functions = {{"_+_", "add_int64"}, {"_+_", "add_list"}}}, .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, .expected_invalid_expressions = {"'hello' + 'world'"}, })); TEST(ContainerConfigTest, ContainerConfig) { Env env; env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); Config config; config.SetContainerConfig({.name = "cel.expr.conformance.proto2"}); env.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; std::string validate_value_expr; }; class VariableConfigWithValueTest : public testing::TestWithParam {}; TEST_P(VariableConfigWithValueTest, VariableConfigWithValue) { const VariableConfigWithValueTestCase& param = GetParam(); Env env; env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); Config config; ASSERT_THAT(config.AddVariableConfig(param.variable_config), IsOk()); env.SetConfig(config); ASSERT_OK_AND_ASSIGN( bool type_as_expected, CompileAndEvalBooleanExpr(env, param.validate_type_expr)); ASSERT_TRUE(type_as_expected) << " expr: " << param.validate_type_expr; if (!param.validate_value_expr.empty()) { ASSERT_OK_AND_ASSIGN( bool value_as_expected, CompileAndEvalBooleanExpr(env, param.validate_value_expr)); ASSERT_TRUE(value_as_expected) << " expr: " << param.validate_value_expr; } } Config::VariableConfig MakeConstant( absl::string_view variable_name, absl::string_view type_name, absl::AnyInvocable setter) { Config::VariableConfig variable_config; variable_config.name = variable_name; Constant c; setter(c); variable_config.type_info.name = type_name; variable_config.value = c; return variable_config; } std::vector GetVariableConfigWithValueTestCases() { return { VariableConfigWithValueTestCase{ .variable_config = MakeConstant( "x", "null", [](auto& c) { c.set_null_value(nullptr); }), .validate_type_expr = "type(x) == type(null)", }, VariableConfigWithValueTestCase{ .variable_config = MakeConstant( "x", "bool", [](auto& c) { c.set_bool_value(true); }), .validate_type_expr = "type(x) == bool", .validate_value_expr = "x == true", }, VariableConfigWithValueTestCase{ .variable_config = MakeConstant( "x", "int", [](Constant& c) { c.set_int_value(42); }), .validate_type_expr = "type(x) == int", .validate_value_expr = "x == 42", }, VariableConfigWithValueTestCase{ .variable_config = MakeConstant( "x", "uint", [](Constant& c) { c.set_uint_value(777); }), .validate_type_expr = "type(x) == uint", .validate_value_expr = "x == 777u", }, VariableConfigWithValueTestCase{ .variable_config = MakeConstant("x", "double", [](Constant& c) { c.set_double_value(1.0 / 3.0); }), .validate_type_expr = "type(x) == double", .validate_value_expr = "x > 0.333 && x < 0.334", }, VariableConfigWithValueTestCase{ .variable_config = MakeConstant("x", "bytes", [](Constant& c) { c.set_bytes_value(absl::string_view( "\xff\x00\x01", 3)); }), .validate_type_expr = "type(x) == bytes", .validate_value_expr = "x == b'\\xff\\x00\\x01'", }, VariableConfigWithValueTestCase{ .variable_config = MakeConstant( "x", "string", [](Constant& c) { c.set_string_value("hello"); }), .validate_type_expr = "type(x) == string", .validate_value_expr = "x == 'hello'", }, VariableConfigWithValueTestCase{ .variable_config = MakeConstant( "x", "timestamp", [](Constant& c) { // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) c.set_timestamp_value(absl::FromUnixSeconds(1767323045)); }), .validate_type_expr = "type(x) == type(timestamp('2026-01-02T03:04:05Z'))", .validate_value_expr = "x == timestamp('2026-01-02T03:04:05Z')", }, VariableConfigWithValueTestCase{ .variable_config = MakeConstant( "x", "duration", [](Constant& c) { // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) c.set_duration_value(absl::Hours(1) + absl::Minutes(2) + absl::Seconds(3)); }), .validate_type_expr = "type(x) == type(duration('1h2m3s'))", .validate_value_expr = "x == duration('1h2m3s')", }, }; } INSTANTIATE_TEST_SUITE_P(VariableConfigTest, VariableConfigWithValueTest, ValuesIn(GetVariableConfigWithValueTestCases())); struct FunctionConfigTestCase { Config::FunctionConfig function_config; std::vector variable_configs; std::string expr; std::string expected_error; }; class FunctionConfigTest : public testing::TestWithParam {}; TEST_P(FunctionConfigTest, FunctionConfig) { const FunctionConfigTestCase& param = GetParam(); Env env; env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); Config config; for (const Config::VariableConfig& variable_config : param.variable_configs) { ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); } ASSERT_THAT(config.AddFunctionConfig(param.function_config), IsOk()); env.SetConfig(config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); if (param.expected_error.empty()) { EXPECT_TRUE(result.GetIssues().empty()) << " expr: " << param.expr << " error: " << result.FormatError(); } else { EXPECT_THAT(result.GetIssues(), UnorderedElementsAre(Property(&TypeCheckIssue::message, HasSubstr(param.expected_error)))) << " expr: " << param.expr << " error: " << result.FormatError(); } } std::vector GetFunctionConfigTestCases() { return {{ FunctionConfigTestCase{ .function_config = { .name = "add", .overload_configs = { { .overload_id = "plus(int,int)", .examples = {"add(1, 2) -> 3"}, .parameters = {{.name = "int"}, {.name = "int"}}, .return_type = {.name = "int"}, }, }, }, .expr = "add(1, 2)", }, FunctionConfigTestCase{ .function_config = { .name = "add", .overload_configs = { { .overload_id = "int.plus(int)", .examples = {"1.add(2) -> 3"}, .is_member_function = true, .parameters = {{.name = "int"}, {.name = "int"}}, .return_type = {.name = "int"}, }, }, }, .expr = "1.add(2) == 3", }, FunctionConfigTestCase{ .function_config = { .name = "add", .overload_configs = { { .overload_id = "plus(string,string)", .examples = {"add('hello', 'world') -> 'hello world'"}, .parameters = {{.name = "int"}, {.name = "int"}}, .return_type = {.name = "string"}, }, }, }, .expr = "add('hello', 'world')", .expected_error = "found no matching overload for 'add' applied to " "'(string, string)'", }, FunctionConfigTestCase{ .function_config = { .name = "add", .overload_configs = { { .overload_id = "int.plus(int)", .examples = {"1.add(2) -> 'three'"}, .is_member_function = true, .parameters = {{.name = "int"}, {.name = "int"}}, .return_type = {.name = "string"}, }, }, }, .expr = "1.add(2) == 3", .expected_error = "found no matching overload for '_==_' applied to " "'(string, int)'", }, FunctionConfigTestCase{ .function_config = { .name = "sum", .description = "Sum a collection, which is an opaque type.", .overload_configs = { { .overload_id = "sum(collection)", .examples = {"sum(my_collection) -> 100"}, .parameters = {{.name = "collection", .params = {{.name = "double"}}}}, .return_type = {.name = "double"}, }, }, }, .variable_configs = { {.name = "my_collection", .description = "Matching opaque type.", .type_info = {.name = "collection", .params = {{.name = "double"}}}}, }, .expr = "sum(my_collection) / 3.0", }, FunctionConfigTestCase{ .function_config = { .name = "sum", .description = "Sum a collection, which is an opaque type.", .overload_configs = { { .overload_id = "sum(collection)", .examples = {"sum(my_collection) -> 100"}, .parameters = {{.name = "collection", .params = {{.name = "int"}}}}, .return_type = {.name = "double"}, }, }, }, .variable_configs = { {.name = "my_collection", .description = "Mismatched opaque type.", .type_info = {.name = "collection", .params = {{.name = "double"}}}}, }, .expr = "sum(my_collection) / 3.0", .expected_error = "found no matching overload for 'sum' applied to " "'(collection(double))'", }, }}; } INSTANTIATE_TEST_SUITE_P(FunctionConfigTest, FunctionConfigTest, ::testing::ValuesIn(GetFunctionConfigTestCases())); } // namespace } // namespace cel ================================================ FILE: env/env_yaml.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/env_yaml.h" #include #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/escaping.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/constant.h" #include "env/config.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "yaml-cpp/emitter.h" #include "yaml-cpp/emittermanip.h" #include "yaml-cpp/exceptions.h" #include "yaml-cpp/mark.h" #include "yaml-cpp/node/node.h" #include "yaml-cpp/node/parse.h" #include "yaml-cpp/null.h" #include "yaml-cpp/yaml.h" // IWYU pragma: keep namespace cel { namespace { std::string FormatYamlErrorMessage(absl::string_view yaml, absl::string_view error, const YAML::Mark& mark) { if (mark.is_null()) { return std::string(error); } std::string message; absl::StrAppend(&message, mark.line + 1, ":", mark.column + 1, ": ", error, "\n|"); size_t start = mark.pos - mark.column; size_t end = yaml.find('\n', mark.pos); if (end == std::string::npos) { end = yaml.size(); } absl::StrAppend(&message, yaml.substr(start, end - start), "\n|", std::string(mark.column, ' '), "^"); return message; } absl::StatusOr LoadYaml(const std::string& yaml) { try { return YAML::Load(yaml); } catch (YAML::ParserException& e) { return absl::InvalidArgumentError( FormatYamlErrorMessage(yaml, e.msg, e.mark)); } } absl::Status YamlError(absl::string_view yaml, const YAML::Node& node, absl::string_view error) { return absl::InvalidArgumentError( FormatYamlErrorMessage(yaml, error, node.Mark())); } std::string GetString(absl::string_view yaml, const YAML::Node& node) { if (!node.IsDefined() || !node.IsScalar()) { return ""; } try { return node.as(); } catch (YAML::Exception& e) { // This should never happen since we already checked that the node is a // scalar and all scalars can be converted to strings. return ""; } } bool IsBinary(const YAML::Node& node) { return node.Tag() == "!!binary" || node.Tag() == "tag:yaml.org,2002:binary"; } absl::StatusOr GetBinary(absl::string_view yaml, const YAML::Node& node) { if (!node.IsDefined() || !node.IsScalar() || !IsBinary(node)) { return ""; } std::string binary; // Instead of using the YAML::Binary type, we use absl::Base64Unescape // because YAML::Binary is lenient to Base64 decoding errors. if (absl::Base64Unescape(GetString(yaml, node), &binary)) { return binary; } else { return YamlError(yaml, node, "Node '" + GetString(yaml, node) + "' is not a valid Base64 encoded binary"); } } absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, const YAML::Node& node) { if (!node.IsDefined() || !node.IsScalar()) { return false; } try { return node.as(); } catch (YAML::Exception& e) { return YamlError(yaml, node, "Node '" + std::string(key) + "' is not a boolean"); } } absl::Status ParseName(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node name = root["name"]; if (name.IsDefined()) { if (!name.IsScalar()) { return YamlError(yaml, name, "Node 'name' is not a string"); } config.SetName(GetString(yaml, name)); } return absl::OkStatus(); } absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node container = root["container"]; if (container.IsDefined()) { if (!container.IsScalar()) { return YamlError(yaml, container, "Node 'container' is not a string"); } config.SetContainerConfig({.name = GetString(yaml, container)}); } return absl::OkStatus(); } absl::Status ParseExtensionConfigs(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node extensions = root["extensions"]; if (!extensions.IsDefined()) { return absl::OkStatus(); } if (!extensions.IsSequence()) { return YamlError(yaml, extensions, "Node 'extensions' is not a sequence"); } for (const YAML::Node& extension : extensions) { if (!extension || !extension.IsMap()) { return YamlError(yaml, extension, "Extension is not a map"); } const YAML::Node name = extension["name"]; if (!name || !name.IsScalar()) { return YamlError(yaml, name, "Extension name is not a string"); } std::string name_str = GetString(yaml, name); const YAML::Node version = extension["version"]; std::string version_str = GetString(yaml, version); int extension_version; if (version.IsDefined()) { bool is_valid_version = false; if (version.IsScalar()) { if (version_str == "latest") { extension_version = Config::ExtensionConfig::kLatest; is_valid_version = true; } else { if (absl::SimpleAtoi(version_str, &extension_version) && extension_version >= 0) { is_valid_version = true; } } } if (!is_valid_version) { return YamlError( yaml, version, absl::StrCat("Extension '", name_str, "' version is not a valid number or 'latest'")); } } else { extension_version = Config::ExtensionConfig::kLatest; } absl::Status add_status = config.AddExtensionConfig(name_str, extension_version); if (!add_status.ok()) { return YamlError(yaml, extension, add_status.message()); } } return absl::OkStatus(); } absl::StatusOr> ParseMacroList( absl::string_view yaml, const YAML::Node& standard_library, absl::string_view key) { absl::flat_hash_set macro_set; const YAML::Node macros = standard_library[std::string(key)]; if (!macros.IsDefined()) { return macro_set; } if (!macros.IsSequence()) { return YamlError(yaml, macros, absl::StrCat("Node '", key, "' is not a sequence")); } for (const YAML::Node& macro : macros) { if (!macro.IsScalar()) { return YamlError(yaml, macro, absl::StrCat("Entry in '", key, "' is not a string")); } macro_set.insert(GetString(yaml, macro)); } return macro_set; } absl::StatusOr>> ParseFunctionList(absl::string_view yaml, const YAML::Node& standard_library, absl::string_view key) { absl::flat_hash_set> function_set; const YAML::Node functions = standard_library[std::string(key)]; if (!functions.IsDefined()) { return function_set; } if (!functions.IsSequence()) { return YamlError(yaml, functions, absl::StrCat("Node '", key, "' is not a sequence")); } for (const YAML::Node& function : functions) { if (!function.IsMap()) { return YamlError(yaml, function, absl::StrCat("Entry in '", key, "' is not a map")); } const YAML::Node name = function["name"]; if (!name.IsDefined()) { return YamlError( yaml, function, absl::StrCat("Function name in not specified in '", key, "'")); } if (!name.IsScalar()) { return YamlError( yaml, name, absl::StrCat("Function name in '", key, "' entry is not a string")); } std::string name_str = GetString(yaml, name); const YAML::Node overloads = function["overloads"]; if (!overloads.IsDefined()) { function_set.insert(std::make_pair(name_str, "")); } else { if (!overloads.IsSequence()) { return YamlError( yaml, overloads, absl::StrCat("Overloads in '", key, "' entry is not a sequence")); } for (const YAML::Node& overload : overloads) { if (!overload.IsMap()) { return YamlError( yaml, overload, absl::StrCat("Overload in '", key, "' entry is not a map")); } const YAML::Node id = overload["id"]; if (!id || !id.IsScalar()) { return YamlError( yaml, id, absl::StrCat("Overload id in '", key, "' entry is not a string")); } function_set.insert(std::make_pair(name_str, GetString(yaml, id))); } } } return function_set; } absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node standard_library = root["stdlib"]; if (!standard_library.IsDefined()) { return absl::OkStatus(); } if (!standard_library.IsMap()) { return YamlError(yaml, standard_library, "Standard library config ('stdlib') is not a map"); } Config::StandardLibraryConfig standard_library_config; const YAML::Node disable = standard_library["disable"]; if (disable.IsDefined()) { if (!disable.IsScalar()) { return YamlError(yaml, disable, "Node 'disable' is not a boolean"); } CEL_ASSIGN_OR_RETURN(standard_library_config.disable, GetBool(yaml, "disable", disable)); } const YAML::Node disable_macros = standard_library["disable_macros"]; if (disable_macros.IsDefined()) { if (!disable_macros.IsScalar()) { return YamlError(yaml, disable_macros, "Node 'disable_macros' is not a boolean"); } CEL_ASSIGN_OR_RETURN(standard_library_config.disable_macros, GetBool(yaml, "disable_macros", disable_macros)); } CEL_ASSIGN_OR_RETURN( standard_library_config.included_macros, ParseMacroList(yaml, standard_library, "include_macros")); CEL_ASSIGN_OR_RETURN( standard_library_config.excluded_macros, ParseMacroList(yaml, standard_library, "exclude_macros")); CEL_ASSIGN_OR_RETURN( standard_library_config.included_functions, ParseFunctionList(yaml, standard_library, "include_functions")); CEL_ASSIGN_OR_RETURN( standard_library_config.excluded_functions, ParseFunctionList(yaml, standard_library, "exclude_functions")); return config.SetStandardLibraryConfig(standard_library_config); } absl::StatusOr ParseTypeInfo(const YAML::Node& node, absl::string_view yaml) { Config::TypeInfo type_config; const YAML::Node type_name = node["type_name"]; if (!type_name.IsDefined()) { return type_config; } if (!type_name || !type_name.IsScalar()) { return YamlError(yaml, type_name, "Node 'type_name' is not a string"); } type_config.name = GetString(yaml, type_name); const YAML::Node is_type_param = node["is_type_param"]; if (is_type_param.IsDefined()) { if (!is_type_param.IsScalar()) { return YamlError(yaml, is_type_param, "Node 'is_type_param' is not a boolean"); } CEL_ASSIGN_OR_RETURN(type_config.is_type_param, GetBool(yaml, "is_type_param", is_type_param)); } const YAML::Node params = node["params"]; if (!params.IsDefined()) { return type_config; } if (!params.IsSequence()) { return YamlError(yaml, params, "Node 'params' is not a sequence"); } for (const YAML::Node& param : params) { CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_config, ParseTypeInfo(param, yaml)); type_config.params.push_back(param_config); } return type_config; } bool CompareTypeInfo(const Config::TypeInfo& a, const Config::TypeInfo& b) { if (a.name != b.name) { return a.name < b.name; } if (a.params.size() != b.params.size()) { return a.params.size() < b.params.size(); } for (size_t i = 0; i < a.params.size(); ++i) { if (CompareTypeInfo(a.params[i], b.params[i])) { return true; } if (CompareTypeInfo(b.params[i], a.params[i])) { return false; } } return false; // They are equal } ConstantKindCase GetConstantKindCase(absl::string_view type_name) { static const auto kTypeNameToConstantKindCase = absl::NoDestructor>({ {"null", ConstantKindCase::kNull}, {"bool", ConstantKindCase::kBool}, {"int", ConstantKindCase::kInt}, {"uint", ConstantKindCase::kUint}, {"double", ConstantKindCase::kDouble}, {"string", ConstantKindCase::kString}, {"bytes", ConstantKindCase::kBytes}, {"duration", ConstantKindCase::kDuration}, {"timestamp", ConstantKindCase::kTimestamp}, }); if (auto it = kTypeNameToConstantKindCase->find(type_name); it != kTypeNameToConstantKindCase->end()) { return it->second; } return ConstantKindCase::kUnspecified; } absl::StatusOr ParseConstantValue(absl::string_view yaml, const YAML::Node& node, ConstantKindCase constant_kind_case, absl::string_view value) { switch (constant_kind_case) { case ConstantKindCase::kNull: if (!value.empty()) { return YamlError(yaml, node, "Failed to parse null constant"); } return Constant(nullptr); case ConstantKindCase::kBool: if (absl::EqualsIgnoreCase(value, "true")) { return Constant(true); } else if (absl::EqualsIgnoreCase(value, "false")) { return Constant(false); } else { return YamlError(yaml, node, "Failed to parse bool constant"); } case ConstantKindCase::kInt: int64_t int_value; if (!absl::SimpleAtoi(value, &int_value)) { return YamlError(yaml, node, "Failed to parse int constant"); } return Constant(int_value); case ConstantKindCase::kUint: uint64_t uint_value; if (absl::EndsWith(value, "u")) { value = value.substr(0, value.size() - 1); } if (!absl::SimpleAtoi(value, &uint_value)) { return YamlError(yaml, node, "Failed to parse uint constant"); } return Constant(uint_value); case ConstantKindCase::kDouble: double double_value; if (!absl::SimpleAtod(value, &double_value)) { return YamlError(yaml, node, "Failed to parse double constant"); } return Constant(double_value); case ConstantKindCase::kBytes: { if (!IsBinary(node)) { absl::StatusOr bytes_literal = internal::ParseBytesLiteral(value); if (bytes_literal.ok()) { return Constant(BytesConstant(*bytes_literal)); } } return Constant(BytesConstant(value)); } case ConstantKindCase::kString: return Constant(StringConstant(value)); case ConstantKindCase::kDuration: { // Duration is deprecated as a builtin type, but still supported for // compatibility. absl::Duration duration_value; if (!absl::ParseDuration(value, &duration_value)) { return YamlError(yaml, node, "Failed to parse duration constant"); } return Constant(duration_value); } case ConstantKindCase::kTimestamp: { // Timestamp is deprecated as a builtin type, but still supported for // compatibility. absl::Time timestamp_value; std::string error; // Format: YYYY-MM-DDThh:mm:ssZ if (!absl::ParseTime("%Y-%m-%d%ET%H:%M:%E*SZ", value, ×tamp_value, &error)) { return YamlError( yaml, node, absl::StrCat("Failed to parse timestamp constant: ", error, " supported format: YYYY-MM-DDThh:mm:ssZ")); } return Constant(timestamp_value); } default: // This should never happen. return YamlError(yaml, node, "Constant type is not supported"); } } absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node variables = root["variables"]; if (!variables.IsDefined()) { return absl::OkStatus(); } if (!variables.IsSequence()) { return YamlError(yaml, variables, "Node 'variables' is not a sequence"); } for (const YAML::Node& variable : variables) { Config::VariableConfig variable_config; if (!variable || !variable.IsMap()) { return YamlError(yaml, variable, "Variable is not a map"); } const YAML::Node name = variable["name"]; if (!name || !name.IsScalar()) { return YamlError(yaml, name, "Variable name is not a string"); } variable_config.name = GetString(yaml, name); const YAML::Node description = variable["description"]; if (description.IsDefined()) { if (!description.IsScalar()) { return YamlError(yaml, description, "Variable description is not a string"); } variable_config.description = GetString(yaml, description); } CEL_ASSIGN_OR_RETURN(auto type_info, ParseTypeInfo(variable, yaml)); ConstantKindCase constant_kind_case = GetConstantKindCase(type_info.name); std::string value_str; YAML::Node value = variable["value"]; if (value.IsDefined()) { if (constant_kind_case == ConstantKindCase::kUnspecified) { return YamlError(yaml, value, absl::StrCat("Constant type '", type_info.name, "' is not supported")); } if (!value.IsScalar()) { return YamlError(yaml, value, "Variable value is not a scalar"); } if (IsBinary(value)) { CEL_ASSIGN_OR_RETURN(value_str, GetBinary(yaml, value)); } else { value_str = GetString(yaml, value); } } variable_config.type_info = type_info; if (constant_kind_case != ConstantKindCase::kUnspecified && !value_str.empty()) { CEL_ASSIGN_OR_RETURN( variable_config.value, ParseConstantValue(yaml, value, constant_kind_case, value_str)); } else if (constant_kind_case == ConstantKindCase::kNull) { variable_config.value = Constant(nullptr); } CEL_RETURN_IF_ERROR(config.AddVariableConfig(variable_config)); } return absl::OkStatus(); } absl::StatusOr ParseFunctionOverloadConfig( absl::string_view yaml, const YAML::Node& overload) { Config::FunctionOverloadConfig overload_config; if (!overload || !overload.IsMap()) { return YamlError(yaml, overload, "Function overload is not a map"); } const YAML::Node id = overload["id"]; if (id.IsDefined()) { if (!id.IsScalar()) { return YamlError(yaml, id, "Function overload id is not a string"); } overload_config.overload_id = GetString(yaml, id); } const YAML::Node examples = overload["examples"]; if (examples.IsDefined()) { if (!examples.IsSequence()) { return YamlError(yaml, examples, "Function overload examples is not a sequence"); } for (const YAML::Node& example : examples) { if (!example.IsScalar()) { return YamlError(yaml, example, "Function overload example is not a string"); } overload_config.examples.push_back(GetString(yaml, example)); } } const YAML::Node target = overload["target"]; if (target.IsDefined()) { if (!target.IsMap()) { return YamlError(yaml, target, "Function overload target is not a map"); } CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, ParseTypeInfo(target, yaml)); overload_config.is_member_function = true; overload_config.parameters.push_back(type_info); } const YAML::Node args = overload["args"]; if (args.IsDefined()) { if (!args.IsSequence()) { return YamlError(yaml, args, "Function overload args is not a sequence"); } for (const YAML::Node& arg : args) { if (!arg.IsMap()) { return YamlError(yaml, arg, "Function overload arg is not a map"); } CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, ParseTypeInfo(arg, yaml)); overload_config.parameters.push_back(type_info); } } const YAML::Node return_type = overload["return"]; if (return_type.IsDefined()) { if (!return_type.IsMap()) { return YamlError(yaml, return_type, "Function overload return type is not a map"); } CEL_ASSIGN_OR_RETURN(overload_config.return_type, ParseTypeInfo(return_type, yaml)); } return overload_config; } absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node functions = root["functions"]; if (!functions.IsDefined()) { return absl::OkStatus(); } if (!functions.IsSequence()) { return YamlError(yaml, functions, "Node 'functions' is not a sequence"); } for (const YAML::Node& function : functions) { Config::FunctionConfig function_config; if (!function || !function.IsMap()) { return YamlError(yaml, function, "Function is not a map"); } const YAML::Node name = function["name"]; if (!name || !name.IsScalar()) { return YamlError(yaml, name, "Function name is not a string"); } function_config.name = GetString(yaml, name); const YAML::Node description = function["description"]; if (description.IsDefined()) { if (!description.IsScalar()) { return YamlError(yaml, description, "Function description is not a string"); } function_config.description = GetString(yaml, description); } const YAML::Node overloads = function["overloads"]; if (overloads.IsDefined()) { if (!overloads.IsSequence()) { return YamlError(yaml, overloads, "Function 'overloads' item is not a sequence"); } for (const YAML::Node& overload : overloads) { CEL_ASSIGN_OR_RETURN(Config::FunctionOverloadConfig overload_config, ParseFunctionOverloadConfig(yaml, overload)); function_config.overload_configs.push_back(std::move(overload_config)); } } CEL_RETURN_IF_ERROR(config.AddFunctionConfig(function_config)); } return absl::OkStatus(); } void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { const auto& container_config = env_config.GetContainerConfig(); if (container_config.IsEmpty()) { return; } out << YAML::Key << "container"; out << YAML::Value << YAML::DoubleQuoted << container_config.name; } void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { if (env_config.GetExtensionConfigs().empty()) { return; } // Sort the extensions to make the output deterministic. std::vector sorted_extensions = env_config.GetExtensionConfigs(); absl::c_sort(sorted_extensions, [](const Config::ExtensionConfig& a, const Config::ExtensionConfig& b) { return a.name < b.name; }); out << YAML::Key << "extensions"; out << YAML::Value << YAML::BeginSeq; for (const Config::ExtensionConfig& extension_config : sorted_extensions) { out << YAML::BeginMap; out << YAML::Key << "name"; out << YAML::Value << YAML::DoubleQuoted << extension_config.name; if (extension_config.version != Config::ExtensionConfig::kLatest) { out << YAML::Key << "version"; out << YAML::Value << extension_config.version; } out << YAML::EndMap; } out << YAML::EndSeq; } void EmitMacroList(YAML::Emitter& out, absl::string_view key, const absl::flat_hash_set& macros) { if (macros.empty()) { return; } out << YAML::Key << std::string(key); out << YAML::Value << YAML::BeginSeq; std::vector sorted_macros(macros.begin(), macros.end()); absl::c_sort(sorted_macros); for (const std::string& macro : sorted_macros) { out << YAML::Value << YAML::DoubleQuoted << macro; } out << YAML::EndSeq; } void EmitFunctionList( YAML::Emitter& out, absl::string_view key, const absl::flat_hash_set>& functions) { if (functions.empty()) { return; } // Build a map from function name to a vector of overload ids. // Using std::map ensures function names are sorted. std::map> function_overloads; for (const auto& pair : functions) { function_overloads[pair.first].push_back(pair.second); } out << YAML::Key << std::string(key) << YAML::Value << YAML::BeginSeq; for (auto const& [name, overloads] : function_overloads) { out << YAML::BeginMap; out << YAML::Key << "name"; out << YAML::Value << YAML::DoubleQuoted << name; // If the only overload is the empty string, it signifies that all overloads // of the function are included/excluded. In this case, we don't emit the // "overloads" key. Otherwise, emit the specific overloads. if (!(overloads.size() == 1 && overloads[0].empty())) { // Sort overloads for deterministic output. std::vector sorted_overloads = overloads; absl::c_sort(sorted_overloads); out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; for (const std::string& overload : sorted_overloads) { out << YAML::BeginMap; out << YAML::Key << "id"; out << YAML::Value << YAML::DoubleQuoted << overload; out << YAML::EndMap; } out << YAML::EndSeq; } out << YAML::EndMap; } out << YAML::EndSeq; } void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { const Config::StandardLibraryConfig& standard_library_config = env_config.GetStandardLibraryConfig(); if (standard_library_config.IsEmpty()) { return; } out << YAML::Key << "stdlib" << YAML::Value << YAML::BeginMap; if (standard_library_config.disable) { out << YAML::Key << "disable" << YAML::Value << true; } if (standard_library_config.disable_macros) { out << YAML::Key << "disable_macros" << YAML::Value << true; } EmitMacroList(out, "include_macros", standard_library_config.included_macros); EmitMacroList(out, "exclude_macros", standard_library_config.excluded_macros); EmitFunctionList(out, "include_functions", standard_library_config.included_functions); EmitFunctionList(out, "exclude_functions", standard_library_config.excluded_functions); out << YAML::EndMap; } void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out) { // Note: the map is already started when this is called, so we don't emit // BeginMap here or EndMap at the end. out << YAML::Key << "type_name"; out << YAML::Value << YAML::DoubleQuoted << type_info.name; if (type_info.is_type_param) { out << YAML::Key << "is_type_param" << YAML::Value << true; } if (!type_info.params.empty()) { out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; for (const Config::TypeInfo& param : type_info.params) { out << YAML::BeginMap; EmitTypeInfo(param, out); out << YAML::EndMap; } out << YAML::EndSeq; } } void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { const auto& variable_configs = env_config.GetVariableConfigs(); if (variable_configs.empty()) { return; } // Sort variable_configs by name to ensure deterministic output. std::vector sorted_variable_configs = variable_configs; absl::c_sort(sorted_variable_configs, [](const Config::VariableConfig& a, const Config::VariableConfig& b) { return a.name < b.name; }); out << YAML::Key << "variables"; out << YAML::Value << YAML::BeginSeq; for (const Config::VariableConfig& variable_config : sorted_variable_configs) { out << YAML::BeginMap; out << YAML::Key << "name"; out << YAML::Value << YAML::DoubleQuoted << variable_config.name; if (!variable_config.description.empty()) { out << YAML::Key << "description"; out << YAML::Value << YAML::DoubleQuoted << variable_config.description; } EmitTypeInfo(variable_config.type_info, out); if (variable_config.value.has_value()) { const Constant& constant = variable_config.value; switch (constant.kind_case()) { case ConstantKindCase::kUnspecified: case ConstantKindCase::kNull: break; case ConstantKindCase::kBool: out << YAML::Key << "value" << YAML::Value << constant.bool_value(); break; case ConstantKindCase::kInt: out << YAML::Key << "value" << YAML::Value << constant.int_value(); break; case ConstantKindCase::kUint: out << YAML::Key << "value" << YAML::Value << constant.uint_value(); break; case ConstantKindCase::kDouble: out << YAML::Key << "value" << YAML::Value << constant.double_value(); break; case ConstantKindCase::kBytes: { out << YAML::Key << "value"; const std::string& bytes_value = constant.bytes_value(); std::string hex_escaped = "b\""; for (unsigned char byte : bytes_value) { absl::StrAppend(&hex_escaped, "\\x"); absl::StrAppendFormat(&hex_escaped, "%02x", byte); } absl::StrAppend(&hex_escaped, "\""); out << YAML::Value << hex_escaped; break; } case ConstantKindCase::kString: out << YAML::Key << "value"; out << YAML::Value << YAML::DoubleQuoted << constant.string_value(); break; case ConstantKindCase::kDuration: out << YAML::Key << "value" << YAML::Value; // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) out << absl::FormatDuration(constant.duration_value()); break; case ConstantKindCase::kTimestamp: out << YAML::Key << "value" << YAML::Value; out << absl::FormatTime( "%Y-%m-%d%ET%H:%M:%E*SZ", // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) constant.timestamp_value(), absl::UTCTimeZone()); break; } } out << YAML::EndMap; } out << YAML::EndSeq; } void EmitFunctionOverloadConfig( const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out) { out << YAML::BeginMap; out << YAML::Key << "id"; out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; if (overload_config.is_member_function) { out << YAML::Key << "target" << YAML::Value; out << YAML::BeginMap; if (overload_config.parameters.empty()) { // This should never happen, but if it does, emit a dynamic type. EmitTypeInfo({.name = "dyn"}, out); } else { EmitTypeInfo(overload_config.parameters[0], out); } out << YAML::EndMap; if (overload_config.parameters.size() > 1) { out << YAML::Key << "args"; out << YAML::Value << YAML::BeginSeq; for (size_t i = 1; i < overload_config.parameters.size(); ++i) { out << YAML::BeginMap; EmitTypeInfo(overload_config.parameters[i], out); out << YAML::EndMap; } out << YAML::EndSeq; } } else { if (!overload_config.parameters.empty()) { out << YAML::Key << "args"; out << YAML::Value << YAML::BeginSeq; for (const Config::TypeInfo& parameter : overload_config.parameters) { out << YAML::BeginMap; EmitTypeInfo(parameter, out); out << YAML::EndMap; } out << YAML::EndSeq; } } out << YAML::Key << "return"; out << YAML::Value << YAML::BeginMap; EmitTypeInfo(overload_config.return_type, out); out << YAML::EndMap; out << YAML::EndMap; } void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { const std::vector& function_configs = env_config.GetFunctionConfigs(); if (function_configs.empty()) { return; } // Sort function_configs by name to ensure deterministic output. std::vector sorted_function_configs = function_configs; absl::c_sort(sorted_function_configs, [](const Config::FunctionConfig& a, const Config::FunctionConfig& b) { return a.name < b.name; }); out << YAML::Key << "functions"; out << YAML::Value << YAML::BeginSeq; for (const Config::FunctionConfig& function_config : sorted_function_configs) { out << YAML::BeginMap; out << YAML::Key << "name"; out << YAML::Value << YAML::DoubleQuoted << function_config.name; if (!function_config.description.empty()) { out << YAML::Key << "description"; out << YAML::Value << YAML::DoubleQuoted << function_config.description; } if (!function_config.overload_configs.empty()) { // Sort overloads for deterministic output. std::vector sorted_overloads = function_config.overload_configs; absl::c_sort(sorted_overloads, [](const Config::FunctionOverloadConfig& a, const Config::FunctionOverloadConfig& b) { for (size_t i = 0; i < a.parameters.size(); ++i) { // Order like this: foo(a), foo(a, b) if (i >= b.parameters.size()) { return false; } if (CompareTypeInfo(a.parameters[i], b.parameters[i])) { return true; } if (CompareTypeInfo(b.parameters[i], a.parameters[i])) { return false; } } return false; }); out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; for (const Config::FunctionOverloadConfig& overload_config : sorted_overloads) { EmitFunctionOverloadConfig(overload_config, out); } out << YAML::EndSeq; } out << YAML::EndMap; } out << YAML::EndSeq; } } // namespace absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { Config config; CEL_ASSIGN_OR_RETURN(YAML::Node root, LoadYaml(yaml)); if (!root.IsDefined() || root.IsNull()) { return config; } if (!root.IsMap()) { return absl::InvalidArgumentError(FormatYamlErrorMessage( yaml, "Invalid CEL environment config YAML", root.Mark())); } CEL_RETURN_IF_ERROR(ParseName(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); return config; } void EnvConfigToYaml(const Config& env_config, std::ostream& os) { YAML::Emitter out(os); out.SetIndent(2); out << YAML::BeginMap; if (!env_config.GetName().empty()) { out << YAML::Key << "name"; out << YAML::Value << YAML::DoubleQuoted << env_config.GetName(); } EmitContainerConfig(env_config, out); EmitExtensionConfigs(env_config, out); EmitStandardLibraryConfig(env_config, out); EmitVariableConfigs(env_config, out); EmitFunctionConfigs(env_config, out); out << YAML::EndMap; } } // namespace cel ================================================ FILE: env/env_yaml.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ #define THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ #include #include #include "absl/status/statusor.h" #include "env/config.h" namespace cel { // EnvConfigFromYaml creates an environment configuration from a YAML string. // // To ensure safety, only pass trusted YAML input. yaml-cpp has some fuzz // coverage, but its security model is unclear. Additionally, callers should be // aware that improper CEL configuration can lead to unsafe or unpredictably // expensive expressions. absl::StatusOr EnvConfigFromYaml(const std::string& yaml); // EnvConfigToYaml serializes an environment configuration as a YAML string. void EnvConfigToYaml(const Config& env_config, std::ostream& os); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ ================================================ FILE: env/env_yaml_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/env_yaml.h" #include #include #include #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/constant.h" #include "env/config.h" #include "internal/status_macros.h" #include "internal/testing.h" namespace cel { namespace { using ::absl_testing::StatusIs; using ::testing::AllOf; using ::testing::ElementsAreArray; using ::testing::Field; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::SizeIs; using ::testing::UnorderedElementsAre; TEST(EnvYamlTest, ParseContainerConfig) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( container: "test.container" )yaml")); EXPECT_THAT(config.GetContainerConfig(), Field(&Config::ContainerConfig::name, "test.container")); } TEST(EnvYamlTest, ParseExtensionConfigs) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( extensions: - name: "math" version: latest - name: "optional" version: 2 - name: "strings" )yaml")); EXPECT_THAT(config.GetExtensionConfigs(), UnorderedElementsAre( AllOf(Field(&Config::ExtensionConfig::name, "math"), Field(&Config::ExtensionConfig::version, Config::ExtensionConfig::kLatest)), AllOf(Field(&Config::ExtensionConfig::name, "optional"), Field(&Config::ExtensionConfig::version, 2)), AllOf(Field(&Config::ExtensionConfig::name, "strings"), Field(&Config::ExtensionConfig::version, Config::ExtensionConfig::kLatest)))); } TEST(EnvYamlTest, DefaultExtensionConfigs) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( )yaml")); EXPECT_THAT(config.GetExtensionConfigs(), IsEmpty()); } TEST(EnvYamlTest, ParseStdlibConfig_ExclusionStyle) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( stdlib: disable: true disable_macros: true exclude_macros: - map - filter exclude_functions: - name: "_+_" overloads: - id: add_bytes - id: add_list - name: "matches" - name: "timestamp" overloads: - id: "string_to_timestamp" )yaml")); const auto& stdlib_config = config.GetStandardLibraryConfig(); EXPECT_TRUE(stdlib_config.disable); EXPECT_TRUE(stdlib_config.disable_macros); EXPECT_THAT(stdlib_config.excluded_macros, UnorderedElementsAre("map", "filter")); EXPECT_THAT(stdlib_config.included_macros, IsEmpty()); EXPECT_THAT( stdlib_config.excluded_functions, UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), std::make_pair("_+_", "add_list"), std::make_pair("matches", ""), std::make_pair("timestamp", "string_to_timestamp"))) << " Actual stdlib config: " << stdlib_config; } TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( stdlib: include_macros: - map - filter include_functions: - name: "_+_" overloads: - id: add_bytes - id: add_list - name: "matches" - name: "timestamp" overloads: - id: "string_to_timestamp" )yaml")); const auto& stdlib_config = config.GetStandardLibraryConfig(); EXPECT_THAT(stdlib_config.included_macros, UnorderedElementsAre("map", "filter")); EXPECT_THAT( stdlib_config.included_functions, UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), std::make_pair("_+_", "add_list"), std::make_pair("matches", ""), std::make_pair("timestamp", "string_to_timestamp"))) << " Actual stdlib config: " << stdlib_config; } TEST(EnvYamlTest, ParseVariableConfigs) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( variables: - name: "msg" type_name: "google.expr.proto3.test.TestAllTypes" description: >- msg represents all possible type permutation which CEL understands from a proto perspective )yaml")); const Config::VariableConfig& variable_config = config.GetVariableConfigs()[0]; EXPECT_EQ(variable_config.name, "msg"); const auto& type_info = variable_config.type_info; EXPECT_EQ(type_info.name, "google.expr.proto3.test.TestAllTypes"); EXPECT_FALSE(type_info.is_type_param); EXPECT_THAT(type_info.params, IsEmpty()); EXPECT_EQ(variable_config.description, "msg represents all possible type permutation which CEL " "understands from a proto perspective"); } TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( variables: - name: "dict" type_name: "map" params: - type_name: "string" - type_name: "A" is_type_param: true )yaml")); const Config::VariableConfig& variable_config = config.GetVariableConfigs()[0]; EXPECT_EQ(variable_config.name, "dict"); const auto& type_info = variable_config.type_info; EXPECT_EQ(type_info.name, "map"); EXPECT_FALSE(type_info.is_type_param); EXPECT_THAT(type_info.params, SizeIs(2)); EXPECT_EQ(type_info.params[0].name, "string"); EXPECT_FALSE(type_info.params[0].is_type_param); EXPECT_THAT(type_info.params[0].params, IsEmpty()); EXPECT_EQ(type_info.params[1].name, "A"); EXPECT_TRUE(type_info.params[1].is_type_param); EXPECT_THAT(type_info.params[1].params, IsEmpty()); } struct ParseConstantTestCase { std::string type_name; std::string value; std::string expected_error; // Empty if no error. Constant expected_constant; }; class EnvYamlParseConstantTest : public testing::TestWithParam {}; TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { const ParseConstantTestCase& param = GetParam(); const std::string yaml = absl::StrFormat( R"yaml( variables: - name: "const" type_name: "%s" value: %s )yaml", param.type_name, param.value); absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); if (!param.expected_error.empty()) { EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); return; } ASSERT_OK_AND_ASSIGN(Config config, status_or_config); const Config::VariableConfig& variable_config = config.GetVariableConfigs()[0]; EXPECT_EQ(variable_config.name, "const"); EXPECT_EQ(variable_config.type_info.name, param.type_name) << " yaml: " << yaml; EXPECT_EQ(variable_config.value, param.expected_constant) << " yaml: " << yaml; } std::vector GetParseConstantTestCases() { return { ParseConstantTestCase{ .type_name = "null", .value = "\"\"", .expected_constant = Constant(nullptr), }, ParseConstantTestCase{ .type_name = "null", .value = "anything", .expected_error = "Failed to parse null constant", }, ParseConstantTestCase{ .type_name = "bool", .value = "TRUE", .expected_constant = Constant(true), }, ParseConstantTestCase{ .type_name = "bool", .value = "false", .expected_constant = Constant(false), }, ParseConstantTestCase{ .type_name = "bool", .value = "yes", .expected_error = "Failed to parse bool constant", }, ParseConstantTestCase{ .type_name = "int", .value = "42", .expected_constant = Constant(int64_t{42}), }, ParseConstantTestCase{ .type_name = "int", .value = "41.999", .expected_error = "Failed to parse int constant", }, ParseConstantTestCase{ .type_name = "uint", .value = "42", .expected_constant = Constant(uint64_t{42}), }, ParseConstantTestCase{ .type_name = "uint", .value = "42u", .expected_constant = Constant(uint64_t{42}), }, ParseConstantTestCase{ .type_name = "uint", .value = "-1", .expected_error = "Failed to parse uint constant", }, ParseConstantTestCase{ .type_name = "double", .value = "42.42", .expected_constant = Constant(42.42), }, ParseConstantTestCase{ .type_name = "double", .value = "abc", .expected_error = "Failed to parse double constant", }, ParseConstantTestCase{ .type_name = "bytes", .value = "abc", .expected_constant = Constant(BytesConstant("abc")), }, ParseConstantTestCase{ .type_name = "bytes", .value = "b\"\\xFF\\x00\\x01\"", .expected_constant = Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), }, ParseConstantTestCase{ .type_name = "bytes", .value = "!!binary /wAB", .expected_constant = Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), }, ParseConstantTestCase{ .type_name = "bytes", .value = "!!binary YWJj=", .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", }, ParseConstantTestCase{ .type_name = "bytes", .value = "abc", .expected_constant = Constant(BytesConstant("abc")), }, ParseConstantTestCase{ .type_name = "string", .value = "abc", .expected_constant = Constant(StringConstant("abc")), }, ParseConstantTestCase{ .type_name = "string", .value = "\"\\\"abc\\\"\"", .expected_constant = Constant(StringConstant("\"abc\"")), }, ParseConstantTestCase{ .type_name = "duration", .value = "1s", .expected_constant = Constant(absl::Seconds(1)), }, ParseConstantTestCase{ .type_name = "duration", .value = "abc", .expected_error = "Failed to parse duration constant", }, ParseConstantTestCase{ .type_name = "timestamp", .value = "2023-01-01T00:00:00Z", .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), }, ParseConstantTestCase{ .type_name = "timestamp", .value = "abc", .expected_error = "Failed to parse timestamp constant", }, }; } INSTANTIATE_TEST_SUITE_P(EnvYamlParseConstantTest, EnvYamlParseConstantTest, ::testing::ValuesIn(GetParseConstantTestCases())); struct ParseFunctionTestCase { std::string yaml; Config::FunctionConfig expected_function_config; }; class EnvYamlParseFunctionTest : public testing::TestWithParam {}; void ExpectTypeInfoEqual(const Config::TypeInfo& actual, const Config::TypeInfo& expected) { EXPECT_EQ(actual.name, expected.name); EXPECT_EQ(actual.is_type_param, expected.is_type_param); ASSERT_THAT(actual.params, SizeIs(expected.params.size())); for (size_t i = 0; i < expected.params.size(); ++i) { ExpectTypeInfoEqual(actual.params[i], expected.params[i]); } } TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { const ParseFunctionTestCase& param = GetParam(); ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.yaml)); ASSERT_THAT(config.GetFunctionConfigs(), SizeIs(1)); const Config::FunctionConfig& function_config = config.GetFunctionConfigs()[0]; const Config::FunctionConfig& expected = param.expected_function_config; EXPECT_EQ(function_config.name, expected.name); EXPECT_EQ(function_config.description, expected.description); ASSERT_THAT(function_config.overload_configs, SizeIs(expected.overload_configs.size())); for (size_t i = 0; i < expected.overload_configs.size(); ++i) { const auto& actual_overload = function_config.overload_configs[i]; const auto& expected_overload = expected.overload_configs[i]; EXPECT_EQ(actual_overload.overload_id, expected_overload.overload_id); EXPECT_THAT(actual_overload.examples, ElementsAreArray(expected_overload.examples)); EXPECT_EQ(actual_overload.is_member_function, expected_overload.is_member_function); ASSERT_THAT(actual_overload.parameters, SizeIs(expected_overload.parameters.size())); for (size_t j = 0; j < expected_overload.parameters.size(); ++j) { ExpectTypeInfoEqual(actual_overload.parameters[j], expected_overload.parameters[j]); } ExpectTypeInfoEqual(actual_overload.return_type, expected_overload.return_type); } } std::vector GetParseFunctionTestCases() { return { ParseFunctionTestCase{ .yaml = R"yaml( functions: - name: "isEmpty" description: |- determines whether a list is empty, or a string has no characters overloads: - id: "wrapper_string_isEmpty" examples: - "''.isEmpty() // true" target: type_name: "google.protobuf.StringValue" return: type_name: "bool" - id: "list_isEmpty" examples: - "[].isEmpty() // true" - "[1].isEmpty() // false" target: type_name: "list" params: - type_name: "T" is_type_param: true return: type_name: "bool" )yaml", .expected_function_config = { .name = "isEmpty", .description = "determines whether a list is empty,\nor a " "string has no characters", .overload_configs = { Config::FunctionOverloadConfig{ .overload_id = "wrapper_string_isEmpty", .examples = {"''.isEmpty() // true"}, .is_member_function = true, .parameters = {{.name = "google.protobuf.StringValue"}}, .return_type = {.name = "bool"}, }, Config::FunctionOverloadConfig{ .overload_id = "list_isEmpty", .examples = {"[].isEmpty() // true", "[1].isEmpty() // false"}, .is_member_function = true, .parameters = {{.name = "list", .params = {{.name = "T", .is_type_param = true}}}}, .return_type = {.name = "bool"}, }, }, }, }, ParseFunctionTestCase{ .yaml = R"yaml( functions: - name: "contains" overloads: - id: "global_contains" examples: - "contains([1, 2, 3], 2) // true" args: - type_name: "list" params: - type_name: "T" is_type_param: true - type_name: "T" is_type_param: true return: type_name: "bool" )yaml", .expected_function_config = { .name = "contains", .overload_configs = { Config::FunctionOverloadConfig{ .overload_id = "global_contains", .examples = {"contains([1, 2, 3], 2) // true"}, .is_member_function = false, .parameters = {{.name = "list", .params = {{.name = "T", .is_type_param = true}}}, {.name = "T", .is_type_param = true}}, .return_type = {.name = "bool"}, }, }, }, }, }; } INSTANTIATE_TEST_SUITE_P(EnvYamlParseFunctionTest, EnvYamlParseFunctionTest, ::testing::ValuesIn(GetParseFunctionTestCases())); struct ParseTestCase { std::string yaml; std::string expected_error; }; class EnvYamlParseTest : public testing::TestWithParam {}; TEST_P(EnvYamlParseTest, EnvYamlSyntaxError) { const ParseTestCase& param = GetParam(); absl::StatusOr config = EnvConfigFromYaml(param.yaml); EXPECT_THAT(config, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); } INSTANTIATE_TEST_SUITE_P( EnvYamlParseTest, EnvYamlParseTest, ::testing::Values( ParseTestCase{ .yaml = R"yaml( invalid yaml )yaml", .expected_error = "1:2: Invalid CEL environment config YAML\n" "| invalid yaml \n" "| ^", }, ParseTestCase{ .yaml = R"yaml( name: - error: "error" )yaml", .expected_error = "3:19: Node 'name' is not a string\n" "| - error: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( container: - error: "error" )yaml", .expected_error = "3:19: Node 'container' is not a string\n" "| - error: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( extensions: - name: "math" -name: "optional" - name: "other" )yaml", .expected_error = "5:21: end of map not found\n" "| - name: \"other\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( extensions: "bar" )yaml", .expected_error = "2:27: Node 'extensions' is not a sequence\n" "| extensions: \"bar\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( extensions: - name: - something: "bar" )yaml", .expected_error = "4:19: Extension name is not a string\n" "| - something: \"bar\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( extensions: - name: "math" version: last )yaml", .expected_error = "4:28: Extension 'math' version is not a valid " "number or 'latest'\n" "| version: last\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( extensions: - name: "math" version: -15 )yaml", .expected_error = "4:28: Extension 'math' version is not a valid " "number or 'latest'\n" "| version: -15\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( extensions: - name: "math" version: 1 - name: "math" version: 2 )yaml", .expected_error = "5:19: Extension 'math' version 1 is already " "included. Cannot also include version 2\n" "| - name: \"math\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: "error" )yaml", .expected_error = "2:23: Standard library config ('stdlib') " "is not a map\n" "| stdlib: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: disable: "error" )yaml", .expected_error = "3:26: Node 'disable' is not a boolean\n" "| disable: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: disable_macros: "error" )yaml", .expected_error = "3:33: Node 'disable_macros' is not a boolean\n" "| disable_macros: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: exclude_macros: "error" )yaml", .expected_error = "3:33: Node 'exclude_macros' is not a sequence\n" "| exclude_macros: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: exclude_macros: - foo: "error" )yaml", .expected_error = "4:19: Entry in 'exclude_macros' " "is not a string\n" "| - foo: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: include_functions: "error" )yaml", .expected_error = "3:36: Node 'include_functions' " "is not a sequence\n" "| include_functions: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: include_functions: - "error" )yaml", .expected_error = "4:19: Entry in 'include_functions' " "is not a map\n" "| - \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: include_functions: - foo: "error" )yaml", .expected_error = "4:19: Function name in not specified in " "'include_functions'\n" "| - foo: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: include_functions: - name: "foo" overloads: "error" )yaml", .expected_error = "5:30: Overloads in 'include_functions' entry " "is not a sequence\n" "| overloads: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: include_functions: - name: "foo" overloads: - foo_string )yaml", .expected_error = "6:21: Overload in 'include_functions' entry " "is not a map\n" "| - foo_string\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( stdlib: include_functions: - name: "foo" overloads: - id: - foo_int64 )yaml", .expected_error = "7:21: Overload id in 'include_functions' entry " "is not a string\n" "| - foo_int64\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( variables: - name: - type_name: "opaque" )yaml", .expected_error = "4:19: Variable name is not a string\n" "| - type_name: \"opaque\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( variables: - name: "foo" type_name: - params: )yaml", .expected_error = "5:21: Node 'type_name' is not a string\n" "| - params:\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( variables: - name: "foo" type_name: "opaque" params: - type_name: "int" - type_name: "A" is_type_param: maybe )yaml", .expected_error = "8:38: Node 'is_type_param' is not a boolean\n" "| is_type_param: maybe\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( variables: - name: "foo" type_name: "uint" value: -1 )yaml", .expected_error = "5:26: Failed to parse uint constant\n" "| value: -1\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( functions: many )yaml", .expected_error = "2:26: Node 'functions' is not a sequence\n" "| functions: many\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( functions: - name: - overloads: )yaml", .expected_error = "4:19: Function name is not a string\n" "| - overloads:\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( functions: - name: "foo" overloads: "error" )yaml", .expected_error = "4:30: Function 'overloads' item " "is not a sequence\n" "| overloads: \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( functions: - name: "foo" overloads: - id: - "error" )yaml", .expected_error = "6:25: Function overload id is not a string\n" "| - \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( functions: - name: "foo" overloads: - id: "foo_int64" target: - "error" )yaml", .expected_error = "7:25: Function overload target is not a map\n" "| - \"error\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( functions: - name: "foo" overloads: - id: "foo_int64" target: type_name: "Foo" params: - type_name: - is_type_param: true )yaml", .expected_error = "10:31: Node 'type_name' is not a string\n" "| " "- is_type_param: true\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( functions: - name: "foo" overloads: - id: "foo_int64" args: "a bunch" )yaml", .expected_error = "6:29: Function overload args is not a sequence\n" "| args: \"a bunch\"\n" "| ^", }, ParseTestCase{ .yaml = R"yaml( functions: - name: "foo" overloads: - id: "foo_int64" return: "to sender" )yaml", .expected_error = "6:31: Function overload return type" " is not a map\n" "| return: \"to sender\"\n" "| ^", })); std::string Unindent(std::string_view yaml) { absl::string_view yaml_view = yaml; std::vector lines = absl::StrSplit(yaml_view, '\n'); int indent = -1; std::vector unindented_lines; for (auto& line : lines) { std::size_t pos = line.find_first_not_of(" \t"); if (pos == std::string::npos) { // Skip blank lines. continue; } if (indent == -1) { indent = pos; } if (pos >= indent) { unindented_lines.push_back(line.substr(indent)); } else { unindented_lines.push_back(line); } } return absl::StrJoin(unindented_lines, "\n"); } struct ExportTestCase { absl::StatusOr config; std::string expected_yaml; }; class EnvYamlExportTest : public testing::TestWithParam {}; TEST_P(EnvYamlExportTest, EnvYamlExport) { const ExportTestCase& param = GetParam(); ASSERT_OK_AND_ASSIGN(Config config, param.config); std::stringstream ss; EnvConfigToYaml(config, ss); std::string yaml_output = Unindent(ss.str()); std::string expected_yaml = Unindent(param.expected_yaml); EXPECT_EQ(yaml_output, expected_yaml); } std::vector GetExportTestCases() { return { ExportTestCase{ .config = []() { Config config; config.SetName("test.env"); config.SetContainerConfig({.name = "test.container"}); return config; }(), .expected_yaml = R"yaml( name: "test.env" container: "test.container" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; config.SetName("test.env"); config.SetContainerConfig({.name = "test.container"}); return config; }(), .expected_yaml = R"yaml( name: "test.env" container: "test.container" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddExtensionConfig("math")); CEL_RETURN_IF_ERROR(config.AddExtensionConfig("optional", 2)); CEL_RETURN_IF_ERROR(config.AddExtensionConfig("bindings")); return config; }(), .expected_yaml = R"yaml( extensions: - name: "bindings" - name: "math" - name: "optional" version: 2 )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ .disable = true, })); return config; }(), .expected_yaml = R"yaml( stdlib: disable: true )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ .disable_macros = true, })); return config; }(), .expected_yaml = R"yaml( stdlib: disable_macros: true )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ .excluded_macros = {"map", "filter"}, })); return config; }(), .expected_yaml = R"yaml( stdlib: exclude_macros: - "filter" - "map" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ .included_macros = {"map", "filter"}, })); return config; }(), .expected_yaml = R"yaml( stdlib: include_macros: - "filter" - "map" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ .excluded_functions = { std::make_pair("timestamp", "string_to_timestamp"), std::make_pair("_+_", "add_list"), std::make_pair("matches", ""), std::make_pair("_+_", "add_bytes"), }, })); return config; }(), .expected_yaml = R"yaml( stdlib: exclude_functions: - name: "_+_" overloads: - id: "add_bytes" - id: "add_list" - name: "matches" - name: "timestamp" overloads: - id: "string_to_timestamp" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ .included_functions = { std::make_pair("timestamp", "string_to_timestamp"), std::make_pair("_+_", "add_list"), std::make_pair("matches", ""), std::make_pair("_+_", "add_bytes"), }, })); return config; }(), .expected_yaml = R"yaml( stdlib: include_functions: - name: "_+_" overloads: - id: "add_bytes" - id: "add_list" - name: "matches" - name: "timestamp" overloads: - id: "string_to_timestamp" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.AddVariableConfig({.name = "foo", .type_info = {.name = "null"}, .value = Constant(nullptr)})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "null" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.AddVariableConfig({.name = "foo", .type_info = {.name = "bool"}, .value = Constant(true)})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "bool" value: true )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.AddVariableConfig({.name = "foo", .type_info = {.name = "int"}, .value = Constant(int64_t{42})})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "int" value: 42 )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.AddVariableConfig({.name = "foo", .type_info = {.name = "uint"}, .value = Constant(uint64_t{777})})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "uint" value: 777 )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR( config.AddVariableConfig({.name = "foo", .type_info = {.name = "double"}, .value = Constant(0.75)})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "double" value: 0.75 )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddVariableConfig( {.name = "foo", .type_info = {.name = "bytes"}, .value = Constant( BytesConstant(absl::string_view("\xff\x00\x01", 3)))})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "bytes" value: b"\xff\x00\x01" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; Constant c; c.set_string_value("'single' \"double\""); CEL_RETURN_IF_ERROR(config.AddVariableConfig( {.name = "foo", .type_info = {.name = "string"}, .value = Constant(StringConstant("'single' \"double\""))})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "string" value: "'single' \"double\"" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddVariableConfig( {.name = "foo", .type_info = {.name = "duration"}, .value = Constant(absl::Hours(1) + absl::Minutes(2) + absl::Seconds(3))})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "duration" value: 1h2m3s )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddVariableConfig( {.name = "foo", .type_info = {.name = "timestamp"}, .value = Constant(absl::FromUnixSeconds(1767323045))})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "timestamp" value: 2026-01-02T03:04:05Z )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddVariableConfig( {.name = "foo", .type_info = {.name = "google.expr.proto3.test.TestAllTypes"}})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "google.expr.proto3.test.TestAllTypes" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddVariableConfig( {.name = "foo", .type_info = { .name = "A", .params = {{.name = "int"}, {.name = "B", .is_type_param = true}}}})); return config; }(), .expected_yaml = R"yaml( variables: - name: "foo" type_name: "A" params: - type_name: "int" - type_name: "B" is_type_param: true )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddFunctionConfig({.name = "foo"})); return config; }(), .expected_yaml = R"yaml( functions: - name: "foo" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddFunctionConfig( {.name = "foo", .overload_configs = { {.overload_id = "foo_overload_id", .is_member_function = true, .parameters = {{.name = "timestamp"}, {.name = "A", .params = {{.name = "B"}}}}, .return_type = {.name = "int"}}, }})); return config; }(), .expected_yaml = R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_id" target: type_name: "timestamp" args: - type_name: "A" params: - type_name: "B" return: type_name: "int" )yaml", }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; CEL_RETURN_IF_ERROR(config.AddFunctionConfig( {.name = "foo", .overload_configs = { {.overload_id = "foo_overload_a", .parameters = {{.name = "timestamp"}}, .return_type = {.name = "list", .params = {{.name = "int"}}}}, {.overload_id = "foo_overload_b", .parameters = {{.name = "double"}, {.name = "A", .params = {{.name = "B"}}}}, .return_type = {.name = "string"}}, }})); return config; }(), .expected_yaml = R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_b" args: - type_name: "double" - type_name: "A" params: - type_name: "B" return: type_name: "string" - id: "foo_overload_a" args: - type_name: "timestamp" return: type_name: "list" params: - type_name: "int" )yaml", }, }; }; INSTANTIATE_TEST_SUITE_P(EnvYamlExportTest, EnvYamlExportTest, ::testing::ValuesIn(GetExportTestCases())); class EnvYamlRoundTripTest : public testing::TestWithParam {}; TEST_P(EnvYamlRoundTripTest, EnvYamlRoundTrip) { const std::string& yaml = Unindent(GetParam()); ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); std::stringstream ss; EnvConfigToYaml(config, ss); EXPECT_EQ(ss.str(), yaml); } std::vector GetRoundTripTestCases() { return { R"yaml( stdlib: disable: true disable_macros: true )yaml", R"yaml( name: "test.env" container: "common.proto.prefix" extensions: - name: "math" version: 0 - name: "optional" version: 2 stdlib: include_macros: - "filter" - "map" include_functions: - name: "_+_" overloads: - id: "add_bytes" - id: "add_list" - name: "matches" - name: "timestamp" overloads: - id: "string_to_timestamp" )yaml", R"yaml( extensions: - name: "bindings" - name: "math" stdlib: exclude_macros: - "filter" - "map" exclude_functions: - name: "_+_" overloads: - id: "add_bytes" - id: "add_list" - name: "matches" - name: "timestamp" overloads: - id: "string_to_timestamp" )yaml", R"yaml( variables: - name: "a" type_name: "null" - name: "b" type_name: "bool" value: true - name: "c" type_name: "int" value: 42 - name: "d" type_name: "uint" value: 777 - name: "e" type_name: "double" value: 0.75 - name: "f" type_name: "bytes" value: b"\xff\x00\x01" - name: "g" type_name: "string" value: "plain 'single' \"double\"" - name: "h" type_name: "duration" value: 1h2m3s - name: "i" type_name: "timestamp" value: 2026-01-02T03:04:05Z )yaml", R"yaml( functions: - name: "bar" - name: "foo" )yaml", R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_id" target: type_name: "timestamp" args: - type_name: "A" params: - type_name: "B" return: type_name: "int" )yaml", R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_id" args: - type_name: "timestamp" - type_name: "A" params: - type_name: "B" return: type_name: "list" params: - type_name: "int" )yaml", }; } INSTANTIATE_TEST_SUITE_P(EnvYamlRoundTripTest, EnvYamlRoundTripTest, ::testing::ValuesIn(GetRoundTripTestCases())); } // namespace } // namespace cel ================================================ FILE: env/internal/BUILD ================================================ # Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) cc_library( name = "ext_registry", srcs = ["ext_registry.cc"], hdrs = ["ext_registry.h"], deps = [ "//compiler", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) cc_library( name = "runtime_ext_registry", srcs = ["runtime_ext_registry.cc"], hdrs = ["runtime_ext_registry.h"], deps = [ "//runtime:runtime_builder", "//runtime:runtime_options", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) cc_test( name = "ext_registry_test", srcs = ["ext_registry_test.cc"], deps = [ ":ext_registry", "//checker:type_checker_builder", "//compiler", "//internal:testing", "//parser:parser_interface", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", ], ) cc_test( name = "runtime_ext_registry_test", srcs = ["runtime_ext_registry_test.cc"], deps = [ ":runtime_ext_registry", "//common:ast", "//common:source", "//common:value", "//common:value_testing", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "//parser:options", "//parser:parser_interface", "//runtime", "//runtime:activation", "//runtime:function", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_builder", "//runtime:runtime_builder_factory", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: env/internal/ext_registry.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/internal/ext_registry.h" #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "compiler/compiler.h" namespace cel { namespace env_internal { void ExtensionRegistry::RegisterCompilerLibrary( absl::string_view name, absl::string_view alias, int version, absl::AnyInvocable library_factory) { library_registry_.push_back( LibraryRegistration(name, alias, version, std::move(library_factory))); } absl::StatusOr ExtensionRegistry::GetCompilerLibrary( absl::string_view name, int version) const { if (version == kLatest) { int max_version = -1; for (const auto& registration : library_registry_) { if ((registration.name_ == name || registration.alias_ == name) && registration.version_ > max_version) { max_version = registration.version_; } } if (max_version == -1) { return absl::NotFoundError( absl::StrCat("CompilerLibrary not registered: ", name)); } version = max_version; } for (const auto& registration : library_registry_) { if ((registration.name_ == name || registration.alias_ == name) && registration.version_ == version) { return registration.GetLibrary(); } } return absl::NotFoundError( absl::StrCat("CompilerLibrary not registered: ", name, "#", version)); } } // namespace env_internal } // namespace cel ================================================ FILE: env/internal/ext_registry.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ #include #include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "compiler/compiler.h" namespace cel { namespace env_internal { // A registry for CEL compiler extension libraries. // // Used to register and retrieve CompilerLibraries by name (or alias) and // version. class ExtensionRegistry { public: static constexpr int kLatest = std::numeric_limits::max(); void RegisterCompilerLibrary( absl::string_view name, absl::string_view alias, int version, absl::AnyInvocable library_factory); absl::StatusOr GetCompilerLibrary(absl::string_view name, int version) const; private: class LibraryRegistration final { public: LibraryRegistration( absl::string_view name, absl::string_view alias, int version, absl::AnyInvocable library_factory) : name_(name), alias_(!alias.empty() ? alias : name), version_(version), factory_(std::move(library_factory)) {} CompilerLibrary GetLibrary() const { return factory_(); } private: std::string name_; std::string alias_; int version_; absl::AnyInvocable factory_; friend class ExtensionRegistry; }; std::vector library_registry_; }; } // namespace env_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ ================================================ FILE: env/internal/ext_registry_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/internal/ext_registry.h" #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "checker/type_checker_builder.h" #include "compiler/compiler.h" #include "internal/testing.h" #include "parser/parser_interface.h" namespace cel::env_internal { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::testing::Field; using ::testing::HasSubstr; TEST(ExtensionRegistryTest, GetCompilerLibrary) { ExtensionRegistry registry; registry.RegisterCompilerLibrary("foo1", "f", 1, []() { return CompilerLibrary("foo1_1", nullptr, nullptr); }); registry.RegisterCompilerLibrary("foo1", "f", 2, []() { return CompilerLibrary("foo1_2", nullptr, nullptr); }); registry.RegisterCompilerLibrary("foo2", "", 1, []() { return CompilerLibrary("foo2_1", nullptr, nullptr); }); EXPECT_THAT(registry.GetCompilerLibrary("foo1", 1), IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); EXPECT_THAT(registry.GetCompilerLibrary("f", 1), IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); EXPECT_THAT(registry.GetCompilerLibrary("foo1", 2), IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); EXPECT_THAT(registry.GetCompilerLibrary("foo1", ExtensionRegistry::kLatest), IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); EXPECT_THAT(registry.GetCompilerLibrary("f", ExtensionRegistry::kLatest), IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); EXPECT_THAT(registry.GetCompilerLibrary("foo2", 1), IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); EXPECT_THAT(registry.GetCompilerLibrary("foo2", ExtensionRegistry::kLatest), IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); EXPECT_THAT(registry.GetCompilerLibrary("foo1", 3), StatusIs(absl::StatusCode::kNotFound, HasSubstr("CompilerLibrary not registered: foo1#3"))); EXPECT_THAT(registry.GetCompilerLibrary("foo3", 1), StatusIs(absl::StatusCode::kNotFound, HasSubstr("CompilerLibrary not registered: foo3"))); EXPECT_THAT(registry.GetCompilerLibrary("foo3", ExtensionRegistry::kLatest), StatusIs(absl::StatusCode::kNotFound, HasSubstr("CompilerLibrary not registered: foo3"))); } } // namespace } // namespace cel::env_internal ================================================ FILE: env/internal/runtime_ext_registry.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/internal/runtime_ext_registry.h" #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" namespace cel { namespace env_internal { void RuntimeExtensionRegistry::AddFunctionRegistration( absl::string_view name, absl::string_view alias, int version, FunctionRegistrationCallback function_registration_callback) { registry_.push_back(Registration(name, alias, version, std::move(function_registration_callback))); } absl::Status RuntimeExtensionRegistry::RegisterExtensionFunctions( RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options, absl::string_view name, int version) const { if (version == kLatest) { int max_version = -1; for (const Registration& registration : registry_) { if ((registration.name_ == name || registration.alias_ == name) && registration.version_ > max_version) { max_version = registration.version_; } } if (max_version == -1) { return absl::NotFoundError(absl::StrCat( "Runtime functions are not registered for extension: ", name)); } version = max_version; } for (const Registration& registration : registry_) { if ((registration.name_ == name || registration.alias_ == name) && registration.version_ == version) { return registration.RegisterExtensionFunctions(runtime_builder, runtime_options); } } return absl::NotFoundError(absl::StrCat( "Runtime functions are not registered for extension: ", name)); } } // namespace env_internal } // namespace cel ================================================ FILE: env/internal/runtime_ext_registry.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ #include #include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" namespace cel { namespace env_internal { using FunctionRegistrationCallback = absl::AnyInvocable; // A registry for CEL runtime extension functions. // // Used to register runtime functions for extensions by name (or alias) and // version. class RuntimeExtensionRegistry { public: static constexpr int kLatest = std::numeric_limits::max(); void AddFunctionRegistration( absl::string_view name, absl::string_view alias, int version, FunctionRegistrationCallback function_registration_callback); absl::Status RegisterExtensionFunctions(RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options, absl::string_view name, int version) const; private: class Registration final { public: Registration(absl::string_view name, absl::string_view alias, int version, FunctionRegistrationCallback function_registration_callback) : name_(name), alias_(!alias.empty() ? alias : name), version_(version), function_registration_callback_( std::move(function_registration_callback)) {} absl::Status RegisterExtensionFunctions( RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) const { return function_registration_callback_(runtime_builder, runtime_options); } private: std::string name_; std::string alias_; int version_; FunctionRegistrationCallback function_registration_callback_; friend class RuntimeExtensionRegistry; }; std::vector registry_; }; } // namespace env_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ ================================================ FILE: env/internal/runtime_ext_registry_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/internal/runtime_ext_registry.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/ast.h" #include "common/source.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/options.h" #include "parser/parser.h" #include "parser/parser_interface.h" #include "runtime/activation.h" #include "runtime/function.h" #include "runtime/function_adapter.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_builder_factory.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace cel::env_internal { namespace { using ::absl_testing::IsOkAndHolds; using ::cel::test::StringValueIs; Value Hello1(const StringValue& input, const Function::InvokeContext& context) { return StringValue::From("Hello, old " + input.ToString() + "!", context.arena()); } Value Hello2(const StringValue& input, const Function::InvokeContext& context) { return StringValue::From("Hello, new " + input.ToString() + "!", context.arena()); } RuntimeExtensionRegistry GetRuntimeExtensionRegistry() { RuntimeExtensionRegistry registry; registry.AddFunctionRegistration( "hello_extension", "hello_extension_alias", 1, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { return cel::UnaryFunctionAdapter:: RegisterGlobalOverload("hello", &Hello1, runtime_builder.function_registry()); }); registry.AddFunctionRegistration( "hello_extension", "hello_extension_alias", 2, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { return cel::UnaryFunctionAdapter:: RegisterMemberOverload("hello", &Hello2, runtime_builder.function_registry()); }); return registry; } class RuntimeExtensionRegistryTest : public testing::Test { protected: absl::StatusOr Run(std::string_view extension_name, int version, std::string_view expr) { const RuntimeExtensionRegistry registry = GetRuntimeExtensionRegistry(); CEL_ASSIGN_OR_RETURN(std::unique_ptr parser, NewParserBuilder(ParserOptions())->Build()); CEL_ASSIGN_OR_RETURN(std::unique_ptr source, NewSource(expr, "")); CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, parser->Parse(*source)); auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); cel::RuntimeOptions runtime_options; CEL_ASSIGN_OR_RETURN( cel::RuntimeBuilder runtime_builder, cel::CreateRuntimeBuilder(descriptor_pool, runtime_options)); CEL_RETURN_IF_ERROR(registry.RegisterExtensionFunctions( runtime_builder, runtime_options, extension_name, version)); CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, std::move(runtime_builder).Build()); CEL_ASSIGN_OR_RETURN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); Activation activation; return program->Evaluate(&arena_, activation); } private: google::protobuf::Arena arena_; }; TEST_F(RuntimeExtensionRegistryTest, SpecificExtensionVersion) { EXPECT_THAT(Run("hello_extension", 1, "hello('world')"), IsOkAndHolds(StringValueIs("Hello, old world!"))); } TEST_F(RuntimeExtensionRegistryTest, LatestExtensionVersion) { EXPECT_THAT(Run("hello_extension_alias", RuntimeExtensionRegistry::kLatest, "'world'.hello()"), IsOkAndHolds(StringValueIs("Hello, new world!"))); } } // namespace } // namespace cel::env_internal ================================================ FILE: env/runtime_std_extensions.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/runtime_std_extensions.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "checker/optional.h" #include "env/env_runtime.h" #include "env/internal/runtime_ext_registry.h" #include "extensions/encoders.h" #include "extensions/lists_functions.h" #include "extensions/math_ext.h" #include "extensions/math_ext_decls.h" #include "extensions/regex_ext.h" #include "extensions/sets_functions.h" #include "extensions/strings.h" #include "runtime/optional_types.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" namespace cel { void RegisterStandardExtensions(EnvRuntime& env_runtime) { env_internal::RuntimeExtensionRegistry& registry = env_runtime.GetRuntimeExtensionRegistry(); registry.AddFunctionRegistration( "cel.lib.ext.bindings", "bindings", 0, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { // No runtime functions to register. return absl::OkStatus(); }); registry.AddFunctionRegistration( "cel.lib.ext.encoders", "encoders", 0, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { return cel::extensions::RegisterEncodersFunctions( runtime_builder.function_registry(), runtime_options); }); for (int version = 0; version <= extensions::kListsExtensionLatestVersion; ++version) { registry.AddFunctionRegistration( "cel.lib.ext.lists", "lists", version, [version](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { return cel::extensions::RegisterListsFunctions( runtime_builder.function_registry(), runtime_options, version); }); } for (int version = 0; version <= extensions::kMathExtensionLatestVersion; ++version) { registry.AddFunctionRegistration( "cel.lib.ext.math", "math", version, [version](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { return cel::extensions::RegisterMathExtensionFunctions( runtime_builder.function_registry(), runtime_options, version); }); } for (int version = 0; version <= cel::kOptionalExtensionLatestVersion; ++version) { registry.AddFunctionRegistration( "cel.lib.ext.optional", "optional", version, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { return cel::extensions::EnableOptionalTypes(runtime_builder); }); } registry.AddFunctionRegistration( "cel.lib.ext.protos", "protos", 0, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { // No runtime functions to register. return absl::OkStatus(); }); registry.AddFunctionRegistration( "cel.lib.ext.sets", "sets", 0, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { return cel::extensions::RegisterSetsFunctions( runtime_builder.function_registry(), runtime_options); }); for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; ++version) { registry.AddFunctionRegistration( "cel.lib.ext.strings", "strings", version, [version](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { cel::extensions::StringsExtensionOptions strings_options; strings_options.version = version; return cel::extensions::RegisterStringsFunctions( runtime_builder.function_registry(), runtime_options, strings_options); }); } registry.AddFunctionRegistration( "cel.lib.ext.comprev2", "two-var-comprehensions", 0, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { // No runtime functions to register. return absl::OkStatus(); }); registry.AddFunctionRegistration( "cel.lib.ext.regex", "regex", 0, [](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { return cel::extensions::RegisterRegexExtensionFunctions( runtime_builder); }); } } // namespace cel ================================================ FILE: env/runtime_std_extensions.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ #define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ #include "env/env_runtime.h" namespace cel { // Registers the standard CEL extension functions with the given environment // runtime. This makes them available, but does not enable them. See Env::Config // for how to enable extensions. // // Included in the standard runtime environment: // // - cel.lib.ext.bindings (alias: "bindings") // - cel.lib.ext.encoders (alias: "encoders") // - cel.lib.ext.lists (alias: "lists") // - cel.lib.ext.math (alias: "math") // - optional // - cel.lib.ext.protos (alias: "protos") // - cel.lib.ext.sets (alias: "sets") // - cel.lib.ext.strings (alias: "strings") // - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") // // NOTE: Not included in the standard runtime environment yet - include manually // if needed: // - cel.lib.ext.regex (alias: "regex") // void RegisterStandardExtensions(EnvRuntime& env_runtime); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ ================================================ FILE: env/runtime_std_extensions_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/runtime_std_extensions.h" #include #include #include #include #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "checker/optional.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/value.h" #include "compiler/compiler.h" #include "env/config.h" #include "env/env.h" #include "env/env_runtime.h" #include "env/env_std_extensions.h" #include "extensions/lists_functions.h" #include "extensions/math_ext_decls.h" #include "extensions/strings.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::testing::IsEmpty; using ::testing::ValuesIn; struct TestCase { std::string extension_name; std::vector extension_versions = {0}; int latest_extension_version = 0; std::string expr; bool requires_optional_extension = false; }; using RuntimeStdExtensionTest = testing::TestWithParam; TEST_P(RuntimeStdExtensionTest, RegisterStandardExtensions) { const TestCase& param = GetParam(); Env env; env.SetDescriptorPool(cel::internal::GetSharedTestingDescriptorPool()); RegisterStandardExtensions(env); Config compiler_config; // For the compilation step, assume latest version of the extension to ensure // a successful compilation. Later, we will test the runtime with different // extension versions. ASSERT_THAT(compiler_config.AddExtensionConfig( param.extension_name, Config::ExtensionConfig::kLatest), IsOk()); env.SetConfig(compiler_config); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); for (int version = 0; version <= param.latest_extension_version; ++version) { Config runtime_config; // Request a specific version of the extension to be configured in the // runtime. ASSERT_THAT( runtime_config.AddExtensionConfig(param.extension_name, version), IsOk()); if (param.requires_optional_extension) { ASSERT_THAT(runtime_config.AddExtensionConfig("optional"), IsOk()); } EnvRuntime env_runtime; env_runtime.SetDescriptorPool( cel::internal::GetSharedTestingDescriptorPool()); RegisterStandardExtensions(env_runtime); env_runtime.SetConfig(runtime_config); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, env_runtime.NewRuntime()); absl::StatusOr> program_or = runtime->CreateProgram(std::make_unique(*ast)); // If the function is not supported in this extension version, check that // the program creation returned an error. if (!absl::c_contains(param.extension_versions, version)) { EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) << " expr: " << param.expr << " version: " << version; continue; } ASSERT_THAT(program_or, IsOk()) << " expr: " << param.expr << " version: " << version; std::unique_ptr program = *std::move(program_or); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); EXPECT_TRUE(value.GetBool()) << " expr: " << param.expr << " version: " << version; } } std::vector GetRuntimeStdExtensionTestCases() { return { TestCase{ // The "bindings" extension does not register any runtime functions - // only macros. .extension_name = "bindings", .expr = "cel.bind(t, 42, t + 1) == 43", }, TestCase{ .extension_name = "encoders", .expr = "base64.encode(b'hello') == 'aGVsbG8='", }, TestCase{ .extension_name = "lists", .extension_versions = {0, 1, 2}, .latest_extension_version = extensions::kListsExtensionLatestVersion, .expr = "[3, 2, 1].slice(0, 1) == [3]", }, TestCase{ .extension_name = "lists", .extension_versions = {1, 2}, .latest_extension_version = extensions::kListsExtensionLatestVersion, .expr = "[[1, 2], 3].flatten() == [1, 2, 3]", }, TestCase{ .extension_name = "lists", .extension_versions = {2}, .latest_extension_version = extensions::kListsExtensionLatestVersion, .expr = "[3, 2, 1].sort() == [1, 2, 3]", }, TestCase{ .extension_name = "math", .extension_versions = {0, 1, 2}, .latest_extension_version = extensions::kMathExtensionLatestVersion, .expr = "math.least([1, -2, 3]) == -2", }, TestCase{ .extension_name = "math", .extension_versions = {1, 2}, .latest_extension_version = extensions::kMathExtensionLatestVersion, .expr = "math.floor(42.9) == 42.0", }, TestCase{ .extension_name = "math", .extension_versions = {2}, .latest_extension_version = extensions::kMathExtensionLatestVersion, .expr = "math.sqrt(4) == 2.0", }, TestCase{ .extension_name = "optional", .extension_versions = {0, 1, 2}, .latest_extension_version = kOptionalExtensionLatestVersion, .expr = "optional.of(1).hasValue()", }, TestCase{ // No runtime functions. .extension_name = "protos", .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " "cel.expr.conformance.proto2.nested_ext)", }, TestCase{ .extension_name = "sets", .expr = "sets.contains([1], [1])", }, TestCase{ .extension_name = "strings", .extension_versions = {0, 1, 2, 3, 4}, .latest_extension_version = extensions::kStringsExtensionLatestVersion, .expr = "'Hello, who!'.replace('who', 'World') == 'Hello, World!'", }, TestCase{ .extension_name = "strings", .extension_versions = {1, 2, 3, 4}, .latest_extension_version = extensions::kStringsExtensionLatestVersion, .expr = "strings.quote('hello') == '\"hello\"'", }, TestCase{ .extension_name = "strings", .extension_versions = {2, 3, 4}, .latest_extension_version = extensions::kStringsExtensionLatestVersion, .expr = "['hello', 'world'].join(', ') == 'hello, world'", }, TestCase{ .extension_name = "strings", .extension_versions = {3, 4}, .latest_extension_version = extensions::kStringsExtensionLatestVersion, .expr = "'stressed'.reverse() == 'desserts'", }, TestCase{ // No runtime functions. .extension_name = "cel.lib.ext.comprev2", .expr = "[1, 2, 3].map(i, i * 2) == [2, 4, 6]", }, TestCase{ .extension_name = "cel.lib.ext.regex", .expr = "regex.replace('abc', '$', '_end') == 'abc_end'", .requires_optional_extension = true, }, }; } INSTANTIATE_TEST_SUITE_P(RuntimeStdExtensionTest, RuntimeStdExtensionTest, ValuesIn(GetRuntimeStdExtensionTestCases())); } // namespace } // namespace cel ================================================ FILE: env/type_info.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/type_info.h" #include #include #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/type.h" #include "common/type_kind.h" #include "env/config.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { std::optional TypeNameToTypeKind(absl::string_view type_name) { // Excluded types: // kUnknown // kError // kTypeParam // kFunction // kEnum static const absl::NoDestructor< absl::flat_hash_map> kTypeNameToTypeKind({ {"null", TypeKind::kNull}, {"bool", TypeKind::kBool}, {"int", TypeKind::kInt}, {"uint", TypeKind::kUint}, {"double", TypeKind::kDouble}, {"string", TypeKind::kString}, {"bytes", TypeKind::kBytes}, {"timestamp", TypeKind::kTimestamp}, {TimestampType::kName, TypeKind::kTimestamp}, {"duration", TypeKind::kDuration}, {DurationType::kName, TypeKind::kDuration}, {"list", TypeKind::kList}, {"map", TypeKind::kMap}, {"", TypeKind::kDyn}, {"any", TypeKind::kAny}, {"dyn", TypeKind::kDyn}, {BoolWrapperType::kName, TypeKind::kBoolWrapper}, {"bool_wrapper", TypeKind::kBoolWrapper}, {IntWrapperType::kName, TypeKind::kIntWrapper}, {"int_wrapper", TypeKind::kIntWrapper}, {UintWrapperType::kName, TypeKind::kUintWrapper}, {"uint_wrapper", TypeKind::kUintWrapper}, {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, {"double_wrapper", TypeKind::kDoubleWrapper}, {StringWrapperType::kName, TypeKind::kStringWrapper}, {"string_wrapper", TypeKind::kStringWrapper}, {BytesWrapperType::kName, TypeKind::kBytesWrapper}, {"bytes_wrapper", TypeKind::kBytesWrapper}, {"type", TypeKind::kType}, }); if (auto it = kTypeNameToTypeKind->find(type_name); it != kTypeNameToTypeKind->end()) { return it->second; } return std::nullopt; } } // namespace absl::StatusOr TypeInfoToType( const Config::TypeInfo& type_info, const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena) { if (type_info.is_type_param) { return TypeParamType(type_info.name); } std::optional type_kind = TypeNameToTypeKind(type_info.name); if (!type_kind.has_value()) { if (type_info.params.empty() && descriptor_pool != nullptr) { const google::protobuf::Descriptor* type = descriptor_pool->FindMessageTypeByName(type_info.name); if (type != nullptr) { return Type::Message(type); } } // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types std::vector parameter_types; for (const Config::TypeInfo& param : type_info.params) { CEL_ASSIGN_OR_RETURN(Type parameter_type, TypeInfoToType(param, descriptor_pool, arena)); parameter_types.push_back(parameter_type); } return OpaqueType(arena, type_info.name, parameter_types); } switch (*type_kind) { case TypeKind::kNull: return NullType(); case TypeKind::kBool: return BoolType(); case TypeKind::kInt: return IntType(); case TypeKind::kUint: return UintType(); case TypeKind::kDouble: return DoubleType(); case TypeKind::kString: return StringType(); case TypeKind::kBytes: return BytesType(); case TypeKind::kDuration: return DurationType(); case TypeKind::kTimestamp: return TimestampType(); case TypeKind::kList: { Type element_type; if (!type_info.params.empty()) { CEL_ASSIGN_OR_RETURN( element_type, TypeInfoToType(type_info.params[0], descriptor_pool, arena)); } else { element_type = DynType(); } return ListType(arena, element_type); } case TypeKind::kMap: { Type key_type = DynType(); Type value_type = DynType(); if (!type_info.params.empty()) { CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], descriptor_pool, arena)); } if (type_info.params.size() > 1) { CEL_ASSIGN_OR_RETURN( value_type, TypeInfoToType(type_info.params[1], descriptor_pool, arena)); } return MapType(arena, key_type, value_type); } case TypeKind::kDyn: return DynType(); case TypeKind::kAny: return AnyType(); case TypeKind::kBoolWrapper: return BoolWrapperType(); case TypeKind::kIntWrapper: return IntWrapperType(); case TypeKind::kUintWrapper: return UintWrapperType(); case TypeKind::kDoubleWrapper: return DoubleWrapperType(); case TypeKind::kStringWrapper: return StringWrapperType(); case TypeKind::kBytesWrapper: return BytesWrapperType(); case TypeKind::kType: { if (type_info.params.empty()) { return TypeType(arena, DynType()); } CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], descriptor_pool, arena)); return TypeType(arena, type); } default: return DynType(); } } } // namespace cel ================================================ FILE: env/type_info.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ #define THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ #include "absl/status/statusor.h" #include "common/type.h" #include "env/config.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { // Converts a Config::TypeInfo to a cel::Type. Returns an error if the type_info // cannot be converted to a known cel::Type, a list configured with more than // one parameter. absl::StatusOr TypeInfoToType( const Config::TypeInfo& type_info, const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ ================================================ FILE: env/type_info_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "env/type_info.h" #include #include #include "common/type.h" #include "common/type_proto.h" #include "env/config.h" #include "internal/proto_matchers.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/text_format.h" namespace cel { namespace { using absl_testing::IsOk; using testing::ValuesIn; struct TestCase { Config::TypeInfo type_info; std::string expected_type_pb; }; using TypeInfoTest = testing::TestWithParam; TEST_P(TypeInfoTest, TypeInfo) { const TestCase& param = GetParam(); cel::expr::Type expected_type_pb; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, &expected_type_pb)); google::protobuf::Arena arena; const google::protobuf::DescriptorPool* descriptor_pool = cel::internal::GetTestingDescriptorPool(); ASSERT_OK_AND_ASSIGN( cel::Type actual_type, cel::TypeInfoToType(param.type_info, descriptor_pool, &arena)); cel::expr::Type actual_type_pb; ASSERT_THAT(cel::TypeToProto(actual_type, &actual_type_pb), IsOk()); EXPECT_THAT(actual_type_pb, cel::internal::test::EqualsProto(expected_type_pb)); } std::vector GetTestCases() { return { TestCase{ .type_info = {.name = "int"}, .expected_type_pb = "primitive: INT64", }, TestCase{ .type_info = {.name = "list", .params = {Config::TypeInfo{.name = "int"}}}, .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", }, TestCase{ .type_info = {.name = "list"}, .expected_type_pb = "list_type { elem_type { dyn {} }}", }, TestCase{ .type_info = {.name = "map", .params = {Config::TypeInfo{.name = "string"}, Config::TypeInfo{.name = "int"}}}, .expected_type_pb = "map_type { key_type { primitive: STRING } " "value_type { primitive: INT64 }}", }, TestCase{ .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, .expected_type_pb = "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", }, TestCase{ .type_info = {.name = "A", .params = {Config::TypeInfo{.name = "B", .is_type_param = true}}}, .expected_type_pb = "abstract_type { name: 'A' parameter_types { type_param: 'B' } }", }, TestCase{ .type_info = {.name = "any"}, .expected_type_pb = "well_known: ANY", }, TestCase{ .type_info = {.name = "timestamp"}, .expected_type_pb = "well_known: TIMESTAMP", }, TestCase{ .type_info = {.name = "google.protobuf.DoubleValue"}, .expected_type_pb = "wrapper: DOUBLE", }, TestCase{ .type_info = {.name = "double_wrapper"}, .expected_type_pb = "wrapper: DOUBLE", }, TestCase{ .type_info = {.name = "type", .params = {Config::TypeInfo{.name = "duration"}}}, .expected_type_pb = "type: { well_known: DURATION }", }, TestCase{ .type_info = {.name = "parameterized", .params = {{.name = "A", .is_type_param = true}, {.name = "double"}}}, .expected_type_pb = "abstract_type { name: 'parameterized' " "parameter_types { type_param: 'A' } " "parameter_types { primitive: DOUBLE } }", }, }; } INSTANTIATE_TEST_SUITE_P(TypeInfoTest, TypeInfoTest, ValuesIn(GetTestCases())); } // namespace } // namespace cel ================================================ FILE: eval/BUILD ================================================ # Description # CEL evaluator performs evaluation of CEL expressions provided in AST form # (google.api.CheckedExpr) package(default_visibility = ["//visibility:public"]) licenses(["notice"]) exports_files(["LICENSE"]) ================================================ FILE: eval/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: eval/README.md ================================================ # CEL Evaluator A C++ implementation of a [Common Expression Language][1] evaluator. [1]: https://github.com/google/cel-spec ================================================ FILE: eval/compiler/BUILD ================================================ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") DEFAULT_VISIBILITY = [ "//eval:__subpackages__", "//runtime:__subpackages__", "//extensions:__subpackages__", "//testing:__subpackages__", ] # This package contains code # that compiles Expr object into evaluatable CelExpression package(default_visibility = ["//visibility:public"]) licenses(["notice"]) exports_files(["LICENSE"]) package_group( name = "coverage_visibility", packages = [ "//tools/...", ], ) cc_library( name = "flat_expr_builder_extensions", srcs = ["flat_expr_builder_extensions.cc"], hdrs = ["flat_expr_builder_extensions.h"], deps = [ ":resolver", "//base:ast", "//base:data", "//common:expr", "//common:native_type", "//common:value", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/eval:trace_step", "//internal:casts", "//runtime:runtime_options", "//runtime/internal:issue_collector", "//runtime/internal:runtime_env", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "flat_expr_builder_extensions_test", srcs = ["flat_expr_builder_extensions_test.cc"], deps = [ ":flat_expr_builder_extensions", ":resolver", "//common:expr", "//common:native_type", "//common:value", "//eval/eval:const_value_step", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/eval:function_step", "//internal:status_macros", "//internal:testing", "//runtime:function_registry", "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime:type_registry", "//runtime/internal:issue_collector", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "flat_expr_builder", srcs = [ "flat_expr_builder.cc", ], hdrs = [ "flat_expr_builder.h", ], deps = [ ":check_ast_extensions", ":flat_expr_builder_extensions", ":resolver", "//base:ast", "//base:builtins", "//base:data", "//common:allocator", "//common:ast", "//common:ast_traverse", "//common:ast_visitor", "//common:constant", "//common:expr", "//common:kind", "//common:type", "//common:value", "//eval/eval:comprehension_step", "//eval/eval:const_value_step", "//eval/eval:container_access_step", "//eval/eval:create_list_step", "//eval/eval:create_map_step", "//eval/eval:create_struct_step", "//eval/eval:direct_expression_step", "//eval/eval:equality_steps", "//eval/eval:evaluator_core", "//eval/eval:function_step", "//eval/eval:ident_step", "//eval/eval:jump_step", "//eval/eval:lazy_init_step", "//eval/eval:logic_step", "//eval/eval:optional_or_step", "//eval/eval:select_step", "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", "//eval/eval:trace_step", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime:type_registry", "//runtime/internal:convert_constant", "//runtime/internal:issue_collector", "//runtime/internal:runtime_env", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "flat_expr_builder_test", srcs = [ "flat_expr_builder_test.cc", ], deps = [ ":cel_expression_builder_flat_impl", ":constant_folding", ":flat_expr_builder", ":qualified_reference_resolver", "//base:builtins", "//common:function_descriptor", "//common:kind", "//common:value", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", "//eval/public:cel_builtins", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function", "//eval/public:cel_function_adapter", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:portable_cel_function_adapter", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_descriptor_pool_builder", "//eval/public/structs:cel_proto_wrapper", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//parser", "//runtime:function", "//runtime:function_adapter", "//runtime:runtime_options", "//runtime:standard_functions", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:field_mask_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "flat_expr_builder_comprehensions_test", srcs = [ "flat_expr_builder_comprehensions_test.cc", ], deps = [ ":cel_expression_builder_flat_impl", ":comprehension_vulnerability_check", ":flat_expr_builder", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:testing", "//parser", "//runtime:runtime_options", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:field_mask_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_expression_builder_flat_impl", srcs = [ "cel_expression_builder_flat_impl.cc", ], hdrs = [ "cel_expression_builder_flat_impl.h", ], deps = [ ":flat_expr_builder", "//base:ast", "//common:native_type", "//eval/eval:cel_expression_flat_impl", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/public:cel_expression", "//eval/public:cel_function_registry", "//eval/public:cel_type_registry", "//extensions/protobuf:ast_converters", "//internal:status_macros", "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime/internal:runtime_env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "cel_expression_builder_flat_impl_test", srcs = [ "cel_expression_builder_flat_impl_test.cc", ], deps = [ ":cel_expression_builder_flat_impl", ":constant_folding", ":regex_precompilation_optimization", "//eval/eval:cel_expression_flat_impl", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expression", "//eval/public:cel_function", "//eval/public:cel_value", "//eval/public:portable_cel_function_adapter", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/public/testing:matchers", "//extensions:bindings_ext", "//internal:status_macros", "//internal:testing", "//parser", "//parser:macro", "//runtime:runtime_options", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "constant_folding", srcs = [ "constant_folding.cc", ], hdrs = [ "constant_folding.h", ], deps = [ ":flat_expr_builder_extensions", ":resolver", "//base:builtins", "//base:data", "//common:ast", "//common:constant", "//common:expr", "//common:value", "//eval/eval:const_value_step", "//eval/eval:evaluator_core", "//internal:status_macros", "//runtime:activation", "//runtime/internal:convert_constant", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "constant_folding_test", srcs = [ "constant_folding_test.cc", ], deps = [ ":constant_folding", ":flat_expr_builder_extensions", ":resolver", "//base:ast", "//common:expr", "//common:value", "//eval/eval:const_value_step", "//eval/eval:create_list_step", "//eval/eval:create_map_step", "//eval/eval:evaluator_core", "//extensions/protobuf:ast_converters", "//internal:status_macros", "//internal:testing", "//parser", "//runtime:function_registry", "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime:type_registry", "//runtime/internal:issue_collector", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "qualified_reference_resolver", srcs = [ "qualified_reference_resolver.cc", ], hdrs = [ "qualified_reference_resolver.h", ], deps = [ ":flat_expr_builder_extensions", ":resolver", "//base:ast", "//base:builtins", "//common:ast", "//common:ast_rewrite", "//common:expr", "//common:kind", "//runtime:runtime_issue", "//runtime/internal:issue_collector", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "check_ast_extensions", srcs = ["check_ast_extensions.cc"], hdrs = ["check_ast_extensions.h"], deps = [ "//common:ast", "//common/ast:metadata", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) cc_test( name = "check_ast_extensions_test", srcs = ["check_ast_extensions_test.cc"], deps = [ ":check_ast_extensions", "//common:ast", "//common:expr", "//common/ast:metadata", "//internal:testing", "@com_google_absl//absl/status", ], ) cc_library( name = "resolver", srcs = ["resolver.cc"], hdrs = ["resolver.h"], deps = [ "//common:kind", "//common:type", "//common:value", "//internal:status_macros", "//runtime:function_overload_reference", "//runtime:function_registry", "//runtime:type_registry", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_test( name = "qualified_reference_resolver_test", srcs = [ "qualified_reference_resolver_test.cc", ], deps = [ ":qualified_reference_resolver", ":resolver", "//base:ast", "//base:builtins", "//common:ast", "//common:expr", "//common/ast:expr_proto", "//eval/public:builtin_func_registrar", "//eval/public:cel_function", "//eval/public:cel_function_registry", "//eval/public:cel_value", "//extensions/protobuf:ast_converters", "//internal:proto_matchers", "//internal:testing", "//runtime:runtime_issue", "//runtime:type_registry", "//runtime/internal:issue_collector", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "flat_expr_builder_short_circuiting_conformance_test", srcs = [ "flat_expr_builder_short_circuiting_conformance_test.cc", ], deps = [ ":cel_expression_builder_flat_impl", "//base:builtins", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_expression", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", "//internal:testing", "//runtime:runtime_options", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "resolver_test", size = "small", srcs = ["resolver_test.cc"], deps = [ ":resolver", "//common:value", "//eval/public:cel_function", "//eval/public:cel_function_registry", "//eval/public:cel_type_registry", "//eval/public:cel_value", "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "regex_precompilation_optimization", srcs = ["regex_precompilation_optimization.cc"], hdrs = ["regex_precompilation_optimization.h"], deps = [ ":flat_expr_builder_extensions", "//base:builtins", "//common:ast", "//common:casting", "//common:expr", "//common:native_type", "//common:value", "//eval/eval:compiler_constant_step", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/eval:regex_match_step", "//internal:casts", "//internal:re2_options", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_googlesource_code_re2//:re2", ], ) cc_test( name = "regex_precompilation_optimization_test", srcs = ["regex_precompilation_optimization_test.cc"], deps = [ ":cel_expression_builder_flat_impl", ":constant_folding", ":flat_expr_builder", ":flat_expr_builder_extensions", ":regex_precompilation_optimization", ":resolver", "//common:ast", "//eval/eval:evaluator_core", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expression", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//eval/public:cel_type_registry", "//eval/public:cel_value", "//internal:testing", "//parser", "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime/internal:issue_collector", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "comprehension_vulnerability_check", srcs = ["comprehension_vulnerability_check.cc"], hdrs = ["comprehension_vulnerability_check.h"], deps = [ ":flat_expr_builder_extensions", "//base:builtins", "//common:ast", "//common:constant", "//common:expr", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", ], ) cc_library( name = "instrumentation", srcs = ["instrumentation.cc"], hdrs = ["instrumentation.h"], deps = [ ":flat_expr_builder_extensions", "//common:ast", "//common:expr", "//common:value", "//eval/eval:evaluator_core", "//eval/eval:expression_step_base", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_test( name = "instrumentation_test", srcs = ["instrumentation_test.cc"], deps = [ ":constant_folding", ":flat_expr_builder", ":instrumentation", ":regex_precompilation_optimization", "//common:ast", "//common:value", "//eval/eval:evaluator_core", "//extensions/protobuf:ast_converters", "//internal:testing", "//parser", "//runtime:activation", "//runtime:function_registry", "//runtime:runtime_options", "//runtime:standard_functions", "//runtime:type_registry", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: eval/compiler/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: eval/compiler/cel_expression_builder_flat_impl.cc ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "eval/compiler/cel_expression_builder_flat_impl.h" #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/base/macros.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/ast.h" #include "common/native_type.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/public/cel_expression.h" #include "extensions/protobuf/ast_converters.h" #include "internal/status_macros.h" #include "runtime/runtime_issue.h" namespace google::api::expr::runtime { using ::cel::Ast; using ::cel::RuntimeIssue; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; // NOLINT: adjusted in OSS using ::cel::expr::SourceInfo; absl::StatusOr> CelExpressionBuilderFlatImpl::CreateExpression( const Expr* expr, const SourceInfo* source_info, std::vector* warnings) const { ABSL_ASSERT(expr != nullptr); CEL_ASSIGN_OR_RETURN( std::unique_ptr converted_ast, cel::extensions::CreateAstFromParsedExpr(*expr, source_info)); return CreateExpressionImpl(std::move(converted_ast), warnings); } absl::StatusOr> CelExpressionBuilderFlatImpl::CreateExpression( const Expr* expr, const SourceInfo* source_info) const { return CreateExpression(expr, source_info, /*warnings=*/nullptr); } absl::StatusOr> CelExpressionBuilderFlatImpl::CreateExpression( const CheckedExpr* checked_expr, std::vector* warnings) const { ABSL_ASSERT(checked_expr != nullptr); CEL_ASSIGN_OR_RETURN( std::unique_ptr converted_ast, cel::extensions::CreateAstFromCheckedExpr(*checked_expr)); return CreateExpressionImpl(std::move(converted_ast), warnings); } absl::StatusOr> CelExpressionBuilderFlatImpl::CreateExpression( const CheckedExpr* checked_expr) const { return CreateExpression(checked_expr, /*warnings=*/nullptr); } absl::StatusOr> CelExpressionBuilderFlatImpl::CreateExpressionImpl( std::unique_ptr converted_ast, std::vector* warnings) const { std::vector issues; auto* issues_ptr = (warnings != nullptr) ? &issues : nullptr; CEL_ASSIGN_OR_RETURN(FlatExpression impl, flat_expr_builder_.CreateExpressionImpl( std::move(converted_ast), issues_ptr)); if (issues_ptr != nullptr) { for (const auto& issue : issues) { warnings->push_back(issue.ToStatus()); } } if (flat_expr_builder_.options().max_recursion_depth != 0 && !impl.subexpressions().empty() && // mainline expression is exactly one recursive step. impl.subexpressions().front().size() == 1 && impl.subexpressions().front().front()->GetNativeTypeId() == cel::NativeTypeId::For()) { return CelExpressionRecursiveImpl::Create(env_, std::move(impl)); } return std::make_unique(env_, std::move(impl)); } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/cel_expression_builder_flat_impl.h ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CEL_EXPRESSION_BUILDER_FLAT_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CEL_EXPRESSION_BUILDER_FLAT_IMPL_H_ #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_type_registry.h" #include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { // CelExpressionBuilder implementation. // Builds instances of CelExpressionFlatImpl. class CelExpressionBuilderFlatImpl : public CelExpressionBuilder { public: CelExpressionBuilderFlatImpl( absl_nonnull std::shared_ptr env, const cel::RuntimeOptions& options) : env_(std::move(env)), flat_expr_builder_(env_, options, /*use_legacy_type_provider=*/true) { ABSL_DCHECK(env_->IsInitialized()); } explicit CelExpressionBuilderFlatImpl( absl_nonnull std::shared_ptr env) : CelExpressionBuilderFlatImpl(std::move(env), cel::RuntimeOptions()) {} absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info) const override; absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info, std::vector* warnings) const override; absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr) const override; absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr, std::vector* warnings) const override; FlatExprBuilder& flat_expr_builder() { return flat_expr_builder_; } void set_container(std::string container) override { flat_expr_builder_.set_container(std::move(container)); } // CelFunction registry. Extension function should be registered with it // prior to expression creation. CelFunctionRegistry* GetRegistry() const override { return &env_->legacy_function_registry; } // CEL Type registry. Provides a means to resolve the CEL built-in types to // CelValue instances, and to extend the set of types and enums known to // expressions by registering them ahead of time. CelTypeRegistry* GetTypeRegistry() const override { return &env_->legacy_type_registry; } absl::string_view container() const override { return flat_expr_builder_.container(); } private: absl::StatusOr> CreateExpressionImpl( std::unique_ptr converted_ast, std::vector* warnings) const; absl_nonnull std::shared_ptr env_; FlatExprBuilder flat_expr_builder_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CEL_EXPRESSION_BUILDER_FLAT_IMPL_H_ ================================================ FILE: eval/compiler/cel_expression_builder_flat_impl_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Smoke tests for CelExpressionBuilderFlatImpl. This class is a thin wrapper // over FlatExprBuilder, so most of the tests are just covering the conversion // code from the legacy APIs to the implementation. See // flat_expr_builder_test.cc for additional tests. #include "eval/compiler/cel_expression_builder_flat_impl.h" #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/regex_precompilation_optimization.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/portable_cel_function_adapter.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "extensions/bindings_ext.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/macro.h" #include "parser/parser.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto3::NestedTestAllTypes; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::google::api::expr::parser::Macro; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParseWithMacros; using ::testing::_; using ::testing::Contains; using ::testing::HasSubstr; using ::testing::IsNull; using ::testing::NotNull; TEST(CelExpressionBuilderFlatImplTest, Error) { Expr expr; SourceInfo source_info; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid empty expression"))); } TEST(CelExpressionBuilderFlatImplTest, ParsedExpr) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelInt64(3)); } struct RecursiveTestCase { std::string test_name; std::string expr; test::CelValueMatcher matcher; std::string pb_expr; }; class RecursivePlanTest : public ::testing::TestWithParam { protected: absl::Status SetupBuilder(CelExpressionBuilderFlatImpl& builder) { builder.GetTypeRegistry()->RegisterEnum("TestEnum", {{"FOO", 1}, {"BAR", 2}}); CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder.GetRegistry())); return builder.GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( "LazilyBoundMult", false, {CelValue::Type::kInt64, CelValue::Type::kInt64})); } absl::Status SetupActivation(Activation& activation, google::protobuf::Arena* arena) { activation.InsertValue("int_1", CelValue::CreateInt64(1)); activation.InsertValue("string_abc", CelValue::CreateStringView("abc")); activation.InsertValue("string_def", CelValue::CreateStringView("def")); auto* map = google::protobuf::Arena::Create(arena); CEL_RETURN_IF_ERROR( map->Add(CelValue::CreateStringView("a"), CelValue::CreateInt64(1))); CEL_RETURN_IF_ERROR( map->Add(CelValue::CreateStringView("b"), CelValue::CreateInt64(2))); activation.InsertValue("map_var", CelValue::CreateMap(map)); auto* msg = google::protobuf::Arena::Create(arena); msg->mutable_child()->mutable_payload()->set_single_int64(42); activation.InsertValue("struct_var", CelProtoWrapper::CreateMessage(msg, arena)); activation.InsertValue("TestEnum.BAR", CelValue::CreateInt64(-1)); CEL_RETURN_IF_ERROR(activation.InsertFunction( PortableBinaryFunctionAdapter::Create( "LazilyBoundMult", false, [](google::protobuf::Arena*, int64_t lhs, int64_t rhs) -> int64_t { return lhs * rhs; }))); return absl::OkStatus(); } }; absl::StatusOr ParseTestCase(const RecursiveTestCase& test_case) { static const std::vector* kMacros = []() { auto* result = new std::vector(Macro::AllMacros()); absl::c_copy(cel::extensions::bindings_macros(), std::back_inserter(*result)); return result; }(); if (!test_case.expr.empty()) { return ParseWithMacros(test_case.expr, *kMacros, ""); } else if (!test_case.pb_expr.empty()) { ParsedExpr result; if (!google::protobuf::TextFormat::ParseFromString(test_case.pb_expr, &result)) { return absl::InvalidArgumentError("Failed to parse proto"); } return result; } return absl::InvalidArgumentError("No expression provided"); } TEST_P(RecursivePlanTest, ParsedExprRecursiveImpl) { const RecursiveTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case)); cel::RuntimeOptions options; options.container = "cel.expr.conformance.proto3"; google::protobuf::Arena arena; // Unbounded. options.max_recursion_depth = -1; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); EXPECT_THAT(dynamic_cast(plan.get()), NotNull()); Activation activation; ASSERT_OK(SetupActivation(activation, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); EXPECT_THAT(result, test_case.matcher); } TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) { const RecursiveTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case)); cel::RuntimeOptions options; options.container = "cel.expr.conformance.proto3"; google::protobuf::Arena arena; // Unbounded. options.max_recursion_depth = -1; options.enable_comprehension_list_append = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); builder.flat_expr_builder().AddProgramOptimizer( cel::runtime_internal::CreateConstantFoldingOptimizer()); builder.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options.regex_max_program_size)); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); EXPECT_THAT(dynamic_cast(plan.get()), NotNull()); Activation activation; ASSERT_OK(SetupActivation(activation, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); EXPECT_THAT(result, test_case.matcher); } TEST_P(RecursivePlanTest, ParsedExprRecursiveTraceSupport) { const RecursiveTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case)); cel::RuntimeOptions options; options.container = "cel.expr.conformance.proto3"; google::protobuf::Arena arena; auto cb = [](int64_t id, const CelValue& value, google::protobuf::Arena* arena) { return absl::OkStatus(); }; // Unbounded. options.max_recursion_depth = -1; options.enable_recursive_tracing = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); EXPECT_THAT(dynamic_cast(plan.get()), NotNull()); Activation activation; ASSERT_OK(SetupActivation(activation, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Trace(activation, &arena, cb)); EXPECT_THAT(result, test_case.matcher); } TEST_P(RecursivePlanTest, Disabled) { google::protobuf::LinkMessageReflection(); const RecursiveTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case)); cel::RuntimeOptions options; options.container = "cel.expr.conformance.proto3"; google::protobuf::Arena arena; // disabled. options.max_recursion_depth = 0; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(SetupBuilder(builder)); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); EXPECT_THAT(dynamic_cast(plan.get()), IsNull()); Activation activation; ASSERT_OK(SetupActivation(activation, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); EXPECT_THAT(result, test_case.matcher); } INSTANTIATE_TEST_SUITE_P( RecursivePlanTest, RecursivePlanTest, testing::ValuesIn(std::vector{ {"constant", "'abc'", test::IsCelString("abc")}, {"call", "1 + 2", test::IsCelInt64(3)}, {"nested_call", "1 + 1 + 1 + 1", test::IsCelInt64(4)}, {"and", "true && false", test::IsCelBool(false)}, {"or", "true || false", test::IsCelBool(true)}, {"ternary", "(true || false) ? 2 + 2 : 3 + 3", test::IsCelInt64(4)}, {"create_list", "3 in [1, 2, 3]", test::IsCelBool(true)}, {"create_list_complex", "3 in [2 / 2, 4 / 2, 6 / 2]", test::IsCelBool(true)}, {"ident", "int_1 == 1", test::IsCelBool(true)}, {"ident_complex", "int_1 + 2 > 4 ? string_abc : string_def", test::IsCelString("def")}, {"select", "struct_var.child.payload.single_int64", test::IsCelInt64(42)}, {"nested_select", "[map_var.a, map_var.b].size() == 2", test::IsCelBool(true)}, {"map_index", "map_var['b']", test::IsCelInt64(2)}, {"list_index", "[1, 2, 3][1]", test::IsCelInt64(2)}, {"compre_exists", "[1, 2, 3, 4].exists(x, x == 3)", test::IsCelBool(true)}, {"compre_map", "8 in [1, 2, 3, 4].map(x, x * 2)", test::IsCelBool(true)}, {"map_var_compre_exists", "map_var.exists(key, key == 'b')", test::IsCelBool(true)}, {"map_compre_exists", "{'a': 1, 'b': 2}.exists(k, k == 'b')", test::IsCelBool(true)}, {"create_map", "{'a': 42, 'b': 0, 'c': 0}.size()", test::IsCelInt64(3)}, {"create_struct", "NestedTestAllTypes{payload: TestAllTypes{single_int64: " "-42}}.payload.single_int64", test::IsCelInt64(-42)}, {"bind", R"(cel.bind(x, "1", x + x + x + x))", test::IsCelString("1111")}, {"nested_bind", R"(cel.bind(x, 20, cel.bind(y, 30, x + y)))", test::IsCelInt64(50)}, {"bind_with_comprehensions", R"(cel.bind(x, [1, 2], cel.bind(y, x.map(z, z * 2), y.exists(z, z == 4))))", test::IsCelBool(true)}, {"shadowable_value_default", R"(TestEnum.FOO == 1)", test::IsCelBool(true)}, {"shadowable_value_shadowed", R"(TestEnum.BAR == -1)", test::IsCelBool(true)}, {"lazily_resolved_function", "LazilyBoundMult(123, 2) == 246", test::IsCelBool(true)}, {"re_matches", "matches(string_abc, '[ad][be][cf]')", test::IsCelBool(true)}, {"re_matches_receiver", "(string_abc + string_def).matches(r'(123)?' + r'abc' + r'def')", test::IsCelBool(true)}, {"block", "", test::IsCelBool(true), R"pb( expr { id: 1 call_expr { function: "cel.@block" args { id: 2 list_expr { elements { const_expr { int64_value: 8 } } elements { const_expr { int64_value: 10 } } } } args { id: 3 call_expr { function: "_<_" args { ident_expr { name: "@index0" } } args { ident_expr { name: "@index1" } } } } } })pb"}, {"block_with_comprehensions", "", test::IsCelBool(true), // Something like: // variables: // - users: {'bob': ['bar'], 'alice': ['foo', 'bar']} // - somone_has_bar: users.exists(u, 'bar' in users[u]) // policy: // - someone_has_bar && !users.exists(u, u == 'eve')) // R"pb( expr { call_expr { function: "cel.@block" args { list_expr { elements { struct_expr: { entries: { map_key: { const_expr: { string_value: "bob" } } value: { list_expr: { elements: { const_expr: { string_value: "bar" } } } } } entries: { map_key: { const_expr: { string_value: "alice" } } value: { list_expr: { elements: { const_expr: { string_value: "bar" } } elements: { const_expr: { string_value: "foo" } } } } } } } elements { id: 16 comprehension_expr: { iter_var: "u" iter_range: { id: 1 ident_expr: { name: "@index0" } } accu_var: "__result__" accu_init: { id: 9 const_expr: { bool_value: false } } loop_condition: { id: 12 call_expr: { function: "@not_strictly_false" args: { id: 11 call_expr: { function: "!_" args: { id: 10 ident_expr: { name: "__result__" } } } } } } loop_step: { id: 14 call_expr: { function: "_||_" args: { id: 13 ident_expr: { name: "__result__" } } args: { id: 5 call_expr: { function: "@in" args: { id: 4 const_expr: { string_value: "bar" } } args: { id: 7 call_expr: { function: "_[_]" args: { id: 6 ident_expr: { name: "@index0" } } args: { id: 8 ident_expr: { name: "u" } } } } } } } } result: { id: 15 ident_expr: { name: "__result__" } } } } } } args { id: 17 call_expr: { function: "_&&_" args: { id: 1 ident_expr: { name: "@index1" } } args: { id: 2 call_expr: { function: "!_" args: { id: 16 comprehension_expr: { iter_var: "u" iter_range: { id: 3 ident_expr: { name: "@index0" } } accu_var: "__result__" accu_init: { id: 9 const_expr: { bool_value: false } } loop_condition: { id: 12 call_expr: { function: "@not_strictly_false" args: { id: 11 call_expr: { function: "!_" args: { id: 10 ident_expr: { name: "__result__" } } } } } } loop_step: { id: 14 call_expr: { function: "_||_" args: { id: 13 ident_expr: { name: "__result__" } } args: { id: 7 call_expr: { function: "_==_" args: { id: 6 ident_expr: { name: "u" } } args: { id: 8 const_expr: { string_value: "eve" } } } } } } result: { id: 15 ident_expr: { name: "__result__" } } } } } } } } } })pb"}}), [](const testing::TestParamInfo& info) -> std::string { return info.param.test_name; }); TEST(CelExpressionBuilderFlatImplTest, ParsedExprWithWarnings) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); cel::RuntimeOptions options; options.fail_on_warnings = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; ASSERT_OK_AND_ASSIGN( std::unique_ptr plan, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info(), &warnings)); EXPECT_THAT(warnings, Contains(StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("No overloads")))); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelError( StatusIs(_, HasSubstr("No matching overloads")))); } TEST(CelExpressionBuilderFlatImplTest, EmptyLegacyTypeViewUnsupported) { // Creating type values directly (instead of using the builtin functions and // identifiers from the type registry) is not recommended for CEL users. The // name is expected to be non-empty. ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("x")); cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; activation.InsertValue("x", CelValue::CreateCelTypeView("")); google::protobuf::Arena arena; ASSERT_THAT(plan->Evaluate(activation, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(CelExpressionBuilderFlatImplTest, LegacyTypeViewSupported) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("x")); cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; activation.InsertValue("x", CelValue::CreateCelTypeView("MyType")); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsCelType()); EXPECT_EQ(result.CelTypeOrDie().value(), "MyType"); } TEST(CelExpressionBuilderFlatImplTest, CheckedExpr) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); CheckedExpr checked_expr; checked_expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); checked_expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&checked_expr)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelInt64(3)); } TEST(CelExpressionBuilderFlatImplTest, CheckedExprWithWarnings) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); CheckedExpr checked_expr; checked_expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); checked_expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); cel::RuntimeOptions options; options.fail_on_warnings = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder.CreateExpression(&checked_expr, &warnings)); EXPECT_THAT(warnings, Contains(StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("No overloads")))); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelError( StatusIs(_, HasSubstr("No matching overloads")))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/check_ast_extensions.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/check_ast_extensions.h" #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/ast.h" #include "common/ast/metadata.h" namespace google::api::expr::runtime { absl::StatusOr> ExtractAndValidateRuntimeExtensions(const cel::Ast& ast) { std::vector runtime_extensions; absl::flat_hash_set seen_extension_ids; for (const cel::ExtensionSpec& extension : ast.source_info().extensions()) { bool is_runtime = false; for (const cel::ExtensionSpec::Component& component : extension.affected_components()) { if (component == cel::ExtensionSpec::Component::kRuntime) { is_runtime = true; break; } } if (!is_runtime) { continue; } if (!seen_extension_ids.insert(extension.id()).second) { return absl::InvalidArgumentError( absl::StrCat("duplicate extension ID: ", extension.id())); } runtime_extensions.push_back(extension); } return runtime_extensions; } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/check_ast_extensions.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ #include #include "absl/status/statusor.h" #include "common/ast.h" #include "common/ast/metadata.h" namespace google::api::expr::runtime { // Extracts and validates extension tags from the AST `ast` that affect the // runtime component. Returns the validated list of runtime extensions, or an // error if there are multiple runtime extensions with the same ID. absl::StatusOr> ExtractAndValidateRuntimeExtensions(const cel::Ast& ast); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ ================================================ FILE: eval/compiler/check_ast_extensions_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/check_ast_extensions.h" #include #include #include #include #include "absl/status/status.h" #include "common/ast.h" #include "common/ast/metadata.h" #include "common/expr.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::Ast; using ::cel::Expr; using ::cel::ExtensionSpec; using ::cel::SourceInfo; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Property; using ::testing::SizeIs; TEST(ExtractAndValidateRuntimeExtensionsTest, EmptyExtensions) { Ast ast(Expr{}, SourceInfo{}); EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), IsOkAndHolds(SizeIs(0))); } TEST(ExtractAndValidateRuntimeExtensionsTest, FiltersNonRuntimeExtensions) { SourceInfo source_info; source_info.mutable_extensions().push_back( ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); source_info.mutable_extensions().push_back( ExtensionSpec("ext2", nullptr, {ExtensionSpec::Component::kTypeChecker})); Ast ast(Expr(), std::move(source_info)); EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), IsOkAndHolds(SizeIs(0))); } TEST(ExtractAndValidateRuntimeExtensionsTest, ExtractsRuntimeExtensions) { SourceInfo source_info; source_info.mutable_extensions().push_back( ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); source_info.mutable_extensions().push_back(ExtensionSpec( "ext2", nullptr, {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); source_info.mutable_extensions().push_back( ExtensionSpec("ext3", nullptr, {ExtensionSpec::Component::kParser})); Ast ast(Expr(), std::move(source_info)); auto result = ExtractAndValidateRuntimeExtensions(ast); ASSERT_THAT(result, IsOk()); EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")), Property(&ExtensionSpec::id, Eq("ext2")))); } TEST(ExtractAndValidateRuntimeExtensionsTest, FailsOnDuplicateRuntimeID) { SourceInfo source_info; source_info.mutable_extensions().push_back( ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); source_info.mutable_extensions().push_back(ExtensionSpec( "ext1", nullptr, {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); Ast ast(Expr(), std::move(source_info)); EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), StatusIs(absl::StatusCode::kInvalidArgument, "duplicate extension ID: ext1")); } TEST(ExtractAndValidateRuntimeExtensionsTest, IgnoresDuplicateNonRuntimeID) { SourceInfo source_info; source_info.mutable_extensions().push_back( ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); source_info.mutable_extensions().push_back( ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); Ast ast(Expr(), std::move(source_info)); auto result = ExtractAndValidateRuntimeExtensions(ast); ASSERT_THAT(result, IsOk()); EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/comprehension_vulnerability_check.cc ================================================ // // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/comprehension_vulnerability_check.h" #include #include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "base/builtins.h" #include "common/ast.h" #include "common/constant.h" #include "common/expr.h" #include "eval/compiler/flat_expr_builder_extensions.h" namespace google::api::expr::runtime { namespace { using ::cel::CallExpr; using ::cel::ComprehensionExpr; using ::cel::Constant; using ::cel::Expr; using ::cel::IdentExpr; using ::cel::ListExpr; using ::cel::MapExpr; using ::cel::SelectExpr; using ::cel::StructExpr; using ::cel::UnspecifiedExpr; // ComprehensionAccumulationReferences recursively walks an expression to count // the locations where the given accumulation var_name is referenced. // // The purpose of this function is to detect cases where the accumulation // variable might be used in hand-rolled ASTs that cause exponential memory // consumption. The var_name is generally not accessible by CEL expression // writers, only by macro authors. However, a hand-rolled AST makes it possible // to misuse the accumulation variable. // // Limitations: // - This check only covers standard operators and functions. // Extension functions may cause the same issue if they allocate an amount of // memory that is dependent on the size of the inputs. // // - This check is not exhaustive. There may be ways to construct an AST to // trigger exponential memory growth not captured by this check. // // The algorithm for reference counting is as follows: // // * Calls - If the call is a concatenation operator, sum the number of places // where the variable appears within the call, as this could result // in memory explosion if the accumulation variable type is a list // or string. Otherwise, return 0. // // accu: ["hello"] // expr: accu + accu // memory grows exponentionally // // * CreateList - If the accumulation var_name appears within multiple elements // of a CreateList call, this means that the accumulation is // generating an ever-expanding tree of values that will likely // exhaust memory. // // accu: ["hello"] // expr: [accu, accu] // memory grows exponentially // // * CreateStruct - If the accumulation var_name as an entry within the // creation of a map or message value, then it's possible that the // comprehension is accumulating an ever-expanding tree of values. // // accu: {"key": "val"} // expr: {1: accu, 2: accu} // // * Comprehension - If the accumulation var_name is not shadowed by a nested // iter_var or accu_var, then it may be accmulating memory within a // nested context. The accumulation may occur on either the // comprehension loop_step or result step. // // Since this behavior generally only occurs within hand-rolled ASTs, it is // very reasonable to opt-in to this check only when using human authored ASTs. int ComprehensionAccumulationReferences(const cel::Expr& expr, absl::string_view var_name) { struct Handler { const Expr& expr; absl::string_view var_name; int operator()(const CallExpr& call) { int references = 0; absl::string_view function = call.function(); // Return the maximum reference count of each side of the ternary branch. if (function == cel::builtin::kTernary && call.args().size() == 3) { return std::max( ComprehensionAccumulationReferences(call.args()[1], var_name), ComprehensionAccumulationReferences(call.args()[2], var_name)); } // Return the number of times the accumulator var_name appears in the add // expression. There's no arg size check on the add as it may become a // variadic add at a future date. if (function == cel::builtin::kAdd) { for (int i = 0; i < call.args().size(); i++) { references += ComprehensionAccumulationReferences(call.args()[i], var_name); } return references; } // Return whether the accumulator var_name is used as the operand in an // index expression or in the identity `dyn` function. if ((function == cel::builtin::kIndex && call.args().size() == 2) || (function == cel::builtin::kDyn && call.args().size() == 1)) { return ComprehensionAccumulationReferences(call.args()[0], var_name); } return 0; } int operator()(const ComprehensionExpr& comprehension) { absl::string_view accu_var = comprehension.accu_var(); absl::string_view iter_var = comprehension.iter_var(); int result_references = 0; int loop_step_references = 0; int sum_of_accumulator_references = 0; // The accumulation or iteration variable shadows the var_name and so will // not manipulate the target var_name in a nested comprehension scope. if (accu_var != var_name && iter_var != var_name) { loop_step_references = ComprehensionAccumulationReferences( comprehension.loop_step(), var_name); } // Accumulator variable (but not necessarily iter var) can shadow an // outer accumulator variable in the result sub-expression. if (accu_var != var_name) { result_references = ComprehensionAccumulationReferences( comprehension.result(), var_name); } // Count the raw number of times the accumulator variable was referenced. // This is to account for cases where the outer accumulator is shadowed by // the inner accumulator, while the inner accumulator is being used as the // iterable range. // // An equivalent expression to this problem: // // outer_accu := outer_accu // for y in outer_accu: // outer_accu += input // return outer_accu // If this is overly restrictive (Ex: when generalized reducers is // implemented), we may need to revisit this solution sum_of_accumulator_references = ComprehensionAccumulationReferences( comprehension.accu_init(), var_name); sum_of_accumulator_references += ComprehensionAccumulationReferences( comprehension.iter_range(), var_name); // Count the number of times the accumulator var_name within the loop_step // or the nested comprehension result. // // This doesn't cover cases where the inner accumulator accumulates the // outer accumulator then is returned in the inner comprehension result. return std::max({loop_step_references, result_references, sum_of_accumulator_references}); } int operator()(const ListExpr& list) { // Count the number of times the accumulator var_name appears within a // create list expression's elements. int references = 0; for (int i = 0; i < list.elements().size(); i++) { references += ComprehensionAccumulationReferences( list.elements()[i].expr(), var_name); } return references; } int operator()(const StructExpr& map) { // Count the number of times the accumulation variable occurs within // entry values. int references = 0; for (int i = 0; i < map.fields().size(); i++) { const auto& entry = map.fields()[i]; if (entry.has_value()) { references += ComprehensionAccumulationReferences(entry.value(), var_name); } } return references; } int operator()(const MapExpr& map) { // Count the number of times the accumulation variable occurs within // entry values. int references = 0; for (int i = 0; i < map.entries().size(); i++) { const auto& entry = map.entries()[i]; if (entry.has_value()) { references += ComprehensionAccumulationReferences(entry.value(), var_name); } } return references; } int operator()(const SelectExpr& select) { // Test only expressions have a boolean return and thus cannot easily // allocate large amounts of memory. if (select.test_only()) { return 0; } // Return whether the accumulator var_name appears within a non-test // select operand. return ComprehensionAccumulationReferences(select.operand(), var_name); } int operator()(const IdentExpr& ident) { // Return whether the identifier name equals the accumulator var_name. return ident.name() == var_name ? 1 : 0; } int operator()(const Constant& constant) { return 0; } int operator()(const UnspecifiedExpr&) { return 0; } } handler{expr, var_name}; return absl::visit(handler, expr.kind()); } bool ComprehensionHasMemoryExhaustionVulnerability( const ComprehensionExpr& comprehension) { absl::string_view accu_var = comprehension.accu_var(); const auto& loop_step = comprehension.loop_step(); return ComprehensionAccumulationReferences(loop_step, accu_var) >= 2; } class ComprehensionVulnerabilityCheck : public ProgramOptimizer { public: absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { if (node.has_comprehension_expr() && ComprehensionHasMemoryExhaustionVulnerability( node.comprehension_expr())) { return absl::InvalidArgumentError( "Comprehension contains memory exhaustion vulnerability"); } return absl::OkStatus(); } absl::Status OnPostVisit(PlannerContext& context, const cel::Expr& node) override { return absl::OkStatus(); } }; } // namespace ProgramOptimizerFactory CreateComprehensionVulnerabilityCheck() { return [](PlannerContext&, const cel::Ast& ast) { return std::make_unique(); }; } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/comprehension_vulnerability_check.h ================================================ // // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_COMPREHENSION_VULNERABILITY_CHECK_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_COMPREHENSION_VULNERABILITY_CHECK_H_ #include "eval/compiler/flat_expr_builder_extensions.h" namespace google::api::expr::runtime { // Create a program optimizer that checks for memory consumption vulnerability // in comprehensions. // // Hand-rolled ASTs or custom Macro implementations can reference the implicit // accumulator variable in comprehensions to generate objects exponential in the // size of the inputs. Type checked expressions using the built-in macros and // functions are not susceptible to this. // // This check is not exhaustive, but will catch most accidental triggers of // this behavior in the standard env. It does not consider custom extension // functions. // // This implementation recursively traverses the AST, so it is not safe for // deeply nested ASTs or in environments with smaller stack limits. // // conceptual example with a generalized reducer macro: // [1, 2, 3, 4] // .reduce( // /*iter_var=*/ unused, // /*accu_var=*/ accu, // /*accu_init=*/ [1], // /*loop_step=*/ accu + accu, // /*result=*/ accu) // resulting list sizes per iteration: 2, 4, 8, 16. ProgramOptimizerFactory CreateComprehensionVulnerabilityCheck(); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_COMPREHENSION_VULNERABILITY_CHECK_H_ ================================================ FILE: eval/compiler/constant_folding.cc ================================================ // Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/constant_folding.h" #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/builtins.h" #include "base/type_provider.h" #include "common/ast.h" #include "common/constant.h" #include "common/expr.h" #include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/const_value_step.h" #include "eval/eval/evaluator_core.h" #include "internal/status_macros.h" #include "runtime/activation.h" #include "runtime/internal/convert_constant.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { namespace { using ::cel::Expr; using ::cel::builtin::kAnd; using ::cel::builtin::kOr; using ::cel::builtin::kTernary; using ::cel::runtime_internal::ConvertConstant; using ::google::api::expr::runtime::CreateConstValueDirectStep; using ::google::api::expr::runtime::CreateConstValueStep; using ::google::api::expr::runtime::EvaluationListener; using ::google::api::expr::runtime::ExecutionFrame; using ::google::api::expr::runtime::ExecutionPath; using ::google::api::expr::runtime::ExecutionPathView; using ::google::api::expr::runtime::FlatExpressionEvaluatorState; using ::google::api::expr::runtime::PlannerContext; using ::google::api::expr::runtime::ProgramOptimizer; using ::google::api::expr::runtime::ProgramOptimizerFactory; using ::google::api::expr::runtime::Resolver; enum class IsConst { kConditional, kNonConst, }; class ConstantFoldingExtension : public ProgramOptimizer { public: ConstantFoldingExtension( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl_nullable std::shared_ptr shared_arena, google::protobuf::Arena* absl_nonnull arena, absl_nullable std::shared_ptr shared_message_factory, google::protobuf::MessageFactory* absl_nonnull message_factory, const TypeProvider& type_provider) : shared_arena_(std::move(shared_arena)), shared_message_factory_(std::move(shared_message_factory)), state_(kDefaultStackLimit, kComprehensionSlotCount, type_provider, descriptor_pool, message_factory, arena) {} absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context, const Expr& node) override; absl::Status OnPostVisit(google::api::expr::runtime::PlannerContext& context, const Expr& node) override; private: // Most constant folding evaluations are simple // binary operators. static constexpr size_t kDefaultStackLimit = 4; // Comprehensions are not evaluated -- the current implementation can't detect // if the comprehension variables are only used in a const way. static constexpr size_t kComprehensionSlotCount = 0; absl_nullable std::shared_ptr shared_arena_; ABSL_ATTRIBUTE_UNUSED absl_nullable std::shared_ptr shared_message_factory_; Activation empty_; FlatExpressionEvaluatorState state_; std::vector is_const_; }; IsConst IsConstExpr(const Expr& expr, const Resolver& resolver) { switch (expr.kind_case()) { case ExprKindCase::kConstant: return IsConst::kConditional; case ExprKindCase::kIdentExpr: return IsConst::kNonConst; case ExprKindCase::kComprehensionExpr: // Not yet supported, need to identify whether range and // iter vars are compatible with const folding. return IsConst::kNonConst; case ExprKindCase::kStructExpr: return IsConst::kNonConst; case ExprKindCase::kMapExpr: // Empty maps are rare and not currently supported as they may eventually // have similar issues to empty list when used within comprehensions or // macros. if (expr.map_expr().entries().empty()) { return IsConst::kNonConst; } return IsConst::kConditional; case ExprKindCase::kListExpr: if (expr.list_expr().elements().empty()) { // Don't fold for empty list to allow comprehension // list append optimization. return IsConst::kNonConst; } return IsConst::kConditional; case ExprKindCase::kSelectExpr: return IsConst::kConditional; case ExprKindCase::kCallExpr: { const auto& call = expr.call_expr(); // Short Circuiting operators not yet supported. if (call.function() == kAnd || call.function() == kOr || call.function() == kTernary) { return IsConst::kNonConst; } // For now we skip constant folding for cel.@block. We do not yet setup // slots. When we enable constant folding for comprehensions (like // cel.bind), we can address cel.@block. if (call.function() == "cel.@block") { return IsConst::kNonConst; } int arg_len = call.args().size() + (call.has_target() ? 1 : 0); // Check for any lazy overloads (activation dependant) if (!resolver .FindLazyOverloads(call.function(), call.has_target(), arg_len) .empty()) { return IsConst::kNonConst; } auto overloads = resolver.FindOverloads(call.function(), call.has_target(), arg_len); // Check for any contextual overloads. If there are any, we cowardly // avoid constant folding instead of trying to check if one of the // overloads would be safe to use. for (const auto& overload : overloads) { if (overload.descriptor.is_contextual()) { return IsConst::kNonConst; } } return IsConst::kConditional; } case ExprKindCase::kUnspecifiedExpr: default: return IsConst::kNonConst; } } absl::Status ConstantFoldingExtension::OnPreVisit(PlannerContext& context, const Expr& node) { IsConst is_const = IsConstExpr(node, context.resolver()); is_const_.push_back(is_const); return absl::OkStatus(); } absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, const Expr& node) { if (is_const_.empty()) { return absl::InternalError("ConstantFoldingExtension called out of order."); } IsConst is_const = is_const_.back(); is_const_.pop_back(); if (is_const == IsConst::kNonConst) { // update parent if (!is_const_.empty()) { is_const_.back() = IsConst::kNonConst; } return absl::OkStatus(); } ExecutionPathView subplan = context.GetSubplan(node); if (subplan.empty()) { // This subexpression is already optimized out or suppressed. return absl::OkStatus(); } // copy string to managed handle if backed by the original program. Value value; if (node.has_const_expr()) { CEL_ASSIGN_OR_RETURN(value, ConvertConstant(node.const_expr(), state_.arena())); } else { ExecutionFrame frame(subplan, empty_, context.options(), state_); state_.Reset(); // Update stack size to accommodate sub expression. // This only results in a vector resize if the new maxsize is greater than // the current capacity. state_.value_stack().SetMaxSize(subplan.size()); auto result = frame.Evaluate(); // If this would be a runtime error, then don't adjust the program plan, but // rather allow the error to occur at runtime to preserve the evaluation // contract with non-constant folding use cases. if (!result.ok()) { return absl::OkStatus(); } value = *result; if (value->Is()) { return absl::OkStatus(); } } // If recursive planning enabled (recursion limit unbounded or at least 1), // use a recursive (direct) step for the folded constant. // // Constant folding is applied leaf to root based on the program plan so far, // so the planner will have an opportunity to validate that the recursion // limit is being followed when visiting parent nodes in the AST. if (context.options().max_recursion_depth != 0) { return context.ReplaceSubplan( node, CreateConstValueDirectStep(std::move(value), node.id()), 1); } // Otherwise make a stack machine plan. ExecutionPath new_plan; CEL_ASSIGN_OR_RETURN( new_plan.emplace_back(), CreateConstValueStep(std::move(value), node.id(), false)); return context.ReplaceSubplan(node, std::move(new_plan)); } } // namespace ProgramOptimizerFactory CreateConstantFoldingOptimizer( absl_nullable std::shared_ptr arena, absl_nullable std::shared_ptr message_factory) { return [shared_arena = std::move(arena), shared_message_factory = std::move(message_factory)]( PlannerContext& context, const Ast&) -> absl::StatusOr> { // If one was explicitly provided during planning or none was explicitly // provided during configuration, request one from the planning context. // Otherwise use the one provided during configuration. google::protobuf::Arena* absl_nonnull arena = context.HasExplicitArena() || shared_arena == nullptr ? context.MutableArena() : shared_arena.get(); google::protobuf::MessageFactory* absl_nonnull message_factory = context.HasExplicitMessageFactory() || shared_message_factory == nullptr ? context.MutableMessageFactory() : shared_message_factory.get(); return std::make_unique( context.descriptor_pool(), shared_arena, arena, shared_message_factory, message_factory, context.type_reflector()); }; } } // namespace cel::runtime_internal ================================================ FILE: eval/compiler/constant_folding.h ================================================ // Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ #include #include "absl/base/nullability.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { // Create a new constant folding extension. // Eagerly evaluates sub expressions with all constant inputs, and replaces said // sub expression with the result. // // Note: the precomputed values may be allocated using the provided // MemoryManager so it must outlive any programs created with this // extension. google::api::expr::runtime::ProgramOptimizerFactory CreateConstantFoldingOptimizer( absl_nullable std::shared_ptr arena = nullptr, absl_nullable std::shared_ptr message_factory = nullptr); } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CONSTANT_FOLDING_H_ ================================================ FILE: eval/compiler/constant_folding_test.cc ================================================ // Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/constant_folding.h" #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "common/expr.h" #include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/const_value_step.h" #include "eval/eval/create_list_step.h" #include "eval/eval/create_map_step.h" #include "eval/eval/evaluator_core.h" #include "extensions/protobuf/ast_converters.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/function_registry.h" #include "runtime/internal/issue_collector.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" #include "google/protobuf/arena.h" namespace cel::runtime_internal { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::Expr; using ::cel::RuntimeIssue; using ::cel::runtime_internal::IssueCollector; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::CreateConstValueStep; using ::google::api::expr::runtime::CreateCreateListStep; using ::google::api::expr::runtime::CreateCreateStructStepForMap; using ::google::api::expr::runtime::ExecutionPath; using ::google::api::expr::runtime::PlannerContext; using ::google::api::expr::runtime::ProgramBuilder; using ::google::api::expr::runtime::ProgramOptimizer; using ::google::api::expr::runtime::ProgramOptimizerFactory; using ::google::api::expr::runtime::Resolver; using ::testing::SizeIs; class UpdatedConstantFoldingTest : public testing::Test { public: UpdatedConstantFoldingTest() : env_(NewTestingRuntimeEnv()), function_registry_(env_->function_registry), type_registry_(env_->type_registry), issue_collector_(RuntimeIssue::Severity::kError), resolver_("", function_registry_, type_registry_, type_registry_.GetComposedTypeProvider()) {} protected: absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; cel::FunctionRegistry& function_registry_; cel::TypeRegistry& type_registry_; cel::RuntimeOptions options_; IssueCollector issue_collector_; Resolver resolver_; }; absl::StatusOr> ParseFromCel( absl::string_view expression) { CEL_ASSIGN_OR_RETURN(ParsedExpr expr, Parse(expression)); return cel::extensions::CreateAstFromParsedExpr(expr); } // While CEL doesn't provide execution order guarantees per se, short circuiting // operators are treated specially to evaluate to user expectations. // // These behaviors aren't easily observable since the flat expression doesn't // expose any details about the program after building, so a lot of setup is // needed to simulate what the expression builder does. TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("true ? true : false")); const Expr& call = ast->root_expr(); const Expr& condition = call.call_expr().args()[0]; const Expr& true_branch = call.call_expr().args()[1]; const Expr& false_branch = call.call_expr().args()[2]; ProgramBuilder program_builder; program_builder.EnterSubexpression(&call); // condition program_builder.EnterSubexpression(&condition); ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::BoolValue(true), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&condition); // true program_builder.EnterSubexpression(&true_branch); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(true), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&true_branch); // false program_builder.EnterSubexpression(&false_branch); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(true), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&false_branch); // ternary. ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::NullValue(), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, condition)); ASSERT_OK(constant_folder->OnPostVisit(context, condition)); ASSERT_OK(constant_folder->OnPreVisit(context, true_branch)); ASSERT_OK(constant_folder->OnPostVisit(context, true_branch)); ASSERT_OK(constant_folder->OnPreVisit(context, false_branch)); ASSERT_OK(constant_folder->OnPostVisit(context, false_branch)); ASSERT_OK(constant_folder->OnPostVisit(context, call)); // Assert // No changes attempted. auto path = std::move(program_builder).FlattenMain(); EXPECT_THAT(path, SizeIs(4)); } TEST_F(UpdatedConstantFoldingTest, SkipsOr) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("false || true")); const Expr& call = ast->root_expr(); const Expr& left_condition = call.call_expr().args()[0]; const Expr& right_condition = call.call_expr().args()[1]; ProgramBuilder program_builder; program_builder.EnterSubexpression(&call); // left program_builder.EnterSubexpression(&left_condition); ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::BoolValue(false), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&left_condition); // right program_builder.EnterSubexpression(&right_condition); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(true), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&right_condition); // op // Just a placeholder. ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::NullValue(), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); ASSERT_OK(constant_folder->OnPreVisit(context, right_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, right_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, call)); // Assert // No changes attempted. auto path = std::move(program_builder).FlattenMain(); EXPECT_THAT(path, SizeIs(3)); } TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("true && false")); const Expr& call = ast->root_expr(); const Expr& left_condition = call.call_expr().args()[0]; const Expr& right_condition = call.call_expr().args()[1]; ProgramBuilder program_builder; program_builder.EnterSubexpression(&call); // left program_builder.EnterSubexpression(&left_condition); ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::BoolValue(true), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&left_condition); // right program_builder.EnterSubexpression(&right_condition); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(false), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&right_condition); // op // Just a placeholder. ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::NullValue(), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); ASSERT_OK(constant_folder->OnPreVisit(context, right_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, right_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, call)); // Assert // No changes attempted. ExecutionPath path = std::move(program_builder).FlattenMain(); EXPECT_THAT(path, SizeIs(3)); } TEST_F(UpdatedConstantFoldingTest, CreatesList) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("[1, 2]")); const Expr& create_list = ast->root_expr(); const Expr& elem_one = create_list.list_expr().elements()[0].expr(); const Expr& elem_two = create_list.list_expr().elements()[1].expr(); ProgramBuilder program_builder; // Simulate the visitor order. program_builder.EnterSubexpression(&create_list); // elem one program_builder.EnterSubexpression(&elem_one); ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::IntValue(1L), 1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&elem_one); // elem two program_builder.EnterSubexpression(&elem_two); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&elem_two); // createlist ASSERT_OK_AND_ASSIGN(step, CreateCreateListStep(create_list.list_expr(), 3)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_list); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, create_list)); ASSERT_OK(constant_folder->OnPreVisit(context, elem_one)); ASSERT_OK(constant_folder->OnPostVisit(context, elem_one)); ASSERT_OK(constant_folder->OnPreVisit(context, elem_two)); ASSERT_OK(constant_folder->OnPostVisit(context, elem_two)); ASSERT_OK(constant_folder->OnPostVisit(context, create_list)); // Assert // Single constant value for the two element list. ExecutionPath path = std::move(program_builder).FlattenMain(); EXPECT_THAT(path, SizeIs(1)); } TEST_F(UpdatedConstantFoldingTest, CreatesLargeList) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("[1, 2, 3, 4, 5]")); const Expr& create_list = ast->root_expr(); const Expr& elem0 = create_list.list_expr().elements()[0].expr(); const Expr& elem1 = create_list.list_expr().elements()[1].expr(); const Expr& elem2 = create_list.list_expr().elements()[2].expr(); const Expr& elem3 = create_list.list_expr().elements()[3].expr(); const Expr& elem4 = create_list.list_expr().elements()[4].expr(); ProgramBuilder program_builder; // Simulate the visitor order. ASSERT_TRUE(program_builder.EnterSubexpression(&create_list) != nullptr); // 0 ASSERT_TRUE(program_builder.EnterSubexpression(&elem0) != nullptr); ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::IntValue(1L), 1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&elem0); // 1 ASSERT_TRUE(program_builder.EnterSubexpression(&elem1)); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&elem1); // 2 ASSERT_TRUE(program_builder.EnterSubexpression(&elem2) != nullptr); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(3L), 3)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&elem2); // 3 ASSERT_TRUE(program_builder.EnterSubexpression(&elem3) != nullptr); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(4L), 4)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&elem3); // 4 ASSERT_TRUE(program_builder.EnterSubexpression(&elem4) != nullptr); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(5L), 5)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&elem4); // createlist ASSERT_OK_AND_ASSIGN(step, CreateCreateListStep(create_list.list_expr(), 6)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_list); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, constant_folder_factory(context, *ast)); ASSERT_THAT(constant_folder->OnPreVisit(context, create_list), IsOk()); ASSERT_THAT(constant_folder->OnPreVisit(context, elem0), IsOk()); ASSERT_THAT(constant_folder->OnPostVisit(context, elem0), IsOk()); ASSERT_THAT(constant_folder->OnPreVisit(context, elem1), IsOk()); ASSERT_THAT(constant_folder->OnPostVisit(context, elem1), IsOk()); ASSERT_THAT(constant_folder->OnPreVisit(context, elem2), IsOk()); ASSERT_THAT(constant_folder->OnPostVisit(context, elem2), IsOk()); ASSERT_THAT(constant_folder->OnPreVisit(context, elem3), IsOk()); ASSERT_THAT(constant_folder->OnPostVisit(context, elem3), IsOk()); ASSERT_THAT(constant_folder->OnPreVisit(context, elem4), IsOk()); ASSERT_THAT(constant_folder->OnPostVisit(context, elem4), IsOk()); ASSERT_THAT(constant_folder->OnPostVisit(context, create_list), IsOk()); // Assert // Single constant value for the two element list. ExecutionPath path = std::move(program_builder).FlattenMain(); EXPECT_THAT(path, SizeIs(1)); } TEST_F(UpdatedConstantFoldingTest, CreatesMap) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("{1: 2}")); const Expr& create_map = ast->root_expr(); const Expr& key = create_map.map_expr().entries()[0].key(); const Expr& value = create_map.map_expr().entries()[0].value(); ProgramBuilder program_builder; program_builder.EnterSubexpression(&create_map); // key program_builder.EnterSubexpression(&key); ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::IntValue(1L), 1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&key); // value program_builder.EnterSubexpression(&value); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&value); // create map ASSERT_OK_AND_ASSIGN( step, CreateCreateStructStepForMap(create_map.map_expr().entries().size(), {}, 3)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_map); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, create_map)); ASSERT_OK(constant_folder->OnPreVisit(context, key)); ASSERT_OK(constant_folder->OnPostVisit(context, key)); ASSERT_OK(constant_folder->OnPreVisit(context, value)); ASSERT_OK(constant_folder->OnPostVisit(context, value)); ASSERT_OK(constant_folder->OnPostVisit(context, create_map)); // Assert // Single constant value for the map. ExecutionPath path = std::move(program_builder).FlattenMain(); EXPECT_THAT(path, SizeIs(1)); } TEST_F(UpdatedConstantFoldingTest, CreatesInvalidMap) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("{1.0: 2}")); const Expr& create_map = ast->root_expr(); const Expr& key = create_map.map_expr().entries()[0].key(); const Expr& value = create_map.map_expr().entries()[0].value(); ProgramBuilder program_builder; program_builder.EnterSubexpression(&create_map); // key program_builder.EnterSubexpression(&key); ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::DoubleValue(1.0), 1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&key); // value program_builder.EnterSubexpression(&value); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::IntValue(2L), 2)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&value); // create map ASSERT_OK_AND_ASSIGN( step, CreateCreateStructStepForMap(create_map.map_expr().entries().size(), {}, 3)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_map); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingOptimizer(); // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, create_map)); ASSERT_OK(constant_folder->OnPreVisit(context, key)); ASSERT_OK(constant_folder->OnPostVisit(context, key)); ASSERT_OK(constant_folder->OnPreVisit(context, value)); ASSERT_OK(constant_folder->OnPostVisit(context, value)); ASSERT_OK(constant_folder->OnPostVisit(context, create_map)); ExecutionPath path = std::move(program_builder).FlattenMain(); EXPECT_THAT(path, SizeIs(1)); } TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("true && false")); const Expr& call = ast->root_expr(); const Expr& left_condition = call.call_expr().args()[0]; const Expr& right_condition = call.call_expr().args()[1]; ProgramBuilder program_builder; program_builder.EnterSubexpression(&call); // left program_builder.EnterSubexpression(&left_condition); ASSERT_OK_AND_ASSIGN(auto step, CreateConstValueStep(cel::BoolValue(true), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&left_condition); // right program_builder.EnterSubexpression(&right_condition); ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::BoolValue(false), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&right_condition); // op // Just a placeholder. ASSERT_OK_AND_ASSIGN(step, CreateConstValueStep(cel::NullValue(), -1)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&call); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ProgramOptimizerFactory constant_folder_factory = CreateConstantFoldingOptimizer(); // Act / Assert ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, constant_folder_factory(context, *ast)); EXPECT_THAT(constant_folder->OnPostVisit(context, left_condition), StatusIs(absl::StatusCode::kInternal)); } } // namespace } // namespace cel::runtime_internal ================================================ FILE: eval/compiler/flat_expr_builder.cc ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "eval/compiler/flat_expr_builder.h" #include #include #include #include #include #include #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/base/optimization.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/log/absl_check.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/ast.h" #include "base/builtins.h" #include "base/type_provider.h" #include "common/allocator.h" #include "common/ast.h" #include "common/ast_traverse.h" #include "common/ast_visitor.h" #include "common/constant.h" #include "common/expr.h" #include "common/kind.h" #include "common/type.h" #include "common/value.h" #include "eval/compiler/check_ast_extensions.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/comprehension_step.h" #include "eval/eval/const_value_step.h" #include "eval/eval/container_access_step.h" #include "eval/eval/create_list_step.h" #include "eval/eval/create_map_step.h" #include "eval/eval/create_struct_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/equality_steps.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/function_step.h" #include "eval/eval/ident_step.h" #include "eval/eval/jump_step.h" #include "eval/eval/lazy_init_step.h" #include "eval/eval/logic_step.h" #include "eval/eval/optional_or_step.h" #include "eval/eval/select_step.h" #include "eval/eval/shadowable_value_step.h" #include "eval/eval/ternary_step.h" #include "eval/eval/trace_step.h" #include "internal/status_macros.h" #include "runtime/internal/convert_constant.h" #include "runtime/internal/issue_collector.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::cel::Ast; using ::cel::AstTraverse; using ::cel::RuntimeIssue; using ::cel::StringValue; using ::cel::Value; using ::cel::runtime_internal::ConvertConstant; using ::cel::runtime_internal::GetLegacyRuntimeTypeProvider; using ::cel::runtime_internal::GetRuntimeTypeProvider; using ::cel::runtime_internal::IssueCollector; constexpr absl::string_view kOptionalOrFn = "or"; constexpr absl::string_view kOptionalOrValueFn = "orValue"; constexpr absl::string_view kBlock = "cel.@block"; // Forward declare to resolve circular dependency for short_circuiting visitors. class FlatExprVisitor; // Error code for failed recursive program building. Generally indicates an // optimization doesn't support recursive programs. absl::Status FailedRecursivePlanning() { return absl::InternalError( "failed to build recursive program. check for unsupported optimizations"); } // Helper for bookkeeping variables mapped to indexes. class IndexManager { public: IndexManager() : next_free_slot_(0), max_slot_count_(0) {} size_t ReserveSlots(size_t n) { size_t result = next_free_slot_; next_free_slot_ += n; if (next_free_slot_ > max_slot_count_) { max_slot_count_ = next_free_slot_; } return result; } size_t ReleaseSlots(size_t n) { next_free_slot_ -= n; return next_free_slot_; } size_t max_slot_count() const { return max_slot_count_; } private: size_t next_free_slot_; size_t max_slot_count_; }; // Helper for computing jump offsets. // // Jumps should be self-contained to a single expression node -- jumping // outside that range is a bug. struct ProgramStepIndex { int index; ProgramBuilder::Subexpression* subexpression; }; // A convenience wrapper for offset-calculating logic. class Jump { public: // Default constructor for empty jump. // // Users must check that jump is non-empty before calling member functions. explicit Jump() : self_index_{-1, nullptr}, jump_step_(nullptr) {} Jump(ProgramStepIndex self_index, JumpStepBase* jump_step) : self_index_(self_index), jump_step_(jump_step) {} static absl::StatusOr CalculateOffset(ProgramStepIndex base, ProgramStepIndex target) { if (target.subexpression != base.subexpression) { return absl::InternalError( "Jump target must be contained in the parent" "subexpression"); } int offset = base.subexpression->CalculateOffset(base.index, target.index); return offset; } absl::Status set_target(ProgramStepIndex target) { CEL_ASSIGN_OR_RETURN(int offset, CalculateOffset(self_index_, target)); jump_step_->set_jump_offset(offset); return absl::OkStatus(); } bool exists() { return jump_step_ != nullptr; } private: ProgramStepIndex self_index_; JumpStepBase* jump_step_; }; class CondVisitor { public: virtual ~CondVisitor() = default; virtual void PreVisit(const cel::Expr* expr) = 0; virtual void PostVisitArg(int arg_num, const cel::Expr* expr) = 0; virtual void PostVisit(const cel::Expr* expr) = 0; virtual void PostVisitTarget(const cel::Expr* expr) {} }; enum class BinaryCond { kAnd = 0, kOr, kOptionalOr, kOptionalOrValue, }; // Visitor managing the "&&" and "||" operatiions. // Implements short-circuiting if enabled. // // With short-circuiting enabled, generates a program like: // +-------------+------------------------+-----------------------+ // | PC | Step | Stack | // +-------------+------------------------+-----------------------+ // | i + 0 | | arg1 | // | i + 1 | ConditionalJump i + 4 | arg1 | // | i + 2 | | arg1, arg2 | // | i + 3 | BooleanOperator | Op(arg1, arg2) | // | i + 4 | | arg1 | Op(arg1, arg2) | // +-------------+------------------------+------------------------+ class BinaryCondVisitor : public CondVisitor { public: explicit BinaryCondVisitor(FlatExprVisitor* visitor, BinaryCond cond, bool short_circuiting) : visitor_(visitor), cond_(cond), short_circuiting_(short_circuiting) {} void PreVisit(const cel::Expr* expr) override; void PostVisitArg(int arg_num, const cel::Expr* expr) override; void PostVisit(const cel::Expr* expr) override; void PostVisitTarget(const cel::Expr* expr) override; private: FlatExprVisitor* visitor_; const BinaryCond cond_; Jump jump_step_; bool short_circuiting_; }; class TernaryCondVisitor : public CondVisitor { public: explicit TernaryCondVisitor(FlatExprVisitor* visitor) : visitor_(visitor) {} void PreVisit(const cel::Expr* expr) override; void PostVisitArg(int arg_num, const cel::Expr* expr) override; void PostVisit(const cel::Expr* expr) override; private: FlatExprVisitor* visitor_; Jump jump_to_second_; Jump error_jump_; Jump jump_after_first_; }; class ExhaustiveTernaryCondVisitor : public CondVisitor { public: explicit ExhaustiveTernaryCondVisitor(FlatExprVisitor* visitor) : visitor_(visitor) {} void PreVisit(const cel::Expr* expr) override; void PostVisitArg(int arg_num, const cel::Expr* expr) override {} void PostVisit(const cel::Expr* expr) override; private: FlatExprVisitor* visitor_; }; // Returns a hint for the number of program nodes (steps or subexpressions) that // will be created for this expr. size_t SizeHint(const cel::Expr& expr) { switch (expr.kind_case()) { case cel::ExprKindCase::kConstant: return 1; case cel::ExprKindCase::kIdentExpr: return 1; case cel::ExprKindCase::kSelectExpr: return 2; case cel::ExprKindCase::kCallExpr: return expr.call_expr().args().size() + (expr.call_expr().has_target() ? 2 : 1); case cel::ExprKindCase::kListExpr: return expr.list_expr().elements().size() + 1; case cel::ExprKindCase::kStructExpr: return expr.struct_expr().fields().size() + 1; case cel::ExprKindCase::kMapExpr: return 2 * expr.struct_expr().fields().size() + 1; default: return 1; } return 0; } // Returns whether this comprehension appears to be a standard map/filter // macro implementation. It is not exhaustive, so it is unsafe to use with // custom comprehensions outside of the standard macros or hand crafted ASTs. bool IsOptimizableListAppend(const cel::ComprehensionExpr* comprehension, bool enable_comprehension_list_append) { if (!enable_comprehension_list_append) { return false; } absl::string_view accu_var = comprehension->accu_var(); if (accu_var.empty() || comprehension->result().ident_expr().name() != accu_var) { return false; } if (!comprehension->accu_init().has_list_expr() || !comprehension->accu_init().list_expr().elements().empty()) { return false; } if (!comprehension->loop_step().has_call_expr()) { return false; } // Macro loop_step for a filter() will contain a ternary: // filter ? accu_var + [elem] : accu_var // Macro loop_step for a map() will contain a list concat operation: // accu_var + [elem] const auto* call_expr = &comprehension->loop_step().call_expr(); if (call_expr->function() == cel::builtin::kTernary && call_expr->args().size() == 3) { if (!call_expr->args()[1].has_call_expr()) { return false; } call_expr = &(call_expr->args()[1].call_expr()); } return call_expr->function() == cel::builtin::kAdd && call_expr->args().size() == 2 && call_expr->args()[0].has_ident_expr() && call_expr->args()[0].ident_expr().name() == accu_var && call_expr->args()[1].has_list_expr() && call_expr->args()[1].list_expr().elements().size() == 1; } // Assuming `IsOptimizableListAppend()` return true, return a pointer to the // call `accu_var + [elem]`. const cel::CallExpr* GetOptimizableListAppendCall( const cel::ComprehensionExpr* comprehension) { ABSL_DCHECK(IsOptimizableListAppend( comprehension, /*enable_comprehension_list_append=*/true)); // Macro loop_step for a filter() will contain a ternary: // filter ? accu_var + [elem] : accu_var // Macro loop_step for a map() will contain a list concat operation: // accu_var + [elem] const auto* call_expr = &comprehension->loop_step().call_expr(); if (call_expr->function() == cel::builtin::kTernary && call_expr->args().size() == 3) { call_expr = &(call_expr->args()[1].call_expr()); } return call_expr; } // Assuming `IsOptimizableListAppend()` return true, return a pointer to the // node `[elem]`. const cel::Expr* GetOptimizableListAppendOperand( const cel::ComprehensionExpr* comprehension) { return &GetOptimizableListAppendCall(comprehension)->args()[1]; } // Returns whether this comprehension appears to be a macro implementation for // map transformations. It is not exhaustive, so it is unsafe to use with custom // comprehensions outside of the standard macros or hand crafted ASTs. bool IsOptimizableMapInsert(const cel::ComprehensionExpr* comprehension, bool enable_comprehension_mutable_map) { if (!enable_comprehension_mutable_map) { return false; } if (comprehension->iter_var().empty() || comprehension->iter_var2().empty()) { return false; } absl::string_view accu_var = comprehension->accu_var(); if (accu_var.empty() || !comprehension->has_result() || !comprehension->result().has_ident_expr() || comprehension->result().ident_expr().name() != accu_var) { return false; } if (!comprehension->accu_init().has_map_expr()) { return false; } if (!comprehension->loop_step().has_call_expr()) { return false; } const auto* call_expr = &comprehension->loop_step().call_expr(); if (call_expr->function() == cel::builtin::kTernary && call_expr->args().size() == 3) { if (!call_expr->args()[1].has_call_expr()) { return false; } call_expr = &(call_expr->args()[1].call_expr()); } return call_expr->function() == "cel.@mapInsert" && (call_expr->args().size() == 2 || call_expr->args().size() == 3) && call_expr->args()[0].has_ident_expr() && call_expr->args()[0].ident_expr().name() == accu_var; } bool IsBind(const cel::ComprehensionExpr* comprehension) { static constexpr absl::string_view kUnusedIterVar = "#unused"; return comprehension->loop_condition().const_expr().has_bool_value() && comprehension->loop_condition().const_expr().bool_value() == false && comprehension->iter_var() == kUnusedIterVar && comprehension->iter_var2().empty() && comprehension->iter_range().has_list_expr() && comprehension->iter_range().list_expr().elements().empty(); } bool IsBlock(const cel::CallExpr* call) { return call->function() == kBlock; } // Visitor for Comprehension expressions. class ComprehensionVisitor { public: explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting, bool is_trivial, size_t iter_slot, size_t iter2_slot, size_t accu_slot) : visitor_(visitor), next_step_(nullptr), cond_step_(nullptr), short_circuiting_(short_circuiting), is_trivial_(is_trivial), accu_init_extracted_(false), iter_slot_(iter_slot), iter2_slot_(iter2_slot), accu_slot_(accu_slot) {} void PreVisit(const cel::Expr* expr); absl::Status PostVisitArg(cel::ComprehensionArg arg_num, const cel::Expr* comprehension_expr) { if (is_trivial_) { PostVisitArgTrivial(arg_num, comprehension_expr); return absl::OkStatus(); } else { return PostVisitArgDefault(arg_num, comprehension_expr); } } void PostVisit(const cel::Expr* expr); void MarkAccuInitExtracted() { accu_init_extracted_ = true; } private: void PostVisitArgTrivial(cel::ComprehensionArg arg_num, const cel::Expr* comprehension_expr); absl::Status PostVisitArgDefault(cel::ComprehensionArg arg_num, const cel::Expr* comprehension_expr); FlatExprVisitor* visitor_; ComprehensionInitStep* init_step_; ComprehensionNextStep* next_step_; ComprehensionCondStep* cond_step_; ProgramStepIndex init_step_pos_; ProgramStepIndex next_step_pos_; ProgramStepIndex cond_step_pos_; bool short_circuiting_; bool is_trivial_; bool accu_init_extracted_; size_t iter_slot_; size_t iter2_slot_; size_t accu_slot_; }; absl::flat_hash_set MakeOptionalIndicesSet( const cel::ListExpr& create_list_expr) { absl::flat_hash_set optional_indices; for (size_t i = 0; i < create_list_expr.elements().size(); ++i) { if (create_list_expr.elements()[i].optional()) { optional_indices.insert(static_cast(i)); } } return optional_indices; } absl::flat_hash_set MakeOptionalIndicesSet( const cel::StructExpr& create_struct_expr) { absl::flat_hash_set optional_indices; for (size_t i = 0; i < create_struct_expr.fields().size(); ++i) { if (create_struct_expr.fields()[i].optional()) { optional_indices.insert(static_cast(i)); } } return optional_indices; } absl::flat_hash_set MakeOptionalIndicesSet( const cel::MapExpr& map_expr) { absl::flat_hash_set optional_indices; for (size_t i = 0; i < map_expr.entries().size(); ++i) { if (map_expr.entries()[i].optional()) { optional_indices.insert(static_cast(i)); } } return optional_indices; } class FlatExprVisitor : public cel::AstVisitor { public: enum class CallHandlerResult { // The call was intercepted, no additional processing is needed. kIntercepted, // The call was not intercepted, continue with the default processing. kNotIntercepted, }; // Handler for functions with builtin implementations. // This is used to replace the usual dispatcher step that applies // the arguments to a candidate function from the function registry. using CallHandler = absl::AnyInvocable; FlatExprVisitor( const Resolver& resolver, const cel::RuntimeOptions& options, std::vector> program_optimizers, const absl::flat_hash_map& reference_map, const cel::TypeProvider& type_provider, IssueCollector& issue_collector, ProgramBuilder& program_builder, PlannerContext& extension_context, bool enable_optional_types) : resolver_(resolver), type_provider_(type_provider), progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), options_(options), program_optimizers_(std::move(program_optimizers)), issue_collector_(issue_collector), program_builder_(program_builder), extension_context_(extension_context), enable_optional_types_(enable_optional_types) { constexpr size_t kCallHandlerSizeHint = 11; call_handlers_.reserve(kCallHandlerSizeHint); call_handlers_[cel::builtin::kIndex] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleIndex(expr, call); }; call_handlers_[kBlock] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleBlock(expr, call); }; call_handlers_[cel::builtin::kAdd] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleListAppend(expr, call); }; if (options_.enable_fast_builtins) { call_handlers_[cel::builtin::kNotStrictlyFalse] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleNotStrictlyFalse(expr, call); }; call_handlers_[cel::builtin::kNotStrictlyFalseDeprecated] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleNotStrictlyFalse(expr, call); }; call_handlers_[cel::builtin::kNot] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleNot(expr, call); }; if (options_.enable_heterogeneous_equality) { for (const auto& in_op : {cel::builtin::kIn, cel::builtin::kInDeprecated, cel::builtin::kInFunction}) { call_handlers_[in_op] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleHeterogeneousEqualityIn(expr, call); }; } // Try to detect if the environment is setup with a custom equality // implementation. if (resolver_ .FindOverloads(cel::builtin::kEqual, /*receiver_style=*/false, {cel::Kind::kAny, cel::Kind::kAny}) .empty()) { call_handlers_[cel::builtin::kEqual] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleHeterogeneousEquality(expr, call, /*inequality=*/false); }; call_handlers_[cel::builtin::kInequal] = [this](const cel::Expr& expr, const cel::CallExpr& call) { return HandleHeterogeneousEquality(expr, call, /*inequality=*/true); }; } } } } void SetMaxRecursionDepth(int max_recursion_depth) { max_recursion_depth_ = max_recursion_depth; } bool PlanRecursiveProgram() const { return max_recursion_depth_ > 0; } void PreVisitExpr(const cel::Expr& expr) override { ValidateOrError(!absl::holds_alternative(expr.kind()), "Invalid empty expression"); if (!progress_status_.ok()) { return; } if (resume_from_suppressed_branch_ == nullptr && suppressed_branches_.find(&expr) != suppressed_branches_.end()) { resume_from_suppressed_branch_ = &expr; } if (block_.has_value()) { BlockInfo& block = *block_; if (block.in && block.bindings_set.contains(&expr)) { block.current_binding = &expr; } } auto* subexpression = program_builder_.EnterSubexpression(&expr, SizeHint(expr)); if (subexpression == nullptr) { progress_status_.Update( absl::InternalError("same CEL expr visited twice")); return; } for (const std::unique_ptr& optimizer : program_optimizers_) { absl::Status status = optimizer->OnPreVisit(extension_context_, expr); if (!status.ok()) { SetProgressStatusError(status); } } } void PostVisitExpr(const cel::Expr& expr) override { if (!progress_status_.ok()) { return; } if (&expr == resume_from_suppressed_branch_) { resume_from_suppressed_branch_ = nullptr; } for (const std::unique_ptr& optimizer : program_optimizers_) { absl::Status status = optimizer->OnPostVisit(extension_context_, expr); if (!status.ok()) { SetProgressStatusError(status); return; } } auto* subexpression = program_builder_.current(); if (subexpression != nullptr && options_.enable_recursive_tracing && subexpression->IsRecursive()) { auto program = subexpression->ExtractRecursiveProgram(); subexpression->set_recursive_program( std::make_unique(std::move(program.step)), program.depth); } program_builder_.ExitSubexpression(&expr); if (!comprehension_stack_.empty() && comprehension_stack_.back().is_optimizable_bind && (&comprehension_stack_.back().comprehension->accu_init() == &expr)) { SetProgressStatusError( MaybeExtractSubexpression(&expr, comprehension_stack_.back())); } if (block_.has_value()) { BlockInfo& block = *block_; if (block.current_binding == &expr) { int index = program_builder_.ExtractSubexpression(&expr); if (index == -1) { SetProgressStatusError( absl::InvalidArgumentError("failed to extract subexpression")); return; } block.subexpressions[block.current_index++] = index; block.current_binding = nullptr; } } } void PostVisitConst(const cel::Expr& expr, const cel::Constant& const_expr) override { if (!progress_status_.ok()) { return; } absl::StatusOr converted_value = ConvertConstant(const_expr, cel::NewDeleteAllocator()); if (!converted_value.ok()) { SetProgressStatusError(converted_value.status()); return; } if (options_.max_recursion_depth > 0 || options_.max_recursion_depth < 0) { SetRecursiveStep(CreateConstValueDirectStep( std::move(converted_value).value(), expr.id()), 1); return; } AddStep( CreateConstValueStep(std::move(converted_value).value(), expr.id())); } struct SlotLookupResult { int slot; int subexpression; }; // Helper to lookup a variable mapped to a slot. // // If lazy evaluation enabled and ided as a lazy expression, // subexpression and slot will be set. SlotLookupResult LookupSlot(absl::string_view path) { // If there's a leading dot, it cannot resolve to a local variable. if (absl::StartsWith(path, ".")) { return {-1, -1}; } if (block_.has_value()) { const BlockInfo& block = *block_; if (block.in) { absl::string_view index_suffix = path; if (absl::ConsumePrefix(&index_suffix, "@index")) { size_t index; if (!absl::SimpleAtoi(index_suffix, &index)) { SetProgressStatusError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("bad @index")))); return {-1, -1}; } if (index >= block.size) { SetProgressStatusError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "invalid @index greater than number of bindings: ", index, " >= ", block.size))))); return {-1, -1}; } if (index >= block.current_index) { SetProgressStatusError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "@index references current or future binding: ", index, " >= ", block.current_index))))); return {-1, -1}; } return {static_cast(block.index + index), block.subexpressions[index]}; } } } if (!comprehension_stack_.empty()) { for (int i = comprehension_stack_.size() - 1; i >= 0; i--) { const ComprehensionStackRecord& record = comprehension_stack_[i]; if (record.iter_var_in_scope && record.comprehension->iter_var() == path) { if (record.is_optimizable_bind) { SetProgressStatusError(issue_collector_.AddIssue( RuntimeIssue::CreateWarning(absl::InvalidArgumentError( "Unexpected iter_var access in trivial comprehension")))); return {-1, -1}; } return {static_cast(record.iter_slot), -1}; } if (record.iter_var2_in_scope && record.comprehension->iter_var2() == path) { return {static_cast(record.iter2_slot), -1}; } if (record.accu_var_in_scope && record.comprehension->accu_var() == path) { int slot = record.accu_slot; int subexpression = -1; if (record.is_optimizable_bind) { subexpression = record.subexpression; } return {slot, subexpression}; } } } if (absl::StartsWith(path, "@it:") || absl::StartsWith(path, "@it2:") || absl::StartsWith(path, "@ac:")) { // If we see a CSE generated comprehension variable that was not // resolvable through the normal comprehension scope resolution, reject it // now rather than surfacing errors at activation time. SetProgressStatusError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("out of scope reference to CSE " "generated comprehension variable")))); } return {-1, -1}; } // Ident node handler. // Invoked after child nodes are processed. void PostVisitIdent(const cel::Expr& expr, const cel::IdentExpr& ident_expr) override { if (!progress_status_.ok()) { return; } absl::string_view path = ident_expr.name(); if (!ValidateOrError( !path.empty(), "Invalid expression: identifier 'name' must not be empty")) { return; } // Check if this is a local variable first (since it should shadow most // other interpretations). SlotLookupResult slot = LookupSlot(path); if (slot.subexpression >= 0) { auto* subexpression = program_builder_.GetExtractedSubexpression(slot.subexpression); if (subexpression == nullptr) { SetProgressStatusError( absl::InternalError("bad subexpression reference")); return; } if (subexpression->IsRecursive()) { const auto& program = subexpression->recursive_program(); SetRecursiveStep( CreateDirectLazyInitStep(slot.slot, program.step.get(), expr.id()), program.depth + 1); } else { // Off by one since mainline expression will be index 0. AddStep( CreateLazyInitStep(slot.slot, slot.subexpression + 1, expr.id())); } return; } else if (slot.slot >= 0) { if (options_.max_recursion_depth != 0) { SetRecursiveStep( CreateDirectSlotIdentStep(ident_expr.name(), slot.slot, expr.id()), 1); } else { AddStep( CreateIdentStepForSlot(ident_expr.name(), slot.slot, expr.id())); } return; } // Attempt to resolve a select expression as a namespaced identifier for an // enum or type constant value. absl::optional const_value; int64_t select_root_id = -1; std::string path_candidate; while (!namespace_stack_.empty()) { const auto& select_node = namespace_stack_.front(); // Generate path in format ".....". const cel::Expr* select_expr = select_node.first; path_candidate = absl::StrCat(path, ".", select_node.second); // Attempt to find a constant enum or type value which matches the // qualified path present in the expression. Whether the identifier // can be resolved to a type instance depends on whether the option to // 'enable_qualified_type_identifiers' is set to true. const_value = resolver_.FindConstant(path_candidate, select_expr->id()); if (const_value) { resolved_select_expr_ = select_expr; select_root_id = select_expr->id(); path = path_candidate; namespace_stack_.clear(); break; } namespace_stack_.pop_front(); } if (!const_value) { // Attempt to resolve a simple identifier as an enum or type constant // value. const_value = resolver_.FindConstant(path, expr.id()); select_root_id = expr.id(); } // TODO(issues/97): Need to add support for resolving packaged names at // runtime if Parse-only. For checked, checker should have reported the // expected interpretation. if (const_value) { // If the path starts with a dot, strip it. absl::string_view name = absl::StripPrefix(path, "."); if (options_.max_recursion_depth != 0) { SetRecursiveStep( CreateDirectShadowableValueStep( name, std::move(const_value).value(), select_root_id), 1); return; } AddStep(CreateShadowableValueStep(name, std::move(const_value).value(), select_root_id)); return; } absl::string_view ident_name = absl::StripPrefix(ident_expr.name(), "."); if (options_.max_recursion_depth != 0) { SetRecursiveStep(CreateDirectIdentStep(ident_name, expr.id()), 1); } else { AddStep(CreateIdentStep(ident_name, expr.id())); } } void PreVisitSelect(const cel::Expr& expr, const cel::SelectExpr& select_expr) override { if (!progress_status_.ok()) { return; } if (!ValidateOrError( !select_expr.field().empty(), "invalid expression: select 'field' must not be empty")) { return; } if (!ValidateOrError( select_expr.has_operand() && select_expr.operand().kind_case() != cel::ExprKindCase::kUnspecifiedExpr, "invalid expression: select must specify an operand")) { return; } // Not exactly the cleanest solution - we peek into child of // select_expr. // Chain of multiple SELECT ending with IDENT can represent namespaced // entity. if (!select_expr.test_only() && (select_expr.operand().has_ident_expr() || select_expr.operand().has_select_expr())) { // select expressions are pushed in reverse order: // google.type.Expr is pushed as: // - field: 'Expr' // - field: 'type' // - id: 'google' // // The search order though is as follows: // - id: 'google.type.Expr' // - id: 'google.type', field: 'Expr' // - id: 'google', field: 'type', field: 'Expr' for (size_t i = 0; i < namespace_stack_.size(); i++) { auto ns = namespace_stack_[i]; namespace_stack_[i] = { ns.first, absl::StrCat(select_expr.field(), ".", ns.second)}; } namespace_stack_.push_back({&expr, select_expr.field()}); } else { namespace_stack_.clear(); } } // Select node handler. // Invoked after child nodes are processed. void PostVisitSelect(const cel::Expr& expr, const cel::SelectExpr& select_expr) override { if (!progress_status_.ok()) { return; } // Check if we are "in the middle" of namespaced name. // This is currently enum specific. Constant expression that corresponds // to resolved enum value has been already created, thus preceding chain // of selects is no longer relevant. if (resolved_select_expr_) { if (&expr == resolved_select_expr_) { resolved_select_expr_ = nullptr; } return; } if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { SetProgressStatusError(absl::InternalError( "unexpected number of dependencies for select operation.")); return; } StringValue field = cel::StringValue(select_expr.field()); SetRecursiveStep( CreateDirectSelectStep(std::move(deps[0]), std::move(field), select_expr.test_only(), expr.id(), options_.enable_empty_wrapper_null_unboxing, enable_optional_types_), *depth + 1); return; } AddStep(CreateSelectStep(select_expr, expr.id(), options_.enable_empty_wrapper_null_unboxing, enable_optional_types_)); } // Call node handler group. // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. void PreVisitCall(const cel::Expr& expr, const cel::CallExpr& call_expr) override { if (!progress_status_.ok()) { return; } std::unique_ptr cond_visitor; if (call_expr.function() == cel::builtin::kAnd) { cond_visitor = std::make_unique( this, BinaryCond::kAnd, options_.short_circuiting); } else if (call_expr.function() == cel::builtin::kOr) { cond_visitor = std::make_unique( this, BinaryCond::kOr, options_.short_circuiting); } else if (call_expr.function() == cel::builtin::kTernary) { if (options_.short_circuiting) { cond_visitor = std::make_unique(this); } else { cond_visitor = std::make_unique(this); } } else if (enable_optional_types_ && call_expr.function() == kOptionalOrFn && call_expr.has_target() && call_expr.args().size() == 1) { cond_visitor = std::make_unique( this, BinaryCond::kOptionalOr, options_.short_circuiting); } else if (enable_optional_types_ && call_expr.function() == kOptionalOrValueFn && call_expr.has_target() && call_expr.args().size() == 1) { cond_visitor = std::make_unique( this, BinaryCond::kOptionalOrValue, options_.short_circuiting); } else if (IsBlock(&call_expr)) { // cel.@block if (block_.has_value()) { // There can only be one for now. SetProgressStatusError( absl::InvalidArgumentError("multiple cel.@block are not allowed")); return; } block_ = BlockInfo(); BlockInfo& block = *block_; block.in = true; if (call_expr.args().empty()) { SetProgressStatusError(absl::InvalidArgumentError( "malformed cel.@block: missing list of bound expressions")); return; } if (call_expr.args().size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "malformed cel.@block: missing bound expression")); return; } if (!call_expr.args()[0].has_list_expr()) { SetProgressStatusError( absl::InvalidArgumentError("malformed cel.@block: first argument " "is not a list of bound expressions")); return; } const auto& list_expr = call_expr.args().front().list_expr(); block.size = list_expr.elements().size(); if (block.size == 0) { SetProgressStatusError(absl::InvalidArgumentError( "malformed cel.@block: list of bound expressions is empty")); return; } block.bindings_set.reserve(block.size); for (const auto& list_expr_element : list_expr.elements()) { if (list_expr_element.optional()) { SetProgressStatusError( absl::InvalidArgumentError("malformed cel.@block: list of bound " "expressions contains an optional")); return; } block.bindings_set.insert(&list_expr_element.expr()); } block.index = index_manager().ReserveSlots(block.size); block.slot_count = block.size; block.expr = &expr; block.bindings = &call_expr.args()[0]; block.bound = &call_expr.args()[1]; block.subexpressions.resize(block.size, -1); } else { return; } if (cond_visitor) { cond_visitor->PreVisit(&expr); cond_visitor_stack_.push({&expr, std::move(cond_visitor)}); } } // Returns the maximum recursion depth of the current program if it is // eligible for recursion, or nullopt if it is not. absl::optional RecursionEligible() { if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { return absl::nullopt; } return program_builder_.current()->RecursiveDependencyDepth(); } std::vector> ExtractRecursiveDependencies() { // Must check recursion eligibility before calling. ABSL_DCHECK(program_builder_.current() != nullptr); return program_builder_.current()->ExtractRecursiveDependencies(); } void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); return; } const cel::Expr* condition_expr = &expr->call_expr().args()[0]; const cel::Expr* left_expr = &expr->call_expr().args()[1]; const cel::Expr* right_expr = &expr->call_expr().args()[2]; auto* condition_plan = program_builder_.GetSubexpression(condition_expr); auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); if (condition_plan == nullptr || !condition_plan->IsRecursive() || left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { SetProgressStatusError(FailedRecursivePlanning()); return; } int max_depth = std::max({0, condition_plan->recursive_program().depth, left_plan->recursive_program().depth, right_plan->recursive_program().depth}); SetRecursiveStep( CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, left_plan->ExtractRecursiveProgram().step, right_plan->ExtractRecursiveProgram().step, expr->id(), options_.short_circuiting), max_depth + 1); } void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { if (expr->call_expr().args().size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); return; } const cel::Expr* left_expr = &expr->call_expr().args()[0]; const cel::Expr* right_expr = &expr->call_expr().args()[1]; auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); if (left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { SetProgressStatusError(FailedRecursivePlanning()); return; } int max_depth = std::max({0, left_plan->recursive_program().depth, right_plan->recursive_program().depth}); if (is_or) { SetRecursiveStep( CreateDirectOrStep(left_plan->ExtractRecursiveProgram().step, right_plan->ExtractRecursiveProgram().step, expr->id(), options_.short_circuiting), max_depth + 1); } else { SetRecursiveStep( CreateDirectAndStep(left_plan->ExtractRecursiveProgram().step, right_plan->ExtractRecursiveProgram().step, expr->id(), options_.short_circuiting), max_depth + 1); } } void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for optional.or{Value}")); return; } const cel::Expr* left_expr = &expr->call_expr().target(); const cel::Expr* right_expr = &expr->call_expr().args()[0]; auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); if (left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { SetProgressStatusError(FailedRecursivePlanning()); return; } int max_depth = std::max({0, left_plan->recursive_program().depth, right_plan->recursive_program().depth}); SetRecursiveStep(CreateDirectOptionalOrStep( expr->id(), left_plan->ExtractRecursiveProgram().step, right_plan->ExtractRecursiveProgram().step, is_or_value, options_.short_circuiting), max_depth + 1); } void MaybeMakeBindRecursive(const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t accu_slot) { if (!PlanRecursiveProgram()) { return; } auto* result_plan = program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { SetProgressStatusError(FailedRecursivePlanning()); return; } int result_depth = result_plan->recursive_program().depth; auto program = result_plan->ExtractRecursiveProgram(); SetRecursiveStep( CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), result_depth + 1); } void MaybeMakeComprehensionRecursive( const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t iter_slot, size_t iter2_slot, size_t accu_slot) { if (!PlanRecursiveProgram()) { return; } auto* accu_plan = program_builder_.GetSubexpression(&comprehension->accu_init()); auto* range_plan = program_builder_.GetSubexpression(&comprehension->iter_range()); auto* loop_plan = program_builder_.GetSubexpression(&comprehension->loop_step()); auto* condition_plan = program_builder_.GetSubexpression(&comprehension->loop_condition()); auto* result_plan = program_builder_.GetSubexpression(&comprehension->result()); if (accu_plan == nullptr || !accu_plan->IsRecursive() || range_plan == nullptr || !range_plan->IsRecursive() || loop_plan == nullptr || !loop_plan->IsRecursive() || condition_plan == nullptr || !condition_plan->IsRecursive() || result_plan == nullptr || !result_plan->IsRecursive()) { SetProgressStatusError(FailedRecursivePlanning()); return; } int max_depth = 0; max_depth = std::max(max_depth, accu_plan->recursive_program().depth); max_depth = std::max(max_depth, range_plan->recursive_program().depth); max_depth = std::max(max_depth, loop_plan->recursive_program().depth); max_depth = std::max(max_depth, condition_plan->recursive_program().depth); max_depth = std::max(max_depth, result_plan->recursive_program().depth); auto step = CreateDirectComprehensionStep( iter_slot, iter2_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, accu_plan->ExtractRecursiveProgram().step, loop_plan->ExtractRecursiveProgram().step, condition_plan->ExtractRecursiveProgram().step, result_plan->ExtractRecursiveProgram().step, options_.short_circuiting, expr->id()); SetRecursiveStep(std::move(step), max_depth + 1); } // Invoked after all child nodes are processed. void PostVisitCall(const cel::Expr& expr, const cel::CallExpr& call_expr) override { if (!progress_status_.ok()) { return; } auto cond_visitor = FindCondVisitor(&expr); if (cond_visitor) { cond_visitor->PostVisit(&expr); cond_visitor_stack_.pop(); return; } // Check if the call is intercepted by a custom handler. if (auto handler = call_handlers_.find(call_expr.function()); handler != call_handlers_.end()) { CallHandlerResult result = handler->second(expr, call_expr); if (result == CallHandlerResult::kIntercepted) { return; } // otherwise, apply default function handling. } AddResolvedFunctionStep(&call_expr, &expr, call_expr.function()); } void PreVisitComprehension( const cel::Expr& expr, const cel::ComprehensionExpr& comprehension) override { if (!progress_status_.ok()) { return; } if (!ValidateOrError(options_.enable_comprehension, "Comprehension support is disabled")) { return; } const auto& accu_var = comprehension.accu_var(); const auto& iter_var = comprehension.iter_var(); const auto& iter_var2 = comprehension.iter_var2(); ValidateOrError(!accu_var.empty(), "Invalid comprehension: 'accu_var' must not be empty"); ValidateOrError(!iter_var.empty(), "Invalid comprehension: 'iter_var' must not be empty"); ValidateOrError( accu_var != iter_var, "Invalid comprehension: 'accu_var' must not be the same as 'iter_var'"); ValidateOrError(accu_var != iter_var2, "Invalid comprehension: 'accu_var' must not be the same as " "'iter_var2'"); ValidateOrError(iter_var2 != iter_var, "Invalid comprehension: 'iter_var2' must not be the same " "as 'iter_var'"); ValidateOrError(comprehension.has_accu_init(), "Invalid comprehension: 'accu_init' must be set"); ValidateOrError(comprehension.has_loop_condition(), "Invalid comprehension: 'loop_condition' must be set"); ValidateOrError(comprehension.has_loop_step(), "Invalid comprehension: 'loop_step' must be set"); ValidateOrError(comprehension.has_result(), "Invalid comprehension: 'result' must be set"); size_t iter_slot, iter2_slot, accu_slot, slot_count; bool is_bind = IsBind(&comprehension); if (is_bind) { accu_slot = iter_slot = iter2_slot = index_manager_.ReserveSlots(1); slot_count = 1; } else if (comprehension.iter_var2().empty()) { iter_slot = iter2_slot = index_manager_.ReserveSlots(2); accu_slot = iter_slot + 1; slot_count = 2; } else { iter_slot = index_manager_.ReserveSlots(3); iter2_slot = iter_slot + 1; accu_slot = iter2_slot + 1; slot_count = 3; } if (block_.has_value()) { BlockInfo& block = *block_; if (block.in) { block.slot_count += slot_count; slot_count = 0; } } // If this is in the scope of an optimized bind accu-init, account the slots // to the outermost bind-init scope. // // The init expression is effectively inlined at the first usage in the // critical path (which is unknown at plan time), so the used slots need to // be dedicated for the entire scope of that bind. for (ComprehensionStackRecord& record : comprehension_stack_) { if (record.in_accu_init && record.is_optimizable_bind) { record.slot_count += slot_count; slot_count = 0; break; } // If no bind init subexpression, account normally. } comprehension_stack_.push_back( {&expr, &comprehension, iter_slot, iter2_slot, accu_slot, slot_count, /*subexpression=*/-1, /*.is_optimizable_list_append=*/ IsOptimizableListAppend(&comprehension, options_.enable_comprehension_list_append), /*.is_optimizable_map_insert=*/ IsOptimizableMapInsert(&comprehension, options_.enable_comprehension_mutable_map), /*.is_optimizable_bind=*/is_bind, /*.iter_var_in_scope=*/false, /*.iter_var2_in_scope=*/false, /*.accu_var_in_scope=*/false, /*.in_accu_init=*/false, std::make_unique(this, options_.short_circuiting, is_bind, iter_slot, iter2_slot, accu_slot)}); comprehension_stack_.back().visitor->PreVisit(&expr); } // Invoked after all child nodes are processed. void PostVisitComprehension( const cel::Expr& expr, const cel::ComprehensionExpr& comprehension_expr) override { if (!progress_status_.ok()) { return; } ComprehensionStackRecord& record = comprehension_stack_.back(); if (comprehension_stack_.empty() || record.comprehension != &comprehension_expr) { return; } record.visitor->PostVisit(&expr); index_manager_.ReleaseSlots(record.slot_count); comprehension_stack_.pop_back(); } void PreVisitComprehensionSubexpression( const cel::Expr& expr, const cel::ComprehensionExpr& compr, cel::ComprehensionArg comprehension_arg) override { if (!progress_status_.ok()) { return; } if (comprehension_stack_.empty() || comprehension_stack_.back().comprehension != &compr) { return; } ComprehensionStackRecord& record = comprehension_stack_.back(); switch (comprehension_arg) { case cel::ITER_RANGE: { record.in_accu_init = false; record.iter_var_in_scope = false; record.iter_var2_in_scope = false; record.accu_var_in_scope = false; break; } case cel::ACCU_INIT: { record.in_accu_init = true; record.iter_var_in_scope = false; record.iter_var2_in_scope = false; record.accu_var_in_scope = false; break; } case cel::LOOP_CONDITION: { record.in_accu_init = false; record.iter_var_in_scope = true; record.iter_var2_in_scope = true; record.accu_var_in_scope = true; break; } case cel::LOOP_STEP: { record.in_accu_init = false; record.iter_var_in_scope = true; record.iter_var2_in_scope = true; record.accu_var_in_scope = true; break; } case cel::RESULT: { record.in_accu_init = false; record.iter_var_in_scope = false; record.iter_var2_in_scope = false; record.accu_var_in_scope = true; break; } } } void PostVisitComprehensionSubexpression( const cel::Expr& expr, const cel::ComprehensionExpr& compr, cel::ComprehensionArg comprehension_arg) override { if (!progress_status_.ok()) { return; } if (comprehension_stack_.empty() || comprehension_stack_.back().comprehension != &compr) { return; } SetProgressStatusError(comprehension_stack_.back().visitor->PostVisitArg( comprehension_arg, comprehension_stack_.back().expr)); } // Invoked after each argument node processed. void PostVisitArg(const cel::Expr& expr, int arg_num) override { if (!progress_status_.ok()) { return; } auto cond_visitor = FindCondVisitor(&expr); if (cond_visitor) { cond_visitor->PostVisitArg(arg_num, &expr); } } void PostVisitTarget(const cel::Expr& expr) override { if (!progress_status_.ok()) { return; } auto cond_visitor = FindCondVisitor(&expr); if (cond_visitor) { cond_visitor->PostVisitTarget(&expr); } } // CreateList node handler. // Invoked after child nodes are processed. void PostVisitList(const cel::Expr& expr, const cel::ListExpr& list_expr) override { if (!progress_status_.ok()) { return; } if (block_.has_value()) { BlockInfo& block = *block_; if (block.bindings == &expr) { // Do nothing, this is the cel.@block bindings list. return; } } if (!comprehension_stack_.empty()) { const ComprehensionStackRecord& comprehension = comprehension_stack_.back(); if (comprehension.is_optimizable_list_append) { if (&(comprehension.comprehension->accu_init()) == &expr) { if (PlanRecursiveProgram()) { SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); return; } AddStep(CreateMutableListStep(expr.id())); return; } if (GetOptimizableListAppendOperand(comprehension.comprehension) == &expr) { return; } } } if (absl::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { SetProgressStatusError(absl::InternalError( "Unexpected number of plan elements for CreateList expr")); return; } auto step = CreateDirectListStep( std::move(deps), MakeOptionalIndicesSet(list_expr), expr.id()); SetRecursiveStep(std::move(step), *depth + 1); return; } AddStep(CreateCreateListStep(list_expr, expr.id())); } // CreateStruct node handler. // Invoked after child nodes are processed. void PostVisitStruct(const cel::Expr& expr, const cel::StructExpr& struct_expr) override { if (!progress_status_.ok()) { return; } auto status_or_resolved_fields = ResolveCreateStructFields(struct_expr, expr.id()); if (!status_or_resolved_fields.ok()) { SetProgressStatusError(status_or_resolved_fields.status()); return; } std::string resolved_name = std::move(status_or_resolved_fields.value().first); std::vector fields = std::move(status_or_resolved_fields.value().second); if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { SetProgressStatusError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } auto step = CreateDirectCreateStructStep( std::move(resolved_name), std::move(fields), std::move(deps), MakeOptionalIndicesSet(struct_expr), expr.id()); SetRecursiveStep(std::move(step), *depth + 1); return; } AddStep(CreateCreateStructStep(std::move(resolved_name), std::move(fields), MakeOptionalIndicesSet(struct_expr), expr.id())); } void PostVisitMap(const cel::Expr& expr, const cel::MapExpr& map_expr) override { for (const auto& entry : map_expr.entries()) { ValidateOrError(entry.has_key(), "Map entry missing key"); ValidateOrError(entry.has_value(), "Map entry missing value"); } if (!comprehension_stack_.empty()) { const ComprehensionStackRecord& comprehension = comprehension_stack_.back(); if (comprehension.is_optimizable_map_insert) { if (&(comprehension.comprehension->accu_init()) == &expr) { if (PlanRecursiveProgram()) { SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); return; } AddStep(CreateMutableMapStep(expr.id())); return; } } } if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { SetProgressStatusError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } auto step = CreateDirectCreateMapStep( std::move(deps), MakeOptionalIndicesSet(map_expr), expr.id()); SetRecursiveStep(std::move(step), *depth + 1); return; } AddStep(CreateCreateStructStepForMap(map_expr.entries().size(), MakeOptionalIndicesSet(map_expr), expr.id())); } absl::Status progress_status() const { return progress_status_; } // Mark a branch as suppressed. The visitor will continue as normal, but // any emitted program steps are ignored. // // Only applies to branches that have not yet been visited (pre-order). void SuppressBranch(const cel::Expr* expr) { suppressed_branches_.insert(expr); } void AddResolvedFunctionStep(const cel::CallExpr* call_expr, const cel::Expr* expr, absl::string_view function) { // Establish the search criteria for a given function. bool receiver_style = call_expr->has_target(); size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0); // First, search for lazily defined function overloads. // Lazy functions shadow eager functions with the same signature. auto lazy_overloads = resolver_.FindLazyOverloads( function, call_expr->has_target(), num_args, expr->id()); if (!lazy_overloads.empty()) { if (auto depth = RecursionEligible(); depth.has_value()) { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep(CreateDirectLazyFunctionStep( expr->id(), *call_expr, std::move(args), std::move(lazy_overloads)), *depth + 1); return; } AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(lazy_overloads))); return; } // Second, search for eagerly defined function overloads. auto overloads = resolver_.FindOverloads(function, receiver_style, num_args, expr->id()); if (overloads.empty()) { // Create a warning that the overload could not be found. Depending on the // builder_warnings configuration, this could result in termination of the // CelExpression creation or an inspectable warning for use within runtime // logging. auto status = issue_collector_.AddIssue(RuntimeIssue::CreateWarning( absl::InvalidArgumentError( "No overloads provided for FunctionStep creation"), RuntimeIssue::ErrorCode::kNoMatchingOverload)); if (!status.ok()) { SetProgressStatusError(status); return; } } if (auto recursion_depth = RecursionEligible(); recursion_depth.has_value()) { // Nonnull while active -- nullptr indicates logic error elsewhere in the // builder. ABSL_DCHECK(program_builder_.current() != nullptr); auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep( CreateDirectFunctionStep(expr->id(), *call_expr, std::move(args), std::move(overloads)), *recursion_depth + 1); return; } AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(overloads))); } // Add a step to the program, taking ownership. If successful, returns the // pointer to the step. Otherwise, returns nullptr. // // Note: the pointer is only guaranteed to stay valid until the parent // subexpression is finalized. Optimizers may modify the program plan which // may free the step at that point. ExpressionStep* AddStep( absl::StatusOr> step) { if (step.ok()) { return AddStep(*std::move(step)); } else { SetProgressStatusError(step.status()); } return nullptr; } template std::enable_if_t, T*> AddStep( std::unique_ptr step) { if (progress_status_.ok() && !PlanningSuppressed()) { return static_cast(program_builder_.AddStep(std::move(step))); } return nullptr; } void SetRecursiveStep(std::unique_ptr step, int depth) { if (!progress_status_.ok() || PlanningSuppressed()) { return; } if (program_builder_.current() == nullptr) { SetProgressStatusError(absl::InternalError( "CEL AST traversal out of order in flat_expr_builder.")); return; } program_builder_.current()->set_recursive_program(std::move(step), depth); if (depth > max_recursion_depth_) { SetProgressStatusError(absl::InvalidArgumentError( absl::StrCat("Maximum recursion depth of ", options_.max_recursion_depth, " exceeded"))); } } void SetProgressStatusError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } } // Index of the next step to be inserted, in terms of the current // subexpression ProgramStepIndex GetCurrentIndex() const { // Nonnull while active -- nullptr indicates logic error in the builder. ABSL_DCHECK(program_builder_.current() != nullptr); return {static_cast(program_builder_.current()->elements().size()), program_builder_.current()}; } CondVisitor* FindCondVisitor(const cel::Expr* expr) const { if (cond_visitor_stack_.empty()) { return nullptr; } const auto& latest = cond_visitor_stack_.top(); return (latest.first == expr) ? latest.second.get() : nullptr; } IndexManager& index_manager() { return index_manager_; } size_t slot_count() const { return index_manager_.max_slot_count(); } void AddOptimizer(std::unique_ptr optimizer) { program_optimizers_.push_back(std::move(optimizer)); } // Tests the boolean predicate, and if false produces an InvalidArgumentError // which concatenates the error_message and any optional message_parts as the // error status message. template bool ValidateOrError(bool valid_expression, absl::string_view error_message, MP... message_parts) { if (valid_expression) { return true; } SetProgressStatusError(absl::InvalidArgumentError( absl::StrCat(error_message, message_parts...))); return false; } private: struct ComprehensionStackRecord { const cel::Expr* expr; const cel::ComprehensionExpr* comprehension; size_t iter_slot; size_t iter2_slot; size_t accu_slot; size_t slot_count; // -1 indicates this shouldn't be used. int subexpression; bool is_optimizable_list_append; bool is_optimizable_map_insert; bool is_optimizable_bind; bool iter_var_in_scope; bool iter_var2_in_scope; bool accu_var_in_scope; bool in_accu_init; std::unique_ptr visitor; }; struct BlockInfo { // True if we are currently visiting the `cel.@block` node or any of its // children. bool in = false; // Pointer to the `cel.@block` node. const cel::Expr* expr = nullptr; // Pointer to the `cel.@block` bindings, that is the first argument to the // function. const cel::Expr* bindings = nullptr; // Set of pointers to the elements of `bindings` above. absl::flat_hash_set bindings_set; // Pointer to the `cel.@block` bound expression, that is the second argument // to the function. const cel::Expr* bound = nullptr; // The number of entries in the `cel.@block`. size_t size = 0; // Starting slot index for `cel.@block`. We occupy he slot indices `index` // through `index + size + (var_size * 2)`. size_t index = 0; // The total number of slots needed for evaluating the bound expressions. size_t slot_count = 0; // The current slot index we are processing, any index references must be // less than this to be valid. size_t current_index = 0; // Pointer to the current `cel.@block` being processed, that is one of the // elements within the first argument. const cel::Expr* current_binding = nullptr; // Mapping between block indices and their subexpressions, fixed size with // exactly `size` elements. Unprocessed indices are set to `-1`. std::vector subexpressions; }; bool PlanningSuppressed() const { return resume_from_suppressed_branch_ != nullptr; } absl::Status MaybeExtractSubexpression(const cel::Expr* expr, ComprehensionStackRecord& record) { if (!record.is_optimizable_bind) { return absl::OkStatus(); } int index = program_builder_.ExtractSubexpression(expr); if (index == -1) { return absl::InternalError("Failed to extract subexpression"); } record.subexpression = index; record.visitor->MarkAccuInitExtracted(); return absl::OkStatus(); } // Resolve the name of the message type being created and the names of set // fields. absl::StatusOr>> ResolveCreateStructFields(const cel::StructExpr& create_struct_expr, int64_t expr_id) { absl::string_view ast_name = create_struct_expr.name(); absl::optional> type; CEL_ASSIGN_OR_RETURN(type, resolver_.FindType(ast_name, expr_id)); if (!type.has_value()) { return absl::InvalidArgumentError(absl::StrCat( "Invalid struct creation: missing type info for '", ast_name, "'")); } std::string resolved_name = std::move(type).value().first; std::vector fields; fields.reserve(create_struct_expr.fields().size()); for (const auto& entry : create_struct_expr.fields()) { if (entry.name().empty()) { return absl::InvalidArgumentError("Struct field missing name"); } if (!entry.has_value()) { return absl::InvalidArgumentError("Struct field missing value"); } CEL_ASSIGN_OR_RETURN(auto field, type_provider_.FindStructTypeFieldByName( resolved_name, entry.name())); if (!field.has_value()) { return absl::InvalidArgumentError( absl::StrCat("Invalid message creation: field '", entry.name(), "' not found in '", resolved_name, "'")); } fields.push_back(entry.name()); } return std::make_pair(std::move(resolved_name), std::move(fields)); } CallHandlerResult HandleIndex(const cel::Expr& expr, const cel::CallExpr& call); CallHandlerResult HandleBlock(const cel::Expr& expr, const cel::CallExpr& call); CallHandlerResult HandleListAppend(const cel::Expr& expr, const cel::CallExpr& call); CallHandlerResult HandleNot(const cel::Expr& expr, const cel::CallExpr& call); CallHandlerResult HandleNotStrictlyFalse(const cel::Expr& expr, const cel::CallExpr& call); CallHandlerResult HandleHeterogeneousEquality(const cel::Expr& expr, const cel::CallExpr& call, bool inequality); CallHandlerResult HandleHeterogeneousEqualityIn(const cel::Expr& expr, const cel::CallExpr& call); const Resolver& resolver_; const cel::TypeProvider& type_provider_; absl::Status progress_status_; absl::flat_hash_map call_handlers_; std::stack>> cond_visitor_stack_; // Tracks SELECT-...SELECT-IDENT chains. std::deque> namespace_stack_; // When multiple SELECT-...SELECT-IDENT chain is resolved as namespace, this // field is used as marker suppressing CelExpression creation for SELECTs. const cel::Expr* resolved_select_expr_; const cel::RuntimeOptions& options_; std::vector comprehension_stack_; absl::flat_hash_set suppressed_branches_; const cel::Expr* resume_from_suppressed_branch_ = nullptr; std::vector> program_optimizers_; IssueCollector& issue_collector_; ProgramBuilder& program_builder_; PlannerContext& extension_context_; IndexManager index_manager_; bool enable_optional_types_; absl::optional block_; int max_recursion_depth_ = 0; }; FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex); if (!ValidateOrError( (call_expr.args().size() == 2 && !call_expr.has_target()) || // TODO(uncreated-issue/79): A few clients use the index operator with a // target in custom ASTs. (call_expr.args().size() == 1 && call_expr.has_target()), "unexpected number of args for builtin index operator")) { return CallHandlerResult::kIntercepted; } if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin index operator")); return CallHandlerResult::kIntercepted; } SetRecursiveStep( CreateDirectContainerAccessStep(std::move(args[0]), std::move(args[1]), enable_optional_types_, expr.id()), *depth + 1); return CallHandlerResult::kIntercepted; } AddStep( CreateContainerAccessStep(call_expr, expr.id(), enable_optional_types_)); return CallHandlerResult::kIntercepted; } FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == cel::builtin::kNot); if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), "unexpected number of args for builtin not operator")) { return CallHandlerResult::kIntercepted; } if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin not operator")); return CallHandlerResult::kIntercepted; } SetRecursiveStep(CreateDirectNotStep(std::move(args[0]), expr.id()), *depth + 1); return CallHandlerResult::kIntercepted; } AddStep(CreateNotStep(expr.id())); return CallHandlerResult::kIntercepted; } FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( const cel::Expr& expr, const cel::CallExpr& call_expr) { if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), "unexpected number of args for builtin " "not_strictly_false operator")) { return CallHandlerResult::kIntercepted; } if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError( absl::InvalidArgumentError("unexpected number of args for builtin " "@not_strictly_false operator")); return CallHandlerResult::kIntercepted; } SetRecursiveStep( CreateDirectNotStrictlyFalseStep(std::move(args[0]), expr.id()), *depth + 1); return CallHandlerResult::kIntercepted; } AddStep(CreateNotStrictlyFalseStep(expr.id())); return CallHandlerResult::kIntercepted; } FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == kBlock); if (!block_.has_value() || block_->expr != &expr || call_expr.args().size() != 2 || call_expr.has_target()) { SetProgressStatusError( absl::InvalidArgumentError("unexpected call to internal cel.@block")); return CallHandlerResult::kIntercepted; } BlockInfo& block = *block_; block.in = false; index_manager().ReleaseSlots(block.slot_count); // Check if eligible for recursion and update the plan if so. // // The first argument to @block is the list of initializers. These don't // generate a plan in the main program (they are tracked separately to support // lazy evaluation) so we only need to extract the second argument -- the body // of the block that uses the initializers. ProgramBuilder::Subexpression* body_subexpression = program_builder_.GetSubexpression(&call_expr.args()[1]); if (options_.max_recursion_depth != 0 && body_subexpression != nullptr && body_subexpression->IsRecursive() && (options_.max_recursion_depth < 0 || body_subexpression->recursive_program().depth < options_.max_recursion_depth)) { auto recursive_program = body_subexpression->ExtractRecursiveProgram(); SetRecursiveStep( CreateDirectBlockStep(block.index, block.slot_count, std::move(recursive_program.step), expr.id()), recursive_program.depth + 1); return CallHandlerResult::kIntercepted; } // Otherwise, iterative plan. AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); return CallHandlerResult::kIntercepted; } FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleListAppend( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == cel::builtin::kAdd); // Check to see if this is a special case of add that should really be // treated as a list append if (!comprehension_stack_.empty() && comprehension_stack_.back().is_optimizable_list_append) { // Already checked that this is an optimizeable comprehension, // check that this is the correct list append node. const cel::ComprehensionExpr* comprehension = comprehension_stack_.back().comprehension; const cel::Expr& loop_step = comprehension->loop_step(); // Macro loop_step for a map() will contain a list concat operation: // accu_var + [elem] if (&loop_step == &expr) { AddResolvedFunctionStep(&call_expr, &expr, cel::builtin::kRuntimeListAppend); return CallHandlerResult::kIntercepted; } // Macro loop_step for a filter() will contain a ternary: // filter ? accu_var + [elem] : accu_var if (loop_step.has_call_expr() && loop_step.call_expr().function() == cel::builtin::kTernary && loop_step.call_expr().args().size() == 3 && &(loop_step.call_expr().args()[1]) == &expr) { AddResolvedFunctionStep(&call_expr, &expr, cel::builtin::kRuntimeListAppend); return CallHandlerResult::kIntercepted; } } return CallHandlerResult::kNotIntercepted; } FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( const cel::Expr& expr, const cel::CallExpr& call, bool inequality) { if (!ValidateOrError( call.args().size() == 2 && !call.has_target(), "unexpected number of args for builtin equality operator")) { return CallHandlerResult::kIntercepted; } if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin equality operator")); return CallHandlerResult::kIntercepted; } SetRecursiveStep( CreateDirectEqualityStep(std::move(args[0]), std::move(args[1]), inequality, expr.id()), *depth + 1); return CallHandlerResult::kIntercepted; } AddStep(CreateEqualityStep(inequality, expr.id())); return CallHandlerResult::kIntercepted; } FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, const cel::CallExpr& call) { if (!ValidateOrError(call.args().size() == 2 && !call.has_target(), "unexpected number of args for builtin 'in' operator")) { return CallHandlerResult::kIntercepted; } if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin 'in' operator")); return CallHandlerResult::kIntercepted; } SetRecursiveStep( CreateDirectInStep(std::move(args[0]), std::move(args[1]), expr.id()), *depth + 1); return CallHandlerResult::kIntercepted; } AddStep(CreateInStep(expr.id())); return CallHandlerResult::kIntercepted; } void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { switch (cond_) { case BinaryCond::kAnd: ABSL_FALLTHROUGH_INTENDED; case BinaryCond::kOr: visitor_->ValidateOrError( !expr->call_expr().has_target() && expr->call_expr().args().size() == 2, "Invalid argument count for a binary function call."); break; case BinaryCond::kOptionalOr: ABSL_FALLTHROUGH_INTENDED; case BinaryCond::kOptionalOrValue: visitor_->ValidateOrError(expr->call_expr().has_target() && expr->call_expr().args().size() == 1, "Invalid argument count for or/orValue call."); break; } } void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { return; } if (short_circuiting_ && arg_num == 0 && (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { // If first branch evaluation result is enough to determine output, // jump over the second branch and provide result of the first argument as // final output. // Retain a pointer to the jump step so we can update the target after // planning the second argument. std::unique_ptr jump_step; switch (cond_) { case BinaryCond::kAnd: jump_step = CreateCondJumpStep(false, true, {}, expr->id()); break; case BinaryCond::kOr: jump_step = CreateCondJumpStep(true, true, {}, expr->id()); break; default: ABSL_UNREACHABLE(); } ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { jump_step_ = Jump(index, jump_step_ptr); } } } void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { return; } if (short_circuiting_ && (cond_ == BinaryCond::kOptionalOr || cond_ == BinaryCond::kOptionalOrValue)) { // If first branch evaluation result is enough to determine output, // jump over the second branch and provide result of the first argument as // final output. // Retain a pointer to the jump step so we can update the target after // planning the second argument. std::unique_ptr jump_step; switch (cond_) { case BinaryCond::kOptionalOr: jump_step = CreateOptionalHasValueJumpStep(false, expr->id()); break; case BinaryCond::kOptionalOrValue: jump_step = CreateOptionalHasValueJumpStep(true, expr->id()); break; default: ABSL_UNREACHABLE(); } ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { jump_step_ = Jump(index, jump_step_ptr); } } } void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { switch (cond_) { case BinaryCond::kAnd: visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/false); break; case BinaryCond::kOr: visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/true); break; case BinaryCond::kOptionalOr: visitor_->MakeOptionalShortcircuit(expr, /*is_or_value=*/false); break; case BinaryCond::kOptionalOrValue: visitor_->MakeOptionalShortcircuit(expr, /*is_or_value=*/true); break; default: ABSL_UNREACHABLE(); } return; } switch (cond_) { case BinaryCond::kAnd: visitor_->AddStep(CreateAndStep(expr->id())); break; case BinaryCond::kOr: visitor_->AddStep(CreateOrStep(expr->id())); break; case BinaryCond::kOptionalOr: visitor_->AddStep( CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); break; case BinaryCond::kOptionalOrValue: visitor_->AddStep(CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); break; default: ABSL_UNREACHABLE(); } if (short_circuiting_) { // If short-circuiting is enabled, point the conditional jump past the // boolean operator step. visitor_->SetProgressStatusError( jump_step_.set_target(visitor_->GetCurrentIndex())); } } void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { visitor_->ValidateOrError( !expr->call_expr().has_target() && expr->call_expr().args().size() == 3, "Invalid argument count for a ternary function call."); } void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { return; } // Ternary operator "_?_:_" requires a special handing. // In contrary to regular function call, its execution affects the control // flow of the overall CEL expression. // If condition value (argument 0) is True, then control flow is unaffected // as it is passed to the first conditional branch. Then, at the end of this // branch, the jump is performed over the second conditional branch. // If condition value is False, then jump is performed and control is passed // to the beginning of the second conditional branch. // If condition value is Error, then jump is peformed to bypass both // conditional branches and provide Error as result of ternary operation. // condition argument for ternary operator if (arg_num == 0) { // Jump in case of error or non-bool ProgramStepIndex error_jump_pos = visitor_->GetCurrentIndex(); auto* error_jump = visitor_->AddStep(CreateBoolCheckJumpStep({}, expr->id())); if (error_jump) { error_jump_ = Jump(error_jump_pos, error_jump); } // Jump to the second branch of execution // Value is to be removed from the stack. ProgramStepIndex cond_jump_pos = visitor_->GetCurrentIndex(); auto* jump_to_second = visitor_->AddStep(CreateCondJumpStep(false, false, {}, expr->id())); if (jump_to_second) { jump_to_second_ = Jump(cond_jump_pos, static_cast(jump_to_second)); } } else if (arg_num == 1) { // Jump after the first and over the second branch of execution. // Value is to be removed from the stack. ProgramStepIndex jump_pos = visitor_->GetCurrentIndex(); auto* jump_after_first = visitor_->AddStep(CreateJumpStep({}, expr->id())); if (!jump_after_first) { return; } jump_after_first_ = Jump(jump_pos, jump_after_first); if (visitor_->ValidateOrError( jump_to_second_.exists(), "Error configuring ternary operator: jump_to_second_ is null")) { visitor_->SetProgressStatusError( jump_to_second_.set_target(visitor_->GetCurrentIndex())); } } // Code executed after traversing the final branch of execution // (arg_num == 2) is placed in PostVisitCall, to make this method less // clattered. } void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { visitor_->MakeTernaryRecursive(expr); return; } // Determine and set jump offset in jump instruction. if (visitor_->ValidateOrError( error_jump_.exists(), "Error configuring ternary operator: error_jump_ is null")) { visitor_->SetProgressStatusError( error_jump_.set_target(visitor_->GetCurrentIndex())); } if (visitor_->ValidateOrError( jump_after_first_.exists(), "Error configuring ternary operator: jump_after_first_ is null")) { visitor_->SetProgressStatusError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } } void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { visitor_->ValidateOrError( !expr->call_expr().has_target() && expr->call_expr().args().size() == 3, "Invalid argument count for a ternary function call."); } void ExhaustiveTernaryCondVisitor::PostVisit(const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { visitor_->MakeTernaryRecursive(expr); return; } visitor_->AddStep(CreateTernaryStep(expr->id())); } void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { if (is_trivial_) { visitor_->SuppressBranch(&expr->comprehension_expr().iter_range()); visitor_->SuppressBranch(&expr->comprehension_expr().loop_condition()); visitor_->SuppressBranch(&expr->comprehension_expr().loop_step()); } } absl::Status ComprehensionVisitor::PostVisitArgDefault( cel::ComprehensionArg arg_num, const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { return absl::OkStatus(); } switch (arg_num) { case cel::ITER_RANGE: { init_step_pos_ = visitor_->GetCurrentIndex(); init_step_ = visitor_->AddStep( std::make_unique(expr->id())); break; } case cel::ACCU_INIT: { next_step_pos_ = visitor_->GetCurrentIndex(); next_step_ = visitor_->AddStep(std::make_unique( iter_slot_, iter2_slot_, accu_slot_, expr->id())); break; } case cel::LOOP_CONDITION: { cond_step_pos_ = visitor_->GetCurrentIndex(); cond_step_ = visitor_->AddStep(std::make_unique( iter_slot_, iter2_slot_, accu_slot_, short_circuiting_, expr->id())); break; } case cel::LOOP_STEP: { ProgramStepIndex index = visitor_->GetCurrentIndex(); auto* jump_to_next = visitor_->AddStep(CreateJumpStep({}, expr->id())); if (!jump_to_next) { break; } Jump jump_helper(index, jump_to_next); visitor_->SetProgressStatusError(jump_helper.set_target(next_step_pos_)); // Set offsets jumping to the result step. if (cond_step_) { CEL_ASSIGN_OR_RETURN( int jump_from_cond, Jump::CalculateOffset(cond_step_pos_, visitor_->GetCurrentIndex())); cond_step_->set_jump_offset(jump_from_cond); } if (next_step_) { CEL_ASSIGN_OR_RETURN( int jump_from_next, Jump::CalculateOffset(next_step_pos_, visitor_->GetCurrentIndex())); next_step_->set_jump_offset(jump_from_next); } break; } case cel::RESULT: { if (!init_step_ || !next_step_ || !cond_step_) { // Encountered an error earlier. Can't determine where to jump. break; } visitor_->AddStep(CreateComprehensionFinishStep(accu_slot_, expr->id())); // Set offsets jumping past the result step in case of errors. CEL_ASSIGN_OR_RETURN( int jump_from_init, Jump::CalculateOffset(init_step_pos_, visitor_->GetCurrentIndex())); init_step_->set_error_jump_offset(jump_from_init); CEL_ASSIGN_OR_RETURN( int jump_from_next, Jump::CalculateOffset(next_step_pos_, visitor_->GetCurrentIndex())); next_step_->set_error_jump_offset(jump_from_next); CEL_ASSIGN_OR_RETURN( int jump_from_cond, Jump::CalculateOffset(cond_step_pos_, visitor_->GetCurrentIndex())); cond_step_->set_error_jump_offset(jump_from_cond); break; } } return absl::OkStatus(); } void ComprehensionVisitor::PostVisitArgTrivial(cel::ComprehensionArg arg_num, const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { return; } switch (arg_num) { case cel::ITER_RANGE: { break; } case cel::ACCU_INIT: { if (!accu_init_extracted_) { visitor_->AddStep(CreateAssignSlotAndPopStep(accu_slot_)); } break; } case cel::LOOP_CONDITION: { break; } case cel::LOOP_STEP: { break; } case cel::RESULT: { visitor_->AddStep(CreateClearSlotStep(accu_slot_, expr->id())); break; } } } void ComprehensionVisitor::PostVisit(const cel::Expr* expr) { if (is_trivial_) { visitor_->MaybeMakeBindRecursive(expr, &expr->comprehension_expr(), accu_slot_); return; } visitor_->MaybeMakeComprehensionRecursive( expr, &expr->comprehension_expr(), iter_slot_, iter2_slot_, accu_slot_); } // Flattens the expression table into the end of the mainline expression vector // and returns an index to the individual sub expressions. std::vector FlattenExpressionTable( ProgramBuilder& program_builder, ExecutionPath& main) { std::vector> ranges; main = program_builder.FlattenMain(); ranges.push_back(std::make_pair(0, main.size())); std::vector subexpressions = program_builder.FlattenSubexpressions(); for (auto& subexpression : subexpressions) { ranges.push_back(std::make_pair(main.size(), subexpression.size())); absl::c_move(subexpression, std::back_inserter(main)); } std::vector subexpression_indexes; subexpression_indexes.reserve(ranges.size()); for (const auto& range : ranges) { subexpression_indexes.push_back( absl::MakeSpan(main).subspan(range.first, range.second)); } return subexpression_indexes; } absl::Status CheckAstExtensions( const std::vector& extensions) { for (const cel::ExtensionSpec& extension : extensions) { if (extension.id() == "cel_block" && extension.version().major() == 1) { // cel_block v1 is always supported. continue; } // TODO(uncreated-issue/89): Add support for json field names. return absl::InvalidArgumentError(absl::StrCat( "unsupported CEL extension: ", extension.id(), "@", extension.version().major(), ".", extension.version().minor())); } return absl::OkStatus(); } } // namespace absl::StatusOr FlatExprBuilder::CreateExpressionImpl( std::unique_ptr ast, std::vector* issues) const { if (absl::StartsWith(container_, ".") || absl::EndsWith(container_, ".")) { return absl::InvalidArgumentError( absl::StrCat("Invalid expression container: '", container_, "'")); } RuntimeIssue::Severity max_severity = options_.fail_on_warnings ? RuntimeIssue::Severity::kWarning : RuntimeIssue::Severity::kError; IssueCollector issue_collector(max_severity); absl::StatusOr> runtime_extensions = ExtractAndValidateRuntimeExtensions(*ast); if (!runtime_extensions.ok()) { CEL_RETURN_IF_ERROR(issue_collector.AddIssue( RuntimeIssue::CreateError(runtime_extensions.status()))); } auto status = CheckAstExtensions(*runtime_extensions); if (!status.ok()) { CEL_RETURN_IF_ERROR( issue_collector.AddIssue(RuntimeIssue::CreateError(status))); } Resolver resolver(container_, function_registry_, type_registry_, GetTypeProvider(), options_.enable_qualified_type_identifiers); std::shared_ptr arena; ProgramBuilder program_builder; PlannerContext extension_context(env_, resolver, options_, GetTypeProvider(), issue_collector, program_builder, arena); for (const std::unique_ptr& transform : ast_transforms_) { CEL_RETURN_IF_ERROR(transform->UpdateAst(extension_context, *ast)); } std::vector> optimizers; for (const ProgramOptimizerFactory& optimizer_factory : program_optimizers_) { CEL_ASSIGN_OR_RETURN(auto optimizer, optimizer_factory(extension_context, *ast)); if (optimizer != nullptr) { optimizers.push_back(std::move(optimizer)); } } // These objects are expected to remain scoped to one build call -- references // to them shouldn't be persisted in any part of the result expression. FlatExprVisitor visitor(resolver, options_, std::move(optimizers), ast->reference_map(), GetTypeProvider(), issue_collector, program_builder, extension_context, enable_optional_types_); if (options_.max_recursion_depth == -1 || options_.max_recursion_depth > 0) { int depth_limit = options_.max_recursion_depth == -1 ? std::numeric_limits::max() : options_.max_recursion_depth; visitor.SetMaxRecursionDepth(depth_limit); } cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(ast->root_expr(), visitor, opts); if (!visitor.progress_status().ok()) { return visitor.progress_status(); } if (issues != nullptr) { (*issues) = issue_collector.ExtractIssues(); } ExecutionPath execution_path; std::vector subexpressions = FlattenExpressionTable(program_builder, execution_path); return FlatExpression(std::move(execution_path), std::move(subexpressions), visitor.slot_count(), GetTypeProvider(), options_, std::move(arena)); } const cel::TypeProvider& FlatExprBuilder::GetTypeProvider() const { return use_legacy_type_provider_ ? static_cast( *GetLegacyRuntimeTypeProvider(type_registry_)) : GetRuntimeTypeProvider(type_registry_); } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/flat_expr_builder.h ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ #include #include #include #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "base/type_provider.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" #include "runtime/function_registry.h" #include "runtime/internal/runtime_env.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" namespace google::api::expr::runtime { // CelExpressionBuilder implementation. // Builds instances of CelExpressionFlatImpl. class FlatExprBuilder { public: FlatExprBuilder( absl_nonnull std::shared_ptr env, const cel::RuntimeOptions& options, bool use_legacy_type_provider = false) : env_(std::move(env)), options_(options), container_(options.container), function_registry_(env_->function_registry), type_registry_(env_->type_registry), use_legacy_type_provider_(use_legacy_type_provider) {} void AddAstTransform(std::unique_ptr transform) { ast_transforms_.push_back(std::move(transform)); } void AddProgramOptimizer(ProgramOptimizerFactory optimizer) { program_optimizers_.push_back(std::move(optimizer)); } void set_container(std::string container) { container_ = std::move(container); } absl::string_view container() const { return container_; } // TODO(uncreated-issue/45): Add overload for cref AST. At the moment, all the users // can pass ownership of a freshly converted AST. absl::StatusOr CreateExpressionImpl( std::unique_ptr ast, std::vector* issues) const; const cel::runtime_internal::RuntimeEnv& env() const { return *env_; } const cel::RuntimeOptions& options() const { return options_; } // Called by `cel::extensions::EnableOptionalTypes` to indicate that special // `optional_type` handling is needed. void enable_optional_types() { enable_optional_types_ = true; } bool optional_types_enabled() const { return enable_optional_types_; } private: const cel::TypeProvider& GetTypeProvider() const; const absl_nonnull std::shared_ptr env_; cel::RuntimeOptions options_; std::string container_; bool enable_optional_types_ = false; // TODO(uncreated-issue/45): evaluate whether we should use a shared_ptr here to // allow built expressions to keep the registries alive. const cel::FunctionRegistry& function_registry_; const cel::TypeRegistry& type_registry_; bool use_legacy_type_provider_; std::vector> ast_transforms_; std::vector program_optimizers_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ ================================================ FILE: eval/compiler/flat_expr_builder_comprehensions_test.cc ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/field_mask.pb.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/comprehension_vulnerability_check.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::ParsedExpr; using ::testing::HasSubstr; class CelExpressionBuilderFlatImplComprehensionsTest : public testing::TestWithParam { public: CelExpressionBuilderFlatImplComprehensionsTest() = default; bool enable_recursive_planning() { return GetParam(); } cel::RuntimeOptions GetRuntimeOptions() { cel::RuntimeOptions options; if (enable_recursive_planning()) { options.max_recursion_depth = -1; } options.enable_comprehension_list_append = true; return options; } }; TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].filter(x, [3, 4].all(y, x < y))")); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsList()); EXPECT_THAT(*result.ListOrDie(), testing::SizeIs(2)); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].map(x, x * 2)")); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsList()); EXPECT_THAT(*result.ListOrDie(), testing::SizeIs(2)); EXPECT_THAT((*result.ListOrDie())[0], test::EqualsCelValue(CelValue::CreateInt64(2))); EXPECT_THAT((*result.ListOrDie())[1], test::EqualsCelValue(CelValue::CreateInt64(4))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[7].exists_one(a, a == 7)")); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelBool(true)); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[7, 7].exists_one(a, a == 7)")); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelBool(false)); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ListCompWithUnknowns) { cel::RuntimeOptions options = GetRuntimeOptions(); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("items.exists(i, i < 0)")); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; activation.set_unknown_attribute_patterns({CelAttributePattern{ "items", {CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1))}}}); ContainerBackedListImpl list_impl = ContainerBackedListImpl({ CelValue::CreateInt64(1), // element items[1] is marked unknown, so the computation should produce // and unknown set. CelValue::CreateInt64(-1), CelValue::CreateInt64(2), }); activation.InsertValue("items", CelValue::CreateList(&list_impl)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsUnknownSet()) << result.DebugString(); const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); EXPECT_THAT(attrs, testing::SizeIs(1)); EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("items")); EXPECT_THAT(attrs.begin()->qualifier_path(), testing::SizeIs(1)); EXPECT_THAT(attrs.begin()->qualifier_path().at(0).GetInt64Key().value(), testing::Eq(1)); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, InvalidComprehensionWithRewrite) { CheckedExpr expr; // The rewrite step which occurs when an identifier gets a more qualified name // from the reference map has the potential to make invalid comprehensions // appear valid, by populating missing fields with default values. // var.(x, ) google::protobuf::TextFormat::ParseFromString( R"pb( reference_map { key: 1 value { name: "qualified.var" } } expr { comprehension_expr { iter_var: "x" iter_range { id: 1 ident_expr { name: "var" } } accu_var: "y" accu_init { id: 1 const_expr { bool_value: true } } } })pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, testing::AnyOf(HasSubstr("Invalid comprehension"), HasSubstr("Invalid empty expression")))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ComprehensionWithConcatVulernability) { CheckedExpr expr; // The comprehension loop step performs an unsafe concatenation of the // accumulation variable with itself or one of its children. google::protobuf::TextFormat::ParseFromString( R"pb( expr { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "var" } } accu_var: "y" accu_init { list_expr {} } result { ident_expr { name: "y" } } loop_condition { const_expr { bool_value: true } } loop_step { call_expr { function: "_?_:_" args { const_expr { bool_value: true } } args { ident_expr { name: "y" } } args { call_expr { function: "_+_" args { call_expr { function: "dyn" args { ident_expr { name: "y" } } } } args { call_expr { function: "_[_]" args { ident_expr { name: "y" } } args { const_expr { int64_value: 0 } } } } } } } } } })pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("memory exhaustion vulnerability"))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ComprehensionWithListVulernability) { CheckedExpr expr; // The comprehension google::protobuf::TextFormat::ParseFromString( R"pb( expr { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "var" } } accu_var: "y" accu_init { list_expr {} } result { ident_expr { name: "y" } } loop_condition { const_expr { bool_value: true } } loop_step { list_expr { elements { ident_expr { name: "y" } } elements { list_expr { elements { select_expr { operand { ident_expr { name: "y" } } field: "z" } } } } } } } } )pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("memory exhaustion vulnerability"))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ComprehensionWithStructVulernability) { CheckedExpr expr; // The comprehension loop step builds a deeply nested struct which expands // exponentially. google::protobuf::TextFormat::ParseFromString( R"pb( expr { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "var" } } accu_var: "y" accu_init { list_expr {} } result { ident_expr { name: "y" } } loop_condition { const_expr { bool_value: true } } loop_step { struct_expr { entries { map_key { const_expr { string_value: "key" } } value { ident_expr { name: "y" } } } entries { map_key { const_expr { string_value: "present" } } value { select_expr { test_only: true operand { ident_expr { name: "y" } } field: "z" } } } entries { map_key { const_expr { string_value: "key_subset" } } value { select_expr { operand { ident_expr { name: "y" } } field: "z" } } } } } } } )pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("memory exhaustion vulnerability"))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ComprehensionWithNestedComprehensionResultVulernability) { CheckedExpr expr; // The nested comprehension performs an unsafe concatenation on the parent // accumulator variable within its 'result' expression. // // The inner-most comprehension shadows its parent, but still refers to its // oldest ancestor. It, however, does not do anything unsafe. google::protobuf::TextFormat::ParseFromString( R"pb( expr { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "var" } } accu_var: "y" accu_init { list_expr {} } result { ident_expr { name: "y" } } loop_condition { const_expr { bool_value: true } } loop_step { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "y" } } accu_var: "z" accu_init { list_expr {} } result { call_expr { function: "_+_" args { ident_expr { name: "y" } } args { ident_expr { name: "y" } } } } loop_condition { const_expr { bool_value: true } } loop_step { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "y" } } accu_var: "z" accu_init { list_expr {} } result { call_expr { function: "dyn" args { ident_expr { name: "y" } } } } loop_condition { const_expr { bool_value: true } } loop_step { call_expr { function: "dyn" args { ident_expr { name: "y" } } } } } } } } } )pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("memory exhaustion vulnerability"))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ComprehensionWithNestedComprehensionLoopStepVulernability) { CheckedExpr expr; // The nested comprehension performs an unsafe concatenation on the parent // accumulator variable within its 'loop_step'. google::protobuf::TextFormat::ParseFromString( R"pb( expr { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "var" } } accu_var: "y" accu_init { list_expr {} } result { ident_expr { name: "y" } } loop_condition { const_expr { bool_value: true } } loop_step { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "y" } } accu_var: "z" accu_init { list_expr {} } result { ident_expr { name: "z" } } loop_condition { const_expr { bool_value: true } } loop_step { call_expr { function: "_+_" args { ident_expr { name: "y" } } args { ident_expr { name: "y" } } } } } } } } )pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("memory exhaustion vulnerability"))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ComprehensionWithNestedComprehensionLoopStepVulernabilityResult) { CheckedExpr expr; // The nested comprehension performs an unsafe concatenation on the parent // accumulator. google::protobuf::TextFormat::ParseFromString( R"pb( expr { comprehension_expr { iter_var: "outer_iter" iter_range { ident_expr { name: "input_list" } } accu_var: "outer_accu" accu_init { ident_expr { name: "input_list" } } loop_condition { id: 3 const_expr { bool_value: true } } loop_step { comprehension_expr { # the iter_var shadows the outer accumulator on the loop step # but not the result step. iter_var: "outer_accu" iter_range { list_expr {} } accu_var: "inner_accu" accu_init { list_expr {} } loop_condition { const_expr { bool_value: true } } loop_step { list_expr {} } result { call_expr { function: "_+_" args { ident_expr { name: "outer_accu" } } args { ident_expr { name: "outer_accu" } } } } } } result { list_expr {} } } } )pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("memory exhaustion vulnerability"))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ComprehensionWithNestedComprehensionLoopStepIterRangeVulnerability) { CheckedExpr expr; // The nested comprehension unsafely modifies the parent accumulator // (outer_accu) being used as a iterable range google::protobuf::TextFormat::ParseFromString( R"pb( expr { comprehension_expr { iter_var: "x" iter_range { ident_expr { name: "input_list" } } accu_var: "outer_accu" accu_init { ident_expr { name: "input_list" } } loop_condition { const_expr { bool_value: true } } loop_step { comprehension_expr { iter_var: "y" iter_range { ident_expr { name: "outer_accu" } } accu_var: "inner_accu" accu_init { ident_expr { name: "outer_accu" } } loop_condition { const_expr { bool_value: true } } loop_step { call_expr { function: "_+_" args { ident_expr { name: "inner_accu" } } args { const_expr { string_value: "12345" } } } } result { ident_expr { name: "inner_accu" } } } } result { ident_expr { name: "outer_accu" } } } } )pb", &expr); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("memory exhaustion vulnerability"))); } TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, InvalidBindComprehension) { ParsedExpr expr; // Trivial comprehensions (such as cel.bind), are optimized by skipping the // planning for the loop step, however the planner will still warn if the // loop step references the unused var. ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr { comprehension_expr { iter_var: "#unused" iter_range { id: 1 list_expr {} } accu_var: "bind_var" accu_init { id: 1 const_expr { bool_value: true } } loop_step { call_expr { function: "_&&_" args { ident_expr { name: "#unused" } } args { ident_expr { name: "bind_var" } } } } loop_condition { const_expr { bool_value: false } } result { ident_expr { name: "bind_var" } } } })pb", &expr)); cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT( builder.CreateExpression(&(expr.expr()), nullptr).status(), StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("Unexpected iter_var access in trivial comprehension"))); } INSTANTIATE_TEST_SUITE_P(TestSuite, CelExpressionBuilderFlatImplComprehensionsTest, testing::Bool(), [](const testing::TestParamInfo& info) { return info.param ? "recursive" : "default"; }); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/flat_expr_builder_extensions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/flat_expr_builder_extensions.h" #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "common/expr.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { namespace { using Subexpression = google::api::expr::runtime::ProgramBuilder::Subexpression; // Remap a recursive program to its parent if the parent is a transparent // wrapper. void MaybeReassignChildRecursiveProgram(Subexpression* parent) { if (parent->IsFlattened() || parent->IsRecursive()) { return; } if (parent->elements().size() != 1) { return; } auto* child_alternative = absl::get_if(&parent->elements()[0]); if (child_alternative == nullptr) { return; } auto& child_subexpression = *child_alternative; if (!child_subexpression->IsRecursive()) { return; } auto child_program = child_subexpression->ExtractRecursiveProgram(); parent->set_recursive_program(std::move(child_program.step), child_program.depth); } } // namespace Subexpression::Subexpression(const cel::Expr* self, ProgramBuilder* owner) : self_(self), parent_(nullptr), owner_(owner) {} size_t Subexpression::ComputeSize() const { if (IsFlattened()) { return flattened_elements().size(); } else if (IsRecursive()) { return 1; } std::vector to_expand{this}; size_t size = 0; while (!to_expand.empty()) { const auto* expr = to_expand.back(); to_expand.pop_back(); if (expr->IsFlattened()) { size += expr->flattened_elements().size(); continue; } else if (expr->IsRecursive()) { size += 1; continue; } for (const auto& elem : expr->elements()) { if (auto* child = absl::get_if(&elem); child != nullptr) { to_expand.push_back(*child); } else { size += 1; } } } return size; } absl::optional Subexpression::RecursiveDependencyDepth() const { auto* tree = absl::get_if(&program_); int depth = 0; if (tree == nullptr) { return absl::nullopt; } for (const auto& element : *tree) { auto* subexpression = absl::get_if(&element); if (subexpression == nullptr) { return absl::nullopt; } if (!(*subexpression)->IsRecursive()) { return absl::nullopt; } depth = std::max(depth, (*subexpression)->recursive_program().depth); } return depth; } std::vector> Subexpression::ExtractRecursiveDependencies() const { auto* tree = absl::get_if(&program_); std::vector> dependencies; if (tree == nullptr) { return {}; } for (const auto& element : *tree) { auto* subexpression = absl::get_if(&element); if (subexpression == nullptr) { return {}; } if (!(*subexpression)->IsRecursive()) { return {}; } dependencies.push_back((*subexpression)->ExtractRecursiveProgram().step); } return dependencies; } Subexpression* absl_nullable Subexpression::ExtractChild(Subexpression* child) { ABSL_DCHECK(child != nullptr); if (IsFlattened()) { return nullptr; } for (auto iter = elements().begin(); iter != elements().end(); ++iter) { Subexpression::Element& element = *iter; if (!absl::holds_alternative(element)) { continue; } Subexpression* candidate = absl::get(element); if (candidate != child) { continue; } elements().erase(iter); return candidate; } return nullptr; } // Compute the offset for moving the pc from after the base step to before the // target step. int Subexpression::CalculateOffset(int base, int target) const { ABSL_DCHECK(!IsFlattened()); ABSL_DCHECK(!IsRecursive()); int sign = 1; int start = base + 1; int end = target; if (end <= start) { // When target is before base we have to consider the size of the base step // and target (offset is from after base to before target). start = target; end = base + 1; sign = -1; } ABSL_DCHECK_GE(start, 0); ABSL_DCHECK_GE(end, 0); ABSL_DCHECK_LE(start, elements().size()); ABSL_DCHECK_LE(end, elements().size()); int sum = 0; for (int i = start; i < end; ++i) { const auto& element = elements()[i]; if (auto* subexpr = absl::get_if(&element); subexpr != nullptr) { sum += (*subexpr)->ComputeSize(); } else { // Individual step or wrapped recursive program. sum += 1; } } return sign * sum; } void Subexpression::Flatten() { struct Record { Subexpression* subexpr; size_t offset; }; if (IsFlattened()) { return; } std::vector> flat; std::vector flatten_stack; flatten_stack.push_back({this, 0}); while (!flatten_stack.empty()) { Record top = flatten_stack.back(); flatten_stack.pop_back(); size_t offset = top.offset; auto* subexpr = top.subexpr; if (subexpr->IsFlattened()) { auto& elements = subexpr->flattened_elements(); absl::c_move(elements, std::back_inserter(flat)); elements.clear(); continue; } else if (subexpr->IsRecursive()) { flat.push_back(std::make_unique( std::move(subexpr->ExtractRecursiveProgram().step), subexpr->self_->id())); continue; } auto& elements = subexpr->elements(); size_t size = elements.size(); size_t i = offset; for (; i < size; ++i) { auto& element = elements[i]; if (auto* child = absl::get_if(&element); child != nullptr) { // push resume then child so child elements are processed first. flatten_stack.push_back({subexpr, i + 1}); flatten_stack.push_back({*child, 0}); break; } else if (auto* step = absl::get_if>(&element); step != nullptr) { flat.push_back(std::move(*step)); } else { ABSL_UNREACHABLE(); } } if (i == size) { elements.clear(); } } program_ = std::move(flat); } Subexpression::RecursiveProgram Subexpression::ExtractRecursiveProgram() { ABSL_DCHECK(IsRecursive()); auto result = std::move(absl::get(program_)); program_.emplace>(); return result; } bool Subexpression::ExtractTo( std::vector>& out) { if (!IsFlattened()) { return false; } out.reserve(out.size() + flattened_elements().size()); absl::c_move(flattened_elements(), std::back_inserter(out)); program_.emplace>(); return true; } std::vector> ProgramBuilder::FlattenSubexpression(Subexpression* expr) { std::vector> out; if (!expr) { return out; } expr->Flatten(); expr->ExtractTo(out); return out; } ProgramBuilder::ProgramBuilder() : root_(nullptr), current_(nullptr), subprogram_map_() {} ExecutionPath ProgramBuilder::FlattenMain() { auto out = FlattenSubexpression(root_); root_ = nullptr; return out; } std::vector ProgramBuilder::FlattenSubexpressions() { std::vector out; out.reserve(extracted_subexpressions_.size()); for (auto& subexpression : extracted_subexpressions_) { out.push_back(FlattenSubexpression(subexpression)); } extracted_subexpressions_.clear(); return out; } Subexpression* absl_nullable ProgramBuilder::EnterSubexpression( const cel::Expr* expr, size_t size_hint) { Subexpression* subexpr = MakeSubexpression(expr); if (subexpr == nullptr) { return subexpr; } subexpr->elements().reserve(size_hint); if (current_ == nullptr) { root_ = subexpr; current_ = subexpr; return subexpr; } current_->AddSubexpression(subexpr); subexpr->parent_ = current_->self_; current_ = subexpr; return subexpr; } Subexpression* absl_nullable ProgramBuilder::ExitSubexpression( const cel::Expr* expr) { ABSL_DCHECK(expr == current_->self_); ABSL_DCHECK(GetSubexpression(expr) == current_); MaybeReassignChildRecursiveProgram(current_); Subexpression* result = GetSubexpression(current_->parent_); ABSL_DCHECK(result != nullptr || current_ == root_); current_ = result; return result; } Subexpression* absl_nullable ProgramBuilder::GetSubexpression( const cel::Expr* expr) { auto it = subprogram_map_.find(expr); if (it == subprogram_map_.end()) { return nullptr; } return it->second.get(); } ExpressionStep* absl_nullable ProgramBuilder::AddStep( std::unique_ptr step) { if (current_ == nullptr) { return nullptr; } auto* step_ptr = step.get(); return current_->AddStep(std::move(step)) ? step_ptr : nullptr; } int ProgramBuilder::ExtractSubexpression(const cel::Expr* expr) { auto it = subprogram_map_.find(expr); if (it == subprogram_map_.end()) { return -1; } auto* subexpression = it->second.get(); auto parent_it = subprogram_map_.find(subexpression->parent_); if (parent_it == subprogram_map_.end()) { return -1; } auto* parent = parent_it->second.get(); auto* child = parent->ExtractChild(subexpression); if (child == nullptr) { return -1; } extracted_subexpressions_.push_back(child); return extracted_subexpressions_.size() - 1; } Subexpression* absl_nullable ProgramBuilder::MakeSubexpression( const cel::Expr* expr) { auto [it, inserted] = subprogram_map_.try_emplace( expr, absl::WrapUnique(new Subexpression(expr, this))); if (!inserted) { return nullptr; } return it->second.get(); } bool PlannerContext::IsSubplanInspectable(const cel::Expr& node) const { return program_builder_.GetSubexpression(&node) != nullptr; } ExecutionPathView PlannerContext::GetSubplan(const cel::Expr& node) { auto* subexpression = program_builder_.GetSubexpression(&node); if (subexpression == nullptr) { return ExecutionPathView(); } subexpression->Flatten(); return subexpression->flattened_elements(); } absl::StatusOr PlannerContext::ExtractSubplan( const cel::Expr& node) { auto* subexpression = program_builder_.GetSubexpression(&node); if (subexpression == nullptr) { return absl::InternalError( "attempted to update program step for untracked expr node"); } subexpression->Flatten(); ExecutionPath out; subexpression->ExtractTo(out); return out; } absl::Status PlannerContext::ReplaceSubplan(const cel::Expr& node, ExecutionPath path) { auto* subexpression = program_builder_.GetSubexpression(&node); if (subexpression == nullptr) { return absl::InternalError( "attempted to update program step for untracked expr node"); } // Make sure structure for descendents is erased. if (!subexpression->IsFlattened()) { subexpression->Flatten(); } subexpression->flattened_elements() = std::move(path); return absl::OkStatus(); } void ProgramBuilder::Reset() { root_ = nullptr; current_ = nullptr; extracted_subexpressions_.clear(); subprogram_map_.clear(); } absl::Status PlannerContext::ReplaceSubplan( const cel::Expr& node, std::unique_ptr step, int depth) { auto* subexpression = program_builder_.GetSubexpression(&node); if (subexpression == nullptr) { return absl::InternalError( "attempted to update program step for untracked expr node"); } subexpression->set_recursive_program(std::move(step), depth); return absl::OkStatus(); } absl::Status PlannerContext::AddSubplanStep( const cel::Expr& node, std::unique_ptr step) { auto* subexpression = program_builder_.GetSubexpression(&node); if (subexpression == nullptr) { return absl::InternalError( "attempted to update program step for untracked expr node"); } subexpression->AddStep(std::move(step)); return absl::OkStatus(); } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/flat_expr_builder_extensions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // API definitions for planner extensions. // // These are provided to indirect build dependencies for optional features and // require detailed understanding of how the flat expression builder works and // its assumptions. // // These interfaces should not be implemented directly by CEL users. #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/ast.h" #include "base/type_provider.h" #include "common/expr.h" #include "common/native_type.h" #include "common/type_reflector.h" #include "eval/compiler/resolver.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/trace_step.h" #include "internal/casts.h" #include "runtime/internal/issue_collector.h" #include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { // Class representing a CEL program being built. // // Maintains tree structure and mapping from the AST representation to // subexpressions. Maintains an insertion point for new steps and // subexpressions. // // This class is thread-hostile and not intended for direct access outside of // the Expression builder. Extensions should interact with this through the // the PlannerContext member functions. class ProgramBuilder { public: class Subexpression; private: using SubprogramMap = absl::flat_hash_map>; public: // Represents a subexpression. // // Steps apply operations on the stack machine for the C++ runtime. // For most expression types, this maps to a post order traversal -- for all // nodes, evaluate dependencies (pushing their results to stack) then evaluate // self. // // Must be tied to a ProgramBuilder to coordinate relationships. class Subexpression { private: using Element = absl::variant, Subexpression* absl_nonnull>; using TreePlan = std::vector; using FlattenedPlan = std::vector>; public: struct RecursiveProgram { std::unique_ptr step; int depth; }; ~Subexpression() = default; // Not copyable or movable. Subexpression(const Subexpression&) = delete; Subexpression& operator=(const Subexpression&) = delete; Subexpression(Subexpression&&) = delete; Subexpression& operator=(Subexpression&&) = delete; // Add a program step at the current end of the subexpression. bool AddStep(std::unique_ptr step) { if (IsRecursive()) { return false; } if (IsFlattened()) { flattened_elements().push_back(std::move(step)); return true; } elements().push_back({std::move(step)}); return true; } void AddSubexpression(Subexpression* absl_nonnull expr) { ABSL_DCHECK(absl::holds_alternative(program_)); ABSL_DCHECK(owner_ == expr->owner_); elements().push_back(expr); } // Accessor for elements (either simple steps or subexpressions). // // Value is undefined if in the expression has already been flattened. std::vector& elements() { ABSL_DCHECK(absl::holds_alternative(program_)); return absl::get(program_); } const std::vector& elements() const { ABSL_DCHECK(absl::holds_alternative(program_)); return absl::get(program_); } // Accessor for program steps. // // Value is undefined if in the expression has not yet been flattened. std::vector>& flattened_elements() { ABSL_DCHECK(IsFlattened()); return absl::get(program_); } const std::vector>& flattened_elements() const { ABSL_DCHECK(IsFlattened()); return absl::get(program_); } void set_recursive_program(std::unique_ptr step, int depth) { program_ = RecursiveProgram{std::move(step), depth}; } const RecursiveProgram& recursive_program() const { ABSL_DCHECK(IsRecursive()); return absl::get(program_); } absl::optional RecursiveDependencyDepth() const; std::vector> ExtractRecursiveDependencies() const; RecursiveProgram ExtractRecursiveProgram(); bool IsRecursive() const { return absl::holds_alternative(program_); } // Compute the current number of program steps in this subexpression and // its dependencies. size_t ComputeSize() const; // Calculate the number of steps from the end of base to before target, // (including negative offsets). int CalculateOffset(int base, int target) const; // Extract a child subexpression. // // The expression is removed from the elements array. // // Returns nullptr if child is not an element of this subexpression. Subexpression* absl_nullable ExtractChild(Subexpression* child); // Flatten the subexpression. // // This removes the structure tracking for subexpressions, but makes the // subprogram evaluable on the runtime's stack machine. void Flatten(); bool IsFlattened() const { return absl::holds_alternative(program_); } // Extract a flattened subexpression into the given vector. Transferring // ownership of the given steps. // // Returns false if the subexpression is not currently flattened. bool ExtractTo(std::vector>& out); private: Subexpression(const cel::Expr* self, ProgramBuilder* owner); friend class ProgramBuilder; // Some extensions expect the program plan to be contiguous mid-planning. // // This adds complexity, but supports swapping to a flat representation as // needed. absl::variant program_; const cel::Expr* self_; const cel::Expr* absl_nullable parent_; ProgramBuilder* owner_; }; ProgramBuilder(); // Flatten the main subexpression and return its value. // // This transfers ownership of the program, returning the builder to starting // state. (See FlattenSubexpressions). ExecutionPath FlattenMain(); // Flatten extracted subprograms. // // This transfers ownership of the subprograms, returning the extracted // programs table to starting state. std::vector FlattenSubexpressions(); // Returns the current subexpression where steps and new subexpressions are // added. // // May return null if the builder is not currently planning an expression. Subexpression* absl_nullable current() { return current_; } // Enter a subexpression context. // // Adds a subexpression at the current insertion point and move insertion // to the subexpression. // // Returns the new current() value. // // May return nullptr if the expression is already indexed in the program // builder. Subexpression* absl_nullable EnterSubexpression(const cel::Expr* expr, size_t size_hint = 0); // Exit a subexpression context. // // Sets insertion point to parent. // // Returns the new current() value or nullptr if called out of order. Subexpression* absl_nullable ExitSubexpression(const cel::Expr* expr); // Return the subexpression mapped to the given expression. // // Returns nullptr if the mapping doesn't exist either due to the // program being overwritten or not encountering the expression. Subexpression* absl_nullable GetSubexpression(const cel::Expr* expr); // Return the extracted subexpression mapped to the given index. // // Returns nullptr if the mapping doesn't exist Subexpression* absl_nullable GetExtractedSubexpression(size_t index) { if (index >= extracted_subexpressions_.size()) { return nullptr; } return extracted_subexpressions_[index]; } // Return index to the extracted subexpression. // // Returns -1 if the subexpression is not found. int ExtractSubexpression(const cel::Expr* expr); // Add a program step to the current subexpression. // If successful, returns the step pointer. // // Note: If successful, the pointer should remain valid until the parent // expression is finalized. Optimizers may modify the program plan which may // free the step at that point. ExpressionStep* absl_nullable AddStep(std::unique_ptr step); void Reset(); private: static std::vector> FlattenSubexpression(Subexpression* absl_nonnull expr); Subexpression* absl_nullable MakeSubexpression(const cel::Expr* expr); Subexpression* absl_nullable root_; std::vector extracted_subexpressions_; Subexpression* absl_nullable current_; SubprogramMap subprogram_map_; }; // Attempt to downcast a specific type of recursive step. template const Subclass* TryDowncastDirectStep(const DirectExpressionStep* step) { if (step == nullptr) { return nullptr; } auto type_id = step->GetNativeTypeId(); if (type_id == cel::NativeTypeId::For()) { const auto* trace_step = cel::internal::down_cast(step); auto deps = trace_step->GetDependencies(); if (!deps.has_value() || deps->size() != 1) { return nullptr; } step = deps->at(0); type_id = step->GetNativeTypeId(); } if (type_id == cel::NativeTypeId::For()) { return cel::internal::down_cast(step); } return nullptr; } // Class representing FlatExpr internals exposed to extensions. class PlannerContext { public: PlannerContext( std::shared_ptr environment, const Resolver& resolver, const cel::RuntimeOptions& options, const cel::TypeReflector& type_reflector, cel::runtime_internal::IssueCollector& issue_collector, ProgramBuilder& program_builder, std::shared_ptr& arena ABSL_ATTRIBUTE_LIFETIME_BOUND, std::shared_ptr message_factory = nullptr) : environment_(std::move(environment)), resolver_(resolver), type_reflector_(type_reflector), options_(options), issue_collector_(issue_collector), program_builder_(program_builder), arena_(arena), explicit_arena_(arena_ != nullptr), message_factory_(std::move(message_factory)) {} ProgramBuilder& program_builder() { return program_builder_; } // Returns true if the subplan is inspectable. // // If false, the node is not mapped to a subexpression in the program builder. bool IsSubplanInspectable(const cel::Expr& node) const; // Return a view to the current subplan representing node. // // Note: this is invalidated after a sibling or parent is updated. // // This operation forces the subexpression to flatten which removes the // expr->program mapping for any descendants. ExecutionPathView GetSubplan(const cel::Expr& node); // Extract the plan steps for the given expr. // // After successful extraction, the subexpression is still inspectable, but // empty. absl::StatusOr ExtractSubplan(const cel::Expr& node); // Replace the subplan associated with node with a new subplan. // // This operation forces the subexpression to flatten which removes the // expr->program mapping for any descendants. absl::Status ReplaceSubplan(const cel::Expr& node, ExecutionPath path); // Replace the subplan associated with node with a new recursive subplan. // // This operation clears any existing plan to which removes the // expr->program mapping for any descendants. absl::Status ReplaceSubplan(const cel::Expr& node, std::unique_ptr step, int depth); // Extend the current subplan with the given expression step. absl::Status AddSubplanStep(const cel::Expr& node, std::unique_ptr step); const Resolver& resolver() const { return resolver_; } const cel::TypeReflector& type_reflector() const { return type_reflector_; } const cel::RuntimeOptions& options() const { return options_; } cel::runtime_internal::IssueCollector& issue_collector() { return issue_collector_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { return environment_->descriptor_pool.get(); } // Returns `true` if an arena was explicitly provided during planning. bool HasExplicitArena() const { return explicit_arena_; } google::protobuf::Arena* absl_nonnull MutableArena() { if (!explicit_arena_ && arena_ == nullptr) { arena_ = std::make_shared(); } ABSL_DCHECK(arena_ != nullptr); return arena_.get(); } // Returns `true` if a message factory was explicitly provided during // planning. bool HasExplicitMessageFactory() const { return message_factory_ != nullptr; } google::protobuf::MessageFactory* absl_nonnull MutableMessageFactory() { return HasExplicitMessageFactory() ? message_factory_.get() : environment_->MutableMessageFactory(); } private: const std::shared_ptr environment_; const Resolver& resolver_; const cel::TypeReflector& type_reflector_; const cel::RuntimeOptions& options_; cel::runtime_internal::IssueCollector& issue_collector_; ProgramBuilder& program_builder_; std::shared_ptr& arena_; const bool explicit_arena_; const std::shared_ptr message_factory_; }; // Interface for Ast Transforms. // If any are present, the FlatExprBuilder will apply the Ast Transforms in // order on a copy of the relevant input expressions before planning the // program. class AstTransform { public: virtual ~AstTransform() = default; virtual absl::Status UpdateAst(PlannerContext& context, cel::Ast& ast) const = 0; }; // Interface for program optimizers. // // If any are present, the FlatExprBuilder will notify the implementations in // order as it traverses the input ast. // // Note: implementations must correctly check that subprograms are available // before accessing (i.e. they have not already been edited). class ProgramOptimizer { public: virtual ~ProgramOptimizer() = default; // Called before planning the given expr node. virtual absl::Status OnPreVisit(PlannerContext& context, const cel::Expr& node) = 0; // Called after planning the given expr node. virtual absl::Status OnPostVisit(PlannerContext& context, const cel::Expr& node) = 0; }; // Type definition for ProgramOptimizer factories. // // The expression builder must remain thread compatible, but ProgramOptimizers // are often stateful for a given expression. To avoid requiring the optimizer // implementation to handle concurrent planning, the builder creates a new // instance per expression planned. // // The factory must be thread safe, but the returned instance may assume // it is called from a synchronous context. using ProgramOptimizerFactory = absl::AnyInvocable>( PlannerContext&, const cel::Ast&) const>; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_EXTENSIONS_H_ ================================================ FILE: eval/compiler/flat_expr_builder_extensions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/flat_expr_builder_extensions.h" #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "common/expr.h" #include "common/native_type.h" #include "common/value.h" #include "eval/compiler/resolver.h" #include "eval/eval/const_value_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/function_step.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/function_registry.h" #include "runtime/internal/issue_collector.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::Expr; using ::cel::RuntimeIssue; using ::cel::runtime_internal::IssueCollector; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Optional; using Subexpression = ProgramBuilder::Subexpression; class PlannerContextTest : public testing::Test { public: PlannerContextTest() : env_(NewTestingRuntimeEnv()), type_registry_(env_->type_registry), function_registry_(env_->function_registry), resolver_("", function_registry_, type_registry_, type_registry_.GetComposedTypeProvider()), issue_collector_(RuntimeIssue::Severity::kError) {} protected: absl_nonnull std::shared_ptr env_; cel::TypeRegistry& type_registry_; cel::FunctionRegistry& function_registry_; cel::RuntimeOptions options_; Resolver resolver_; IssueCollector issue_collector_; }; MATCHER_P(UniquePtrHolds, ptr, "") { const auto& got = arg; return ptr == got.get(); } struct SimpleTreeSteps { const ExpressionStep* a; const ExpressionStep* b; const ExpressionStep* c; }; // simulate a program of: // a // / \ // b c absl::StatusOr InitSimpleTree( const Expr& a, const Expr& b, const Expr& c, ProgramBuilder& program_builder) { CEL_ASSIGN_OR_RETURN(auto a_step, CreateConstValueStep(cel::NullValue(), -1)); CEL_ASSIGN_OR_RETURN(auto b_step, CreateConstValueStep(cel::NullValue(), -1)); CEL_ASSIGN_OR_RETURN(auto c_step, CreateConstValueStep(cel::NullValue(), -1)); SimpleTreeSteps result{a_step.get(), b_step.get(), c_step.get()}; program_builder.EnterSubexpression(&a); program_builder.EnterSubexpression(&b); program_builder.AddStep(std::move(b_step)); program_builder.ExitSubexpression(&b); program_builder.EnterSubexpression(&c); program_builder.AddStep(std::move(c_step)); program_builder.ExitSubexpression(&c); program_builder.AddStep(std::move(a_step)); program_builder.ExitSubexpression(&a); return result; } TEST_F(PlannerContextTest, GetPlan) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(auto step_ptrs, InitSimpleTree(a, b, c, program_builder)); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(step_ptrs.b))); EXPECT_THAT(context.GetSubplan(c), ElementsAre(UniquePtrHolds(step_ptrs.c))); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(step_ptrs.b), UniquePtrHolds(step_ptrs.c), UniquePtrHolds(step_ptrs.a))); Expr d; EXPECT_FALSE(context.IsSubplanInspectable(d)); EXPECT_THAT(context.GetSubplan(d), IsEmpty()); } TEST_F(PlannerContextTest, ReplacePlan) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(auto step_ptrs, InitSimpleTree(a, b, c, program_builder)); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(step_ptrs.b), UniquePtrHolds(step_ptrs.c), UniquePtrHolds(step_ptrs.a))); ExecutionPath new_a; ASSERT_OK_AND_ASSIGN(auto new_a_step, CreateConstValueStep(cel::NullValue(), -1)); const ExpressionStep* new_a_step_ptr = new_a_step.get(); new_a.push_back(std::move(new_a_step)); ASSERT_THAT(context.ReplaceSubplan(a, std::move(new_a)), IsOk()); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(new_a_step_ptr))); EXPECT_THAT(context.GetSubplan(b), IsEmpty()); } TEST_F(PlannerContextTest, ExtractPlan) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, program_builder)); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); EXPECT_TRUE(context.IsSubplanInspectable(a)); EXPECT_TRUE(context.IsSubplanInspectable(b)); ASSERT_OK_AND_ASSIGN(ExecutionPath extracted, context.ExtractSubplan(b)); EXPECT_THAT(extracted, ElementsAre(UniquePtrHolds(plan_steps.b))); } TEST_F(PlannerContextTest, ExtractFailsOnReplacedNode) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_THAT(InitSimpleTree(a, b, c, program_builder).status(), IsOk()); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ASSERT_THAT(context.ReplaceSubplan(a, {}), IsOk()); EXPECT_THAT(context.ExtractSubplan(b), IsOkAndHolds(IsEmpty())); } TEST_F(PlannerContextTest, ReplacePlanUpdatesParent) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, program_builder)); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); EXPECT_TRUE(context.IsSubplanInspectable(a)); ASSERT_THAT(context.ReplaceSubplan(c, {}), IsOk()); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(plan_steps.b), UniquePtrHolds(plan_steps.a))); EXPECT_THAT(context.GetSubplan(c), IsEmpty()); } TEST_F(PlannerContextTest, ReplacePlanUpdatesSibling) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, program_builder)); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ExecutionPath new_b; ASSERT_OK_AND_ASSIGN(auto b1_step, CreateConstValueStep(cel::NullValue(), -1)); const ExpressionStep* b1_step_ptr = b1_step.get(); new_b.push_back(std::move(b1_step)); ASSERT_OK_AND_ASSIGN(auto b2_step, CreateConstValueStep(cel::NullValue(), -1)); const ExpressionStep* b2_step_ptr = b2_step.get(); new_b.push_back(std::move(b2_step)); ASSERT_THAT(context.ReplaceSubplan(b, std::move(new_b)), IsOk()); EXPECT_THAT(context.GetSubplan(c), ElementsAre(UniquePtrHolds(plan_steps.c))); EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(b1_step_ptr), UniquePtrHolds(b2_step_ptr))); EXPECT_THAT( context.GetSubplan(a), ElementsAre(UniquePtrHolds(b1_step_ptr), UniquePtrHolds(b2_step_ptr), UniquePtrHolds(plan_steps.c), UniquePtrHolds(plan_steps.a))); } TEST_F(PlannerContextTest, ReplacePlanFailsOnUpdatedNode) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, program_builder)); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(a), ElementsAre(UniquePtrHolds(plan_steps.b), UniquePtrHolds(plan_steps.c), UniquePtrHolds(plan_steps.a))); ASSERT_THAT(context.ReplaceSubplan(a, {}), IsOk()); EXPECT_THAT(context.ReplaceSubplan(b, {}), IsOk()); } TEST_F(PlannerContextTest, AddSubplanStep) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(auto plan_steps, InitSimpleTree(a, b, c, program_builder)); ASSERT_OK_AND_ASSIGN(auto b2_step, CreateConstValueStep(cel::NullValue(), -1)); const ExpressionStep* b2_step_ptr = b2_step.get(); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); ASSERT_THAT(context.AddSubplanStep(b, std::move(b2_step)), IsOk()); EXPECT_THAT(context.GetSubplan(b), ElementsAre(UniquePtrHolds(plan_steps.b), UniquePtrHolds(b2_step_ptr))); EXPECT_THAT(context.GetSubplan(c), ElementsAre(UniquePtrHolds(plan_steps.c))); EXPECT_THAT( context.GetSubplan(a), ElementsAre(UniquePtrHolds(plan_steps.b), UniquePtrHolds(b2_step_ptr), UniquePtrHolds(plan_steps.c), UniquePtrHolds(plan_steps.a))); } TEST_F(PlannerContextTest, AddSubplanStepFailsOnUnknownNode) { Expr a; Expr b; Expr c; Expr d; ProgramBuilder program_builder; ASSERT_THAT(InitSimpleTree(a, b, c, program_builder).status(), IsOk()); ASSERT_OK_AND_ASSIGN(auto b2_step, CreateConstValueStep(cel::NullValue(), -1)); std::shared_ptr arena; PlannerContext context(env_, resolver_, options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder, arena); EXPECT_THAT(context.GetSubplan(d), IsEmpty()); EXPECT_THAT(context.AddSubplanStep(d, std::move(b2_step)), StatusIs(absl::StatusCode::kInternal)); } class ProgramBuilderTest : public testing::Test { public: ProgramBuilderTest() : type_registry_(), function_registry_() {} protected: cel::TypeRegistry type_registry_; cel::FunctionRegistry function_registry_; }; TEST_F(ProgramBuilderTest, ExtractSubexpression) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(SimpleTreeSteps step_ptrs, InitSimpleTree(a, b, c, program_builder)); EXPECT_EQ(program_builder.ExtractSubexpression(&c), 0); EXPECT_EQ(program_builder.ExtractSubexpression(&b), 1); EXPECT_THAT(program_builder.FlattenMain(), ElementsAre(UniquePtrHolds(step_ptrs.a))); EXPECT_THAT(program_builder.FlattenSubexpressions(), ElementsAre(ElementsAre(UniquePtrHolds(step_ptrs.c)), ElementsAre(UniquePtrHolds(step_ptrs.b)))); } TEST_F(ProgramBuilderTest, FlattenRemovesChildrenReferences) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; program_builder.EnterSubexpression(&a); program_builder.EnterSubexpression(&b); program_builder.EnterSubexpression(&c); program_builder.ExitSubexpression(&c); program_builder.ExitSubexpression(&b); program_builder.ExitSubexpression(&a); auto subexpr_b = program_builder.GetSubexpression(&b); ASSERT_TRUE(subexpr_b != nullptr); subexpr_b->Flatten(); auto* subexpr_c = program_builder.GetSubexpression(&c); EXPECT_EQ(subexpr_b->ExtractChild(subexpr_c), nullptr); } TEST_F(ProgramBuilderTest, ExtractReturnsNullOnFlattendExpr) { Expr a; Expr b; ProgramBuilder program_builder; program_builder.EnterSubexpression(&a); program_builder.EnterSubexpression(&b); program_builder.ExitSubexpression(&b); program_builder.ExitSubexpression(&a); auto* subexpr_a = program_builder.GetSubexpression(&a); auto* subexpr_b = program_builder.GetSubexpression(&b); ASSERT_TRUE(subexpr_a != nullptr); ASSERT_TRUE(subexpr_b != nullptr); subexpr_a->Flatten(); // subexpr_b is now freed. EXPECT_EQ(subexpr_a->ExtractChild(subexpr_b), nullptr); EXPECT_EQ(program_builder.ExtractSubexpression(&b), -1); } TEST_F(ProgramBuilderTest, ExtractReturnsNullOnNonChildren) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; program_builder.EnterSubexpression(&a); program_builder.EnterSubexpression(&b); program_builder.EnterSubexpression(&c); program_builder.ExitSubexpression(&c); program_builder.ExitSubexpression(&b); program_builder.ExitSubexpression(&a); auto* subexpr_a = program_builder.GetSubexpression(&a); auto* subexpr_c = program_builder.GetSubexpression(&c); ASSERT_TRUE(subexpr_a != nullptr); ASSERT_TRUE(subexpr_c != nullptr); EXPECT_EQ(subexpr_a->ExtractChild(subexpr_c), nullptr); } TEST_F(ProgramBuilderTest, ResetWorks) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; program_builder.EnterSubexpression(&a); program_builder.EnterSubexpression(&b); program_builder.EnterSubexpression(&c); program_builder.ExitSubexpression(&c); program_builder.ExitSubexpression(&b); program_builder.ExitSubexpression(&a); auto* subexpr_a = program_builder.GetSubexpression(&a); auto* subexpr_c = program_builder.GetSubexpression(&c); ASSERT_TRUE(subexpr_a != nullptr); ASSERT_TRUE(subexpr_c != nullptr); program_builder.Reset(); subexpr_a = program_builder.GetSubexpression(&a); subexpr_c = program_builder.GetSubexpression(&c); ASSERT_TRUE(subexpr_a == nullptr); ASSERT_TRUE(subexpr_c == nullptr); } TEST_F(ProgramBuilderTest, ExtractWorks) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; program_builder.EnterSubexpression(&a); program_builder.EnterSubexpression(&b); program_builder.ExitSubexpression(&b); ASSERT_OK_AND_ASSIGN(auto a_step, CreateConstValueStep(cel::NullValue(), -1)); program_builder.AddStep(std::move(a_step)); program_builder.EnterSubexpression(&c); program_builder.ExitSubexpression(&c); program_builder.ExitSubexpression(&a); auto* subexpr_a = program_builder.GetSubexpression(&a); auto* subexpr_c = program_builder.GetSubexpression(&c); ASSERT_TRUE(subexpr_a != nullptr); ASSERT_TRUE(subexpr_c != nullptr); EXPECT_EQ(subexpr_a->ExtractChild(subexpr_c), subexpr_c); } TEST_F(ProgramBuilderTest, ExtractToRequiresFlatten) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; ASSERT_OK_AND_ASSIGN(SimpleTreeSteps step_ptrs, InitSimpleTree(a, b, c, program_builder)); auto* subexpr_a = program_builder.GetSubexpression(&a); ExecutionPath path; EXPECT_FALSE(subexpr_a->ExtractTo(path)); subexpr_a->Flatten(); EXPECT_TRUE(subexpr_a->ExtractTo(path)); EXPECT_THAT(path, ElementsAre(UniquePtrHolds(step_ptrs.b), UniquePtrHolds(step_ptrs.c), UniquePtrHolds(step_ptrs.a))); } TEST_F(ProgramBuilderTest, Recursive) { Expr a; Expr b; Expr c; ProgramBuilder program_builder; program_builder.EnterSubexpression(&a); program_builder.EnterSubexpression(&b); program_builder.current()->set_recursive_program( CreateConstValueDirectStep(cel::NullValue()), 1); program_builder.ExitSubexpression(&b); program_builder.EnterSubexpression(&c); program_builder.current()->set_recursive_program( CreateConstValueDirectStep(cel::NullValue()), 1); program_builder.ExitSubexpression(&c); ASSERT_FALSE(program_builder.current()->IsFlattened()); ASSERT_FALSE(program_builder.current()->IsRecursive()); ASSERT_TRUE(program_builder.GetSubexpression(&b)->IsRecursive()); ASSERT_TRUE(program_builder.GetSubexpression(&c)->IsRecursive()); EXPECT_EQ(program_builder.GetSubexpression(&b)->recursive_program().depth, 1); EXPECT_EQ(program_builder.GetSubexpression(&c)->recursive_program().depth, 1); cel::CallExpr call_expr; call_expr.set_function("_==_"); call_expr.mutable_args().emplace_back(); call_expr.mutable_args().emplace_back(); auto max_depth = program_builder.current()->RecursiveDependencyDepth(); EXPECT_THAT(max_depth, Optional(1)); auto deps = program_builder.current()->ExtractRecursiveDependencies(); program_builder.current()->set_recursive_program( CreateDirectFunctionStep(-1, call_expr, std::move(deps), {}), *max_depth + 1); program_builder.ExitSubexpression(&a); auto path = program_builder.FlattenMain(); ASSERT_THAT(path, testing::SizeIs(1)); EXPECT_TRUE(path[0]->GetNativeTypeId() == cel::NativeTypeId::For()); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/flat_expr_builder_short_circuiting_conformance_test.cc ================================================ // A collection of tests that confirm that short-circuit and non-short-circuit // produce expressions with the same outputs. #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "base/builtins.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" #include "internal/testing.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::Expr; using ::testing::Eq; using ::testing::SizeIs; constexpr char kTwoLogicalOp[] = R"cel( id: 1 call_expr { function: "$0" args { id: 2 ident_expr { name: "var1", } } args { id: 3 call_expr { function: "$0" args { id: 4 ident_expr { name: "var2" } } args { id: 5 ident_expr { name: "var3" } } } } } )cel"; constexpr char kTernaryExpr[] = R"cel( id: 1 call_expr { function: "_?_:_" args { id: 2 ident_expr { name: "cond" } } args { id: 3 ident_expr { name: "arg1" } } args { id: 4 ident_expr { name: "arg2" } } })cel"; void BuildAndEval(CelExpressionBuilder* builder, const Expr& expr, const Activation& activation, google::protobuf::Arena* arena, CelValue* result) { ASSERT_OK_AND_ASSIGN(auto expression, builder->CreateExpression(&expr, nullptr)); auto value = expression->Evaluate(activation, arena); ASSERT_OK(value); *result = *value; } class ShortCircuitingTest : public testing::TestWithParam { public: std::unique_ptr GetBuilder( bool enable_unknowns = false) { cel::RuntimeOptions options; options.short_circuiting = GetParam(); if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; } auto result = std::make_unique( NewTestingRuntimeEnv(), options); return result; } }; TEST_P(ShortCircuitingTest, BasicAnd) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(); activation.InsertValue("var1", CelValue::CreateBool(true)); activation.InsertValue("var2", CelValue::CreateBool(true)); activation.InsertValue("var3", CelValue::CreateBool(false)); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_FALSE(result.BoolOrDie()); ASSERT_TRUE(activation.RemoveValueEntry("var3")); activation.InsertValue("var3", CelValue::CreateBool(true)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST_P(ShortCircuitingTest, BasicOr) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(); activation.InsertValue("var1", CelValue::CreateBool(false)); activation.InsertValue("var2", CelValue::CreateBool(false)); activation.InsertValue("var3", CelValue::CreateBool(true)); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); ASSERT_TRUE(activation.RemoveValueEntry("var3")); activation.InsertValue("var3", CelValue::CreateBool(false)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_FALSE(result.BoolOrDie()); } TEST_P(ShortCircuitingTest, ErrorAnd) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(); absl::Status error = absl::InternalError("error"); activation.InsertValue("var1", CelValue::CreateBool(true)); activation.InsertValue("var2", CelValue::CreateError(&error)); activation.InsertValue("var3", CelValue::CreateBool(false)); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_FALSE(result.BoolOrDie()); ASSERT_TRUE(activation.RemoveValueEntry("var3")); activation.InsertValue("var3", CelValue::CreateBool(true)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), Eq(absl::Status(absl::StatusCode::kInternal, "error"))); } TEST_P(ShortCircuitingTest, ErrorOr) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(); absl::Status error = absl::InternalError("error"); activation.InsertValue("var1", CelValue::CreateBool(false)); activation.InsertValue("var2", CelValue::CreateError(&error)); activation.InsertValue("var3", CelValue::CreateBool(true)); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); ASSERT_TRUE(activation.RemoveValueEntry("var3")); activation.InsertValue("var3", CelValue::CreateBool(false)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), Eq(absl::Status(absl::StatusCode::kInternal, "error"))); } TEST_P(ShortCircuitingTest, UnknownAnd) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( absl::Substitute(kTwoLogicalOp, ::cel::builtin::kAnd), &expr)); auto builder = GetBuilder(/* enable_unknowns=*/true); absl::Status error = absl::InternalError("error"); activation.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); activation.InsertValue("var2", CelValue::CreateError(&error)); activation.InsertValue("var3", CelValue::CreateBool(false)); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_FALSE(result.BoolOrDie()); ASSERT_TRUE(activation.RemoveValueEntry("var3")); activation.InsertValue("var3", CelValue::CreateBool(true)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); const UnknownAttributeSet& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, testing::SizeIs(1)); EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("var1")); } TEST_P(ShortCircuitingTest, UnknownOr) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( absl::Substitute(kTwoLogicalOp, ::cel::builtin::kOr), &expr)); auto builder = GetBuilder(/* enable_unknowns=*/true); absl::Status error = absl::InternalError("error"); activation.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); activation.InsertValue("var2", CelValue::CreateError(&error)); activation.InsertValue("var3", CelValue::CreateBool(true)); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); ASSERT_TRUE(activation.RemoveValueEntry("var3")); activation.InsertValue("var3", CelValue::CreateBool(false)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); const UnknownAttributeSet& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, testing::SizeIs(1)); EXPECT_THAT(attrs.begin()->variable_name(), testing::Eq("var1")); } TEST_P(ShortCircuitingTest, BasicTernary) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kTernaryExpr, &expr)); auto builder = GetBuilder(); activation.InsertValue("cond", CelValue::CreateBool(true)); activation.InsertValue("arg1", CelValue::CreateUint64(1)); activation.InsertValue("arg2", CelValue::CreateInt64(-1)); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUint64()); EXPECT_EQ(result.Uint64OrDie(), 1); ASSERT_TRUE(activation.RemoveValueEntry("cond")); activation.InsertValue("cond", CelValue::CreateBool(false)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), -1); } TEST_P(ShortCircuitingTest, TernaryErrorHandling) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kTernaryExpr, &expr)); auto builder = GetBuilder(); absl::Status error1 = absl::InternalError("error1"); absl::Status error2 = absl::InternalError("error2"); activation.InsertValue("cond", CelValue::CreateError(&error1)); activation.InsertValue("arg1", CelValue::CreateError(&error2)); activation.InsertValue("arg2", CelValue::CreateInt64(-1)); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsError()); EXPECT_EQ(*result.ErrorOrDie(), error1); ASSERT_TRUE(activation.RemoveValueEntry("cond")); activation.InsertValue("cond", CelValue::CreateBool(false)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), -1); } TEST_P(ShortCircuitingTest, TernaryUnknownCondHandling) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kTernaryExpr, &expr)); auto builder = GetBuilder(/*enable_unknowns=*/true); absl::Status error = absl::InternalError("error1"); activation.InsertValue("cond", CelValue::CreateBool(false)); activation.InsertValue("arg1", CelValue::CreateError(&error)); activation.InsertValue("arg2", CelValue::CreateInt64(-1)); activation.set_unknown_attribute_patterns({CelAttributePattern("cond", {})}); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, SizeIs(1)); EXPECT_THAT(attrs.begin()->variable_name(), Eq("cond")); // Unknown branches are discarded if condition is unknown activation.set_unknown_attribute_patterns({CelAttributePattern("cond", {}), CelAttributePattern("arg1", {}), CelAttributePattern("arg2", {})}); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); const auto& attrs2 = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs2, SizeIs(1)); EXPECT_THAT(attrs2.begin()->variable_name(), Eq("cond")); } TEST_P(ShortCircuitingTest, TernaryUnknownArgsHandling) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kTernaryExpr, &expr)); auto builder = GetBuilder(/*enable_unknowns=*/true); absl::Status error = absl::InternalError("error1"); activation.InsertValue("cond", CelValue::CreateBool(false)); activation.InsertValue("arg1", CelValue::CreateError(&error)); activation.InsertValue("arg2", CelValue::CreateInt64(-1)); // Unknown arg is discarded if condition chooses other branch. activation.set_unknown_attribute_patterns({CelAttributePattern("arg1", {})}); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), -1); // Branches won't merge if both are unknown. activation.set_unknown_attribute_patterns( {CelAttributePattern("arg1", {}), CelAttributePattern("arg2", {})}); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); const auto& attrs3 = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs3, SizeIs(1)); EXPECT_EQ(attrs3.begin()->variable_name(), "arg2"); } TEST_P(ShortCircuitingTest, TernaryUnknownAndErrorHandling) { Expr expr; Activation activation; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kTernaryExpr, &expr)); auto builder = GetBuilder(/*enable_unknowns=*/true); absl::Status error = absl::InternalError("error1"); activation.InsertValue("cond", CelValue::CreateError(&error)); activation.InsertValue("arg1", CelValue::CreateInt64(1)); activation.InsertValue("arg2", CelValue::CreateInt64(-1)); // Error cond discards args activation.set_unknown_attribute_patterns( {CelAttributePattern("arg1", {}), CelAttributePattern("arg2", {})}); CelValue result; ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsError()); EXPECT_EQ(*result.ErrorOrDie(), error); // Error arg discarded if condition unknown activation.set_unknown_attribute_patterns({CelAttributePattern("cond", {})}); ASSERT_TRUE(activation.RemoveValueEntry("arg1")); activation.InsertValue("arg1", CelValue::CreateError(&error)); ASSERT_NO_FATAL_FAILURE( BuildAndEval(builder.get(), expr, activation, &arena, &result)); ASSERT_TRUE(result.IsUnknownSet()); const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, SizeIs(1)); EXPECT_EQ(attrs.begin()->variable_name(), "cond"); } const char* TestName(testing::TestParamInfo info) { if (info.param) { return "short_circuit_enabled"; } else { return "short_circuit_disabled"; } } INSTANTIATE_TEST_SUITE_P(Test, ShortCircuitingTest, testing::Values(false, true), &TestName); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/flat_expr_builder_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/flat_expr_builder.h" #include #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "base/builtins.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "common/value.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/qualified_reference_resolver.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/portable_cel_function_adapter.h" #include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/function.h" #include "runtime/function_adapter.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "runtime/standard_functions.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::BytesValue; using ::cel::Value; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::test::EqualsProto; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::testing::_; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::SizeIs; using ::testing::Truly; class ConcatFunction : public CelFunction { public: explicit ConcatFunction() : CelFunction(CreateDescriptor()) {} static CelFunctionDescriptor CreateDescriptor() { return CelFunctionDescriptor{ "concat", false, {CelValue::Type::kString, CelValue::Type::kString}}; } absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 2) { return absl::InvalidArgumentError("Bad arguments number"); } std::string concat = std::string(args[0].StringOrDie().value()) + std::string(args[1].StringOrDie().value()); auto* concatenated = google::protobuf::Arena::Create(arena, std::move(concat)); *result = CelValue::CreateString(concatenated); return absl::OkStatus(); } }; class RecorderFunction : public CelFunction { public: explicit RecorderFunction(const std::string& name, int* count) : CelFunction(CelFunctionDescriptor{name, false, {}}), count_(count) {} absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (!args.empty()) { return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } (*count_)++; *result = CelValue::CreateBool(true); return absl::OkStatus(); } int* count_; }; TEST(FlatExprBuilderTest, SimpleEndToEnd) { Expr expr; SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function("concat"); auto arg1 = call_expr->add_args(); arg1->mutable_const_expr()->set_string_value("prefix"); auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value"); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT( builder.GetRegistry()->Register(std::make_unique()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); std::string variable = "test"; Activation activation; activation.InsertValue("value", CelValue::CreateString(&variable)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsString()); EXPECT_THAT(result.StringOrDie().value(), Eq("prefixtest")); } TEST(FlatExprBuilderTest, ExprUnset) { Expr expr; SourceInfo source_info; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid empty expression"))); } TEST(FlatExprBuilderTest, RuntimeExtensionsError) { Expr expr; SourceInfo source_info; auto* ext = source_info.add_extensions(); ext->set_id("ext1"); ext->add_affected_components( cel::expr::SourceInfo_Extension_Component_COMPONENT_RUNTIME); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("unsupported CEL extension: ext1"))); } TEST(FlatExprBuilderTest, ConstValueUnset) { Expr expr; SourceInfo source_info; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Create an empty constant expression to ensure that it triggers an error. expr.mutable_const_expr(); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("unspecified constant"))); } TEST(FlatExprBuilderTest, MapKeyValueUnset) { Expr expr; SourceInfo source_info; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Don't set either the key or the value for the map creation step. auto* entry = expr.mutable_struct_expr()->add_entries(); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Map entry missing key"))); // Set the entry key, but not the value. entry->mutable_map_key()->mutable_const_expr()->set_bool_value(true); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Map entry missing value"))); } TEST(FlatExprBuilderTest, MessageFieldValueUnset) { Expr expr; SourceInfo source_info; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Don't set either the field or the value for the message creation step. auto* create_message = expr.mutable_struct_expr(); create_message->set_message_name("google.protobuf.Value"); auto* entry = create_message->add_entries(); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Struct field missing name"))); // Set the entry field, but not the value. entry->set_field_key("bool_value"); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Struct field missing value"))); } TEST(FlatExprBuilderTest, BinaryCallTooManyArguments) { Expr expr; SourceInfo source_info; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); auto* call = expr.mutable_call_expr(); call->set_function(builtin::kAnd); call->mutable_target()->mutable_const_expr()->set_string_value("random"); call->add_args()->mutable_const_expr()->set_bool_value(false); call->add_args()->mutable_const_expr()->set_bool_value(true); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid argument count"))); } TEST(FlatExprBuilderTest, TernaryCallTooManyArguments) { Expr expr; SourceInfo source_info; auto* call = expr.mutable_call_expr(); call->set_function(builtin::kTernary); call->mutable_target()->mutable_const_expr()->set_string_value("random"); call->add_args()->mutable_const_expr()->set_bool_value(false); call->add_args()->mutable_const_expr()->set_int64_value(1); call->add_args()->mutable_const_expr()->set_int64_value(2); { cel::RuntimeOptions options; options.short_circuiting = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid argument count"))); } // Disable short-circuiting to ensure that a different visitor is used. { cel::RuntimeOptions options; options.short_circuiting = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid argument count"))); } } TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { Expr expr; SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function("concat"); auto arg1 = call_expr->add_args(); arg1->mutable_const_expr()->set_string_value("prefix"); auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value"); cel::RuntimeOptions options; options.fail_on_warnings = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector warnings; // Concat function not registered. ASSERT_OK_AND_ASSIGN( auto cel_expr, builder.CreateExpression(&expr, &source_info, &warnings)); std::string variable = "test"; Activation activation; activation.InsertValue("value", CelValue::CreateString(&variable)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie()->message(), Eq("No matching overloads found : concat(string, string)")); ASSERT_THAT(warnings, testing::SizeIs(1)); EXPECT_EQ(warnings[0].code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(std::string(warnings[0].message()), testing::HasSubstr("No overloads provided")); } TEST(FlatExprBuilderTest, Shortcircuiting) { Expr expr; SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function("_||_"); auto arg1 = call_expr->add_args(); arg1->mutable_call_expr()->set_function("recorder1"); auto arg2 = call_expr->add_args(); arg2->mutable_call_expr()->set_function("recorder2"); Activation activation; google::protobuf::Arena arena; // Shortcircuiting on { cel::RuntimeOptions options; options.short_circuiting = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count1 = 0; int count2 = 0; ASSERT_THAT(builder.GetRegistry()->Register( std::make_unique("recorder1", &count1)), IsOk()); ASSERT_THAT(builder.GetRegistry()->Register( std::make_unique("recorder2", &count2)), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr_on, builder.CreateExpression(&expr, &source_info)); ASSERT_THAT(cel_expr_on->Evaluate(activation, &arena), IsOk()); EXPECT_THAT(count1, Eq(1)); EXPECT_THAT(count2, Eq(0)); } // Shortcircuiting off. { cel::RuntimeOptions options; options.short_circuiting = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count1 = 0; int count2 = 0; ASSERT_THAT(builder.GetRegistry()->Register( std::make_unique("recorder1", &count1)), IsOk()); ASSERT_THAT(builder.GetRegistry()->Register( std::make_unique("recorder2", &count2)), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr_off, builder.CreateExpression(&expr, &source_info)); ASSERT_THAT(cel_expr_off->Evaluate(activation, &arena), IsOk()); EXPECT_THAT(count1, Eq(1)); EXPECT_THAT(count2, Eq(1)); } } TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { Expr expr; SourceInfo source_info; auto comprehension_expr = expr.mutable_comprehension_expr(); comprehension_expr->set_iter_var("x"); auto list_expr = comprehension_expr->mutable_iter_range()->mutable_list_expr(); list_expr->add_elements()->mutable_const_expr()->set_int64_value(1); list_expr->add_elements()->mutable_const_expr()->set_int64_value(2); list_expr->add_elements()->mutable_const_expr()->set_int64_value(3); comprehension_expr->set_accu_var("accu"); comprehension_expr->mutable_accu_init()->mutable_const_expr()->set_bool_value( false); comprehension_expr->mutable_loop_condition() ->mutable_const_expr() ->set_bool_value(false); comprehension_expr->mutable_loop_step()->mutable_call_expr()->set_function( "recorder_function1"); comprehension_expr->mutable_result()->mutable_const_expr()->set_bool_value( false); Activation activation; google::protobuf::Arena arena; // shortcircuiting on { cel::RuntimeOptions options; options.short_circuiting = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count = 0; ASSERT_THAT( builder.GetRegistry()->Register( std::make_unique("recorder_function1", &count)), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr_on, builder.CreateExpression(&expr, &source_info)); ASSERT_THAT(cel_expr_on->Evaluate(activation, &arena), IsOk()); EXPECT_THAT(count, Eq(0)); } // shortcircuiting off { cel::RuntimeOptions options; options.short_circuiting = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count = 0; ASSERT_THAT( builder.GetRegistry()->Register( std::make_unique("recorder_function1", &count)), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr_off, builder.CreateExpression(&expr, &source_info)); ASSERT_THAT(cel_expr_off->Evaluate(activation, &arena), IsOk()); EXPECT_THAT(count, Eq(3)); } } TEST(FlatExprBuilderTest, IdentExprUnsetName) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'name' must not be empty"))); } TEST(FlatExprBuilderTest, SelectExprUnsetField) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(select_expr{ operand{ ident_expr {name: 'var'} } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'field' must not be empty"))); } TEST(FlatExprBuilderTest, SelectExprUnsetOperand) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(select_expr{ field: 'field' operand { id: 1 } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("must specify an operand"))); } TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'accu_var' must not be empty"))); } TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{accu_var: "a"} )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'iter_var' must not be empty"))); } TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: "a" iter_var: "b"} )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'accu_init' must be set"))); } TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' accu_init { const_expr {bool_value: true} }} )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'loop_condition' must be set"))); } TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' accu_init { const_expr {bool_value: true} } loop_condition { const_expr {bool_value: true} }} )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'loop_step' must be set"))); } TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' accu_init { const_expr {bool_value: true} } loop_condition { const_expr {bool_value: true} } loop_step { const_expr {bool_value: false} }} )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'result' must be set"))); } TEST(FlatExprBuilderTest, MapComprehension) { Expr expr; SourceInfo source_info; // {1: "", 2: ""}.all(x, x > 0) google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" accu_init { const_expr { bool_value: true } } loop_condition { ident_expr { name: "accu" } } result { ident_expr { name: "accu" } } loop_step { call_expr { function: "_&&_" args { ident_expr { name: "accu" } } args { call_expr { function: "_>_" args { ident_expr { name: "k" } } args { const_expr { int64_value: 0 } } } } } } iter_range { struct_expr { entries { map_key { const_expr { int64_value: 1 } } value { const_expr { string_value: "" } } } entries { map_key { const_expr { int64_value: 2 } } value { const_expr { string_value: "" } } } } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, InvalidContainer) { Expr expr; SourceInfo source_info; // foo && bar google::protobuf::TextFormat::ParseFromString(R"( call_expr { function: "_&&_" args { ident_expr { name: "foo" } } args { ident_expr { name: "bar" } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); builder.set_container(".bad"); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("container: '.bad'"))); builder.set_container("bad."); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("container: 'bad.'"))); } TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupport) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); using FunctionAdapterT = FunctionAdapter; ASSERT_OK(FunctionAdapterT::CreateAndRegister( "ext.XOr", /*receiver_style=*/false, [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); google::protobuf::Arena arena; Activation act1; act1.InsertValue("a", CelValue::CreateBool(false)); act1.InsertValue("b", CelValue::CreateBool(true)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); EXPECT_THAT(result, test::IsCelBool(true)); Activation act2; act2.InsertValue("a", CelValue::CreateBool(true)); act2.InsertValue("b", CelValue::CreateBool(true)); ASSERT_OK_AND_ASSIGN(result, cel_expr->Evaluate(act2, &arena)); EXPECT_THAT(result, test::IsCelBool(false)); } TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportWithContainer) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("XOr(a, b)")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("ext"); using FunctionAdapterT = FunctionAdapter; ASSERT_OK(FunctionAdapterT::CreateAndRegister( "ext.XOr", /*receiver_style=*/false, [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); google::protobuf::Arena arena; Activation act1; act1.InsertValue("a", CelValue::CreateBool(false)); act1.InsertValue("b", CelValue::CreateBool(true)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); EXPECT_THAT(result, test::IsCelBool(true)); Activation act2; act2.InsertValue("a", CelValue::CreateBool(true)); act2.InsertValue("b", CelValue::CreateBool(true)); ASSERT_OK_AND_ASSIGN(result, cel_expr->Evaluate(act2, &arena)); EXPECT_THAT(result, test::IsCelBool(false)); } TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrder) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); using FunctionAdapterT = FunctionAdapter; ASSERT_OK(FunctionAdapterT::CreateAndRegister( "a.b.c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return true; }, builder.GetRegistry())); ASSERT_OK(FunctionAdapterT::CreateAndRegister( "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); ASSERT_OK((FunctionAdapter::CreateAndRegister( "Get", /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, builder.GetRegistry()))); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); google::protobuf::Arena arena; Activation act1; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); EXPECT_THAT(result, test::IsCelBool(true)); } TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderParentContainer) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("c.d.Get()")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); using FunctionAdapterT = FunctionAdapter; ASSERT_OK(FunctionAdapterT::CreateAndRegister( "a.c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return true; }, builder.GetRegistry())); ASSERT_OK(FunctionAdapterT::CreateAndRegister( "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); ASSERT_OK((FunctionAdapter::CreateAndRegister( "Get", /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, builder.GetRegistry()))); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); google::protobuf::Arena arena; Activation act1; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); EXPECT_THAT(result, test::IsCelBool(true)); } TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderExplicitGlobal) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(".c.d.Get()")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); using FunctionAdapterT = FunctionAdapter; ASSERT_OK(FunctionAdapterT::CreateAndRegister( "a.c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); ASSERT_OK(FunctionAdapterT::CreateAndRegister( "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return true; }, builder.GetRegistry())); ASSERT_OK((FunctionAdapter::CreateAndRegister( "Get", /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return false; }, builder.GetRegistry()))); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); google::protobuf::Arena arena; Activation act1; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); EXPECT_THAT(result, test::IsCelBool(true)); } TEST(FlatExprBuilderTest, ParsedNamespacedFunctionResolutionOrderReceiverCall) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("e.Get()")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kAlways)); builder.set_container("a.b"); using FunctionAdapterT = FunctionAdapter; ASSERT_OK(FunctionAdapterT::CreateAndRegister( "a.c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); ASSERT_OK(FunctionAdapterT::CreateAndRegister( "c.d.Get", /*receiver_style=*/false, [](google::protobuf::Arena*) { return false; }, builder.GetRegistry())); ASSERT_OK((FunctionAdapter::CreateAndRegister( "Get", /*receiver_style=*/true, [](google::protobuf::Arena*, bool) { return true; }, builder.GetRegistry()))); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); google::protobuf::Arena arena; Activation act1; act1.InsertValue("e", CelValue::CreateBool(false)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); EXPECT_THAT(result, test::IsCelBool(true)); } TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("ext.XOr(a, b)")); cel::RuntimeOptions options; options.fail_on_warnings = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); std::vector build_warnings; builder.set_container("ext"); using FunctionAdapterT = FunctionAdapter; ASSERT_OK(FunctionAdapterT::CreateAndRegister( "ext.XOr", /*receiver_style=*/false, [](google::protobuf::Arena*, bool a, bool b) { return a != b; }, builder.GetRegistry())); ASSERT_OK_AND_ASSIGN( auto cel_expr, builder.CreateExpression(&expr.expr(), &expr.source_info(), &build_warnings)); google::protobuf::Arena arena; Activation act1; act1.InsertValue("a", CelValue::CreateBool(false)); act1.InsertValue("b", CelValue::CreateBool(true)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(act1, &arena)); EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kUnknown, HasSubstr("ext")))); } TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { CheckedExpr expr; // foo && bar google::protobuf::TextFormat::ParseFromString(R"( expr { id: 1 call_expr { function: "_&&_" args { id: 2 ident_expr { name: "foo" } } args { id: 3 ident_expr { name: "bar" } } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; activation.InsertValue("foo", CelValue::CreateBool(true)); activation.InsertValue("bar", CelValue::CreateBool(true)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { CheckedExpr expr; // `foo.var1` && `bar.var2` google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { name: "foo.var1" } } reference_map { key: 4 value { name: "bar.var2" } } expr { id: 1 call_expr { function: "_&&_" args { id: 2 select_expr { field: "var1" operand { id: 3 ident_expr { name: "foo" } } } } args { id: 4 select_expr { field: "var2" operand { ident_expr { name: "bar" } } } } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; activation.InsertValue("foo.var1", CelValue::CreateBool(true)); activation.InsertValue("bar.var2", CelValue::CreateBool(true)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { CheckedExpr expr; // ext.and(var1, bar.var2) google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 1 value { overload_id: "com.foo.ext.and" } } reference_map { key: 3 value { name: "com.foo.var1" } } reference_map { key: 4 value { name: "bar.var2" } } expr { id: 1 call_expr { function: "and" target { id: 2 ident_expr { name: "ext" } } args { id: 3 ident_expr { name: "var1" } } args { id: 4 select_expr { field: "var2" operand { id: 5 ident_expr { name: "bar" } } } } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); builder.set_container("com.foo"); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK((FunctionAdapter::CreateAndRegister( "com.foo.ext.and", false, [](google::protobuf::Arena*, bool lhs, bool rhs) { return lhs && rhs; }, builder.GetRegistry()))); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; activation.InsertValue("com.foo.var1", CelValue::CreateBool(true)); activation.InsertValue("bar.var2", CelValue::CreateBool(true)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { CheckedExpr expr; // && . google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { name: "foo.var1" } } reference_map { key: 5 value { name: "bar" } } expr { id: 1 call_expr { function: "_&&_" args { id: 2 select_expr { field: "var1" operand { id: 3 ident_expr { name: "foo" } } } } args { id: 4 select_expr { field: "var2" operand { id: 5 ident_expr { name: "bar" } } } } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; activation.InsertValue("foo.var1", CelValue::CreateBool(true)); // Activation tries to bind a namespaced variable but the reference map refers // to the container 'bar'. activation.InsertValue("bar.var2", CelValue::CreateBool(true)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*(result.ErrorOrDie()), StatusIs(absl::StatusCode::kUnknown, HasSubstr("No value with name \"bar\" found"))); // Re-run with the expected interpretation of `bar`.`var2` std::vector> map_pairs{ {CelValue::CreateStringView("var2"), CelValue::CreateBool(false)}}; std::unique_ptr map_value = *CreateContainerBackedMap(absl::MakeSpan(map_pairs)); activation.InsertValue("bar", CelValue::CreateMap(map_value.get())); ASSERT_OK_AND_ASSIGN(result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_FALSE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { CheckedExpr expr; // {`var1`: 'hello'} google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 3 value { name: "var1" value { int64_value: 1 } } } expr { id: 1 struct_expr { entries { id: 2 map_key { id: 3 ident_expr { name: "var1" } } value { id: 4 const_expr { string_value: "hello" } } } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); google::protobuf::Arena arena; builder.flat_expr_builder().AddProgramOptimizer( cel::runtime_internal::CreateConstantFoldingOptimizer()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsMap()); auto m = result.MapOrDie(); auto v = m->Get(&arena, CelValue::CreateInt64(1L)); EXPECT_THAT(v->StringOrDie().value(), Eq("hello")); } TEST(FlatExprBuilderTest, ComprehensionWorksForError) { Expr expr; SourceInfo source_info; // {}[0].all(x, x) should evaluate OK but return an error value google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" iter_range { id: 2 call_expr { function: "_[_]" args { id: 1 struct_expr { } } args { id: 3 const_expr { int64_value: 0 } } } } accu_var: "__result__" accu_init { id: 7 const_expr { bool_value: true } } loop_condition { id: 8 call_expr { function: "__not_strictly_false__" args { id: 9 ident_expr { name: "__result__" } } } } loop_step { id: 10 call_expr { function: "_&&_" args { id: 11 ident_expr { name: "__result__" } } args { id: 6 ident_expr { name: "x" } } } } result { id: 12 ident_expr { name: "__result__" } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsError()); } TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { Expr expr; SourceInfo source_info; // 0.all(x, x) should evaluate OK but return an error value. google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" iter_range { id: 2 const_expr { int64_value: 0 } } accu_var: "__result__" accu_init { id: 7 const_expr { bool_value: true } } loop_condition { id: 8 call_expr { function: "__not_strictly_false__" args { id: 9 ident_expr { name: "__result__" } } } } loop_step { id: 10 call_expr { function: "_&&_" args { id: 11 ident_expr { name: "__result__" } } args { id: 6 ident_expr { name: "x" } } } } result { id: 12 ident_expr { name: "__result__" } } })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie()->message(), Eq("No matching overloads found : ")); } TEST(FlatExprBuilderTest, ComprehensionBudget) { Expr expr; SourceInfo source_info; // [1, 2].all(x, x > 0) ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" accu_init { const_expr { bool_value: true } } loop_condition { ident_expr { name: "accu" } } result { ident_expr { name: "accu" } } loop_step { call_expr { function: "_&&_" args { ident_expr { name: "accu" } } args { call_expr { function: "_>_" args { ident_expr { name: "k" } } args { const_expr { int64_value: 0 } } } } } } iter_range { list_expr { elements { const_expr { int64_value: 1 } } elements { const_expr { int64_value: 2 } } } } })", &expr)); cel::RuntimeOptions options; options.comprehension_max_iterations = 1; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; EXPECT_THAT(cel_expr->Evaluate(activation, &arena).status(), StatusIs(absl::StatusCode::kInternal, HasSubstr("Iteration budget exceeded"))); } TEST(FlatExprBuilderTest, SimpleEnumTest) { TestMessage message; Expr expr; SourceInfo source_info; constexpr char enum_name[] = "google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1"; std::vector enum_name_parts = absl::StrSplit(enum_name, '.'); Expr* cur_expr = &expr; for (int i = enum_name_parts.size() - 1; i > 0; i--) { auto select_expr = cur_expr->mutable_select_expr(); select_expr->set_field(enum_name_parts[i]); cur_expr = select_expr->mutable_operand(); } cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } TEST(FlatExprBuilderTest, SimpleEnumIdentTest) { TestMessage message; Expr expr; SourceInfo source_info; constexpr char enum_name[] = "google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1"; Expr* cur_expr = &expr; cur_expr->mutable_ident_expr()->set_name(enum_name); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } TEST(FlatExprBuilderTest, ContainerStringFormat) { Expr expr; SourceInfo source_info; expr.mutable_ident_expr()->set_name("ident"); { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.set_container(""); ASSERT_THAT(builder.CreateExpression(&expr, &source_info), IsOk()); } { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.set_container("random.namespace"); ASSERT_THAT(builder.CreateExpression(&expr, &source_info), IsOk()); } { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Leading '.' builder.set_container(".random.namespace"); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid expression container"))); } { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); // Trailing '.' builder.set_container("random.namespace."); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid expression container"))); } } // Builder with google.api.expr.runtime.TestMessage and TestEnum types // linked in and the standard functions registered. CelExpressionBuilderFlatImpl BuilderForNameResolutionTest( absl::string_view container) { cel::RuntimeOptions options; options.enable_qualified_type_identifiers = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); builder.GetTypeRegistry()->Register(TestEnum_descriptor()); builder.set_container(std::string(container)); ABSL_CHECK_OK(cel::RegisterStandardFunctions( builder.GetRegistry()->InternalGetRegistry(), options)); return builder; } TEST(FlatExprBuilderTest, ShortEnumResolution) { google::protobuf::Arena arena; CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google.api.expr.runtime.TestMessage"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("TestMessage.TestEnum.TEST_ENUM_1")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } TEST(FlatExprBuilderTest, EnumResolutionHonorsLeadingDot) { google::protobuf::Arena arena; CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google.api.expr.runtime"); // Leading dot disables container resolution. ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(".TestMessage.TestEnum.TEST_ENUM_1")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsError()); EXPECT_THAT( result.ErrorOrDie()->message(), HasSubstr("No value with name \"TestMessage\" found in Activation")); } TEST(FlatExprBuilderTest, EnumResolutionComprehensionShadowing) { google::protobuf::Arena arena; CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google.api.expr.runtime"); // Prefer the interpretation that it's a comprehension var if there's a // collision. ASSERT_OK_AND_ASSIGN( ParsedExpr expr, parser::Parse("[{'TestEnum': {'TEST_ENUM_1': 42}}].map(TestMessage, " "TestMessage.TestEnum.TEST_ENUM_1)[0] == 42")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, EnumResolutionComprehensionShadowingLeadingDot) { google::protobuf::Arena arena; CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google.api.expr.runtime"); // Prefer the interpretation that it's a comprehension var if there's a // collision. ASSERT_OK_AND_ASSIGN( ParsedExpr expr, parser::Parse("[0].map(google, " ".google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1)" "[0] == TestMessage.TestEnum.TEST_ENUM_1")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, FullEnumNameWithContainerResolution) { google::protobuf::Arena arena; CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("very.random.Namespace"); // Fully qualified name should work. ASSERT_OK_AND_ASSIGN( ParsedExpr expr, parser::Parse( "google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_1")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } TEST(FlatExprBuilderTest, SameShortNameEnumResolution) { google::protobuf::Arena arena; // This precondition validates that // TestMessage::TestEnum::TEST_ENUM1 and TestEnum::TEST_ENUM1 are compiled and // linked in and their values are different. ASSERT_TRUE(static_cast(TestEnum::TEST_ENUM_1) != static_cast(TestMessage::TEST_ENUM_1)); { CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google.api.expr.runtime.TestMessage"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("TestEnum.TEST_ENUM_1")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } // TEST_ENUM3 is present in google.api.expr.runtime.TestEnum, is absent in // google.api.expr.runtime.TestMessage.TestEnum. { CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google.api.expr.runtime.TestMessage"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("TestEnum.TEST_ENUM_3")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestEnum::TEST_ENUM_3)); } { CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google.api.expr.runtime"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("TestEnum.TEST_ENUM_1")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestEnum::TEST_ENUM_1)); } } TEST(FlatExprBuilderTest, PartialQualifiedEnumResolution) { google::protobuf::Arena arena; CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google.api.expr"); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, parser::Parse("runtime.TestMessage.TestEnum.TEST_ENUM_1")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } TEST(FlatExprBuilderTest, NameCollisionWithComprehensionVar) { google::protobuf::Arena arena; CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[0].map(x, x)[0]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; activation.InsertValue("x", CelValue::CreateInt64(1)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(0)); } TEST(FlatExprBuilderTest, NameCollisionWithComprehensionVarLeadingDot) { google::protobuf::Arena arena; CelExpressionBuilderFlatImpl builder = BuilderForNameResolutionTest("google"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[0].map(x, .x)[0]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; activation.InsertValue("x", CelValue::CreateInt64(1)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(1)); } TEST(FlatExprBuilderTest, MapFieldPresence) { Expr expr; SourceInfo source_info; google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { id: 2 ident_expr{ name: "msg" } } field: "string_int32_map" test_only: true })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; { TestMessage message; auto strMap = message.mutable_string_int32_map(); strMap->insert({"key", 1}); Activation activation; activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&message, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } { TestMessage message; Activation activation; activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&message, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } } TEST(FlatExprBuilderTest, RepeatedFieldPresence) { Expr expr; SourceInfo source_info; google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { id: 2 ident_expr{ name: "msg" } } field: "int32_list" test_only: true })", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; { TestMessage message; message.add_int32_list(1); Activation activation; activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&message, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } { TestMessage message; Activation activation; activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&message, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } } absl::Status RunTernaryExpression(CelValue selector, CelValue value1, CelValue value2, google::protobuf::Arena* arena, CelValue* result) { Expr expr; SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function(builtin::kTernary); auto arg0 = call_expr->add_args(); arg0->mutable_ident_expr()->set_name("selector"); auto arg1 = call_expr->add_args(); arg1->mutable_ident_expr()->set_name("value1"); auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value2"); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); CEL_ASSIGN_OR_RETURN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); std::string variable = "test"; Activation activation; activation.InsertValue("selector", selector); activation.InsertValue("value1", value1); activation.InsertValue("value2", value2); CEL_ASSIGN_OR_RETURN(auto eval, cel_expr->Evaluate(activation, arena)); *result = eval; return absl::OkStatus(); } TEST(FlatExprBuilderTest, Ternary) { Expr expr; SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function(builtin::kTernary); auto arg0 = call_expr->add_args(); arg0->mutable_ident_expr()->set_name("selector"); auto arg1 = call_expr->add_args(); arg1->mutable_ident_expr()->set_name("value1"); auto arg2 = call_expr->add_args(); arg2->mutable_ident_expr()->set_name("value1"); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; // On True, value 1 { CelValue result; ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(true), CelValue::CreateInt64(1), CelValue::CreateInt64(2), &arena, &result), IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(1)); // Unknown handling UnknownSet unknown_set; ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(true), CelValue::CreateUnknownSet(&unknown_set), CelValue::CreateInt64(2), &arena, &result), IsOk()); ASSERT_TRUE(result.IsUnknownSet()); ASSERT_THAT(RunTernaryExpression( CelValue::CreateBool(true), CelValue::CreateInt64(1), CelValue::CreateUnknownSet(&unknown_set), &arena, &result), IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(1)); } // On False, value 2 { CelValue result; ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(false), CelValue::CreateInt64(1), CelValue::CreateInt64(2), &arena, &result), IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(2)); // Unknown handling UnknownSet unknown_set; ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(false), CelValue::CreateUnknownSet(&unknown_set), CelValue::CreateInt64(2), &arena, &result), IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(2)); ASSERT_THAT(RunTernaryExpression( CelValue::CreateBool(false), CelValue::CreateInt64(1), CelValue::CreateUnknownSet(&unknown_set), &arena, &result), IsOk()); ASSERT_TRUE(result.IsUnknownSet()); } // On Error, surface error { CelValue result; ASSERT_THAT(RunTernaryExpression(CreateErrorValue(&arena, "error"), CelValue::CreateInt64(1), CelValue::CreateInt64(2), &arena, &result), IsOk()); ASSERT_TRUE(result.IsError()); } // On Unknown, surface Unknown { UnknownSet unknown_set; CelValue result; ASSERT_THAT(RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_set), CelValue::CreateInt64(1), CelValue::CreateInt64(2), &arena, &result), IsOk()); ASSERT_TRUE(result.IsUnknownSet()); EXPECT_THAT(unknown_set, Eq(*result.UnknownSetOrDie())); } // We should not merge unknowns { CelAttribute selector_attr("selector", {}); CelAttribute value1_attr("value1", {}); CelAttribute value2_attr("value2", {}); UnknownSet unknown_selector(UnknownAttributeSet({selector_attr})); UnknownSet unknown_value1(UnknownAttributeSet({value1_attr})); UnknownSet unknown_value2(UnknownAttributeSet({value2_attr})); CelValue result; ASSERT_THAT( RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_selector), CelValue::CreateUnknownSet(&unknown_value1), CelValue::CreateUnknownSet(&unknown_value2), &arena, &result), IsOk()); ASSERT_TRUE(result.IsUnknownSet()); const UnknownSet* result_set = result.UnknownSetOrDie(); EXPECT_THAT(result_set->unknown_attributes().size(), Eq(1)); EXPECT_THAT(result_set->unknown_attributes().begin()->variable_name(), Eq("selector")); } } TEST(FlatExprBuilderTest, EmptyCallList) { std::vector operators = {"_&&_", "_||_", "_?_:_"}; for (const auto& op : operators) { Expr expr; SourceInfo source_info; auto call_expr = expr.mutable_call_expr(); call_expr->set_function(op); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); auto build = builder.CreateExpression(&expr, &source_info); ASSERT_FALSE(build.ok()); } } // Note: this should not be allowed by default, but updating is a breaking // change. TEST(FlatExprBuilderTest, HeterogeneousListsAllowed) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("[17, 'seventeen']")); cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsList()) << result.DebugString(); const auto& list = *result.ListOrDie(); ASSERT_EQ(list.size(), 2); CelValue elem0 = list.Get(&arena, 0); CelValue elem1 = list.Get(&arena, 1); EXPECT_THAT(elem0, test::IsCelInt64(17)); EXPECT_THAT(elem1, test::IsCelString("seventeen")); } TEST(FlatExprBuilderTest, NullUnboxingEnabled) { TestMessage message; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("message.int32_wrapper_value")); cel::RuntimeOptions options; options.enable_empty_wrapper_null_unboxing = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_TRUE(result.IsNull()); } TEST(FlatExprBuilderTest, TypeResolve) { TestMessage message; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("type(message) == runtime.TestMessage")); cel::RuntimeOptions options; options.enable_qualified_type_identifiers = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.set_container("google.api.expr"); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()) << result.DebugString(); EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, FastEquality) { TestMessage message; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("'foo' == 'bar'")); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()) << result.DebugString(); EXPECT_FALSE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, FastEqualityFiltersBadCalls) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("'foo' == 'bar'")); parsed_expr.mutable_expr() ->mutable_call_expr() ->mutable_target() ->mutable_const_expr() ->set_string_value("foo"); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); ASSERT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr( "unexpected number of args for builtin equality operator"))); } TEST(FlatExprBuilderTest, FastInequalityFiltersBadCalls) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("'foo' != 'bar'")); parsed_expr.mutable_expr() ->mutable_call_expr() ->mutable_target() ->mutable_const_expr() ->set_string_value("foo"); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); ASSERT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr( "unexpected number of args for builtin equality operator"))); } TEST(FlatExprBuilderTest, FastInFiltersBadCalls) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("a in b")); parsed_expr.mutable_expr() ->mutable_call_expr() ->mutable_target() ->mutable_const_expr() ->set_string_value("foo"); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); ASSERT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("unexpected number of args for builtin 'in' operator"))); } TEST(FlatExprBuilderTest, IndexFiltersBadCalls) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("a[b]")); parsed_expr.mutable_expr() ->mutable_call_expr() ->mutable_target() ->mutable_const_expr() ->set_string_value("foo"); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); ASSERT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("unexpected number of args for builtin index operator"))); } // TODO(uncreated-issue/79): temporarily allow index operator with a target. TEST(FlatExprBuilderTest, IndexWithTarget) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("a[b]")); parsed_expr.mutable_expr() ->mutable_call_expr() ->mutable_target() ->mutable_ident_expr() ->set_name("a"); parsed_expr.mutable_expr() ->mutable_call_expr() ->mutable_args() ->DeleteSubrange(0, 1); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); ASSERT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), IsOk()); } TEST(FlatExprBuilderTest, NotFiltersBadCalls) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("!a")); parsed_expr.mutable_expr() ->mutable_call_expr() ->mutable_target() ->mutable_const_expr() ->set_string_value("foo"); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); ASSERT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("unexpected number of args for builtin not operator"))); } TEST(FlatExprBuilderTest, NotStrictlyFalseFiltersBadCalls) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("!a")); auto* call = parsed_expr.mutable_expr()->mutable_call_expr(); call->mutable_target()->mutable_const_expr()->set_string_value("foo"); call->set_function("@not_strictly_false"); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); ASSERT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("unexpected number of args for builtin " "not_strictly_false operator"))); } TEST(FlatExprBuilderTest, FastEqualityDisabledWithCustomEquality) { TestMessage message; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("1 == b'\001'")); cel::RuntimeOptions options; options.enable_fast_builtins = true; InterpreterOptions legacy_options; legacy_options.enable_fast_builtins = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry(), legacy_options), IsOk()); auto& registry = builder.GetRegistry()->InternalGetRegistry(); auto status = cel::BinaryFunctionAdapter:: RegisterGlobalOverload( "_==_", [](int64_t lhs, const cel::BytesValue& rhs) -> bool { return true; }, registry); ASSERT_THAT(status, IsOk()); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()) << result.DebugString(); EXPECT_TRUE(result.BoolOrDie()); } TEST(FlatExprBuilderTest, AnyPackingList) { google::protobuf::LinkMessageReflection(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("TestAllTypes{single_any: [1, 2, 3]}")); cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelMessage(EqualsProto( R"pb(single_any { [type.googleapis.com/google.protobuf.ListValue] { values { number_value: 1 } values { number_value: 2 } values { number_value: 3 } } })pb"))) << result.DebugString(); } TEST(FlatExprBuilderTest, AnyPackingNestedNumbers) { google::protobuf::LinkMessageReflection(); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("TestAllTypes{single_any: [1, 2.3]}")); cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelMessage(EqualsProto( R"pb(single_any { [type.googleapis.com/google.protobuf.ListValue] { values { number_value: 1 } values { number_value: 2.3 } } })pb"))) << result.DebugString(); } TEST(FlatExprBuilderTest, AnyPackingInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("TestAllTypes{single_any: 1}")); cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_THAT( result, test::IsCelMessage(EqualsProto( R"pb(single_any { [type.googleapis.com/google.protobuf.Int64Value] { value: 1 } })pb"))) << result.DebugString(); } TEST(FlatExprBuilderTest, AnyPackingMap) { ASSERT_OK_AND_ASSIGN( ParsedExpr parsed_expr, parser::Parse("TestAllTypes{single_any: {'key': 'value'}}")); cel::RuntimeOptions options; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelMessage(EqualsProto( R"pb(single_any { [type.googleapis.com/google.protobuf.Struct] { fields { key: "key" value { string_value: "value" } } } })pb"))) << result.DebugString(); } TEST(FlatExprBuilderTest, NullUnboxingDisabled) { TestMessage message; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("message.int32_wrapper_value")); cel::RuntimeOptions options; options.enable_empty_wrapper_null_unboxing = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelInt64(0)); } TEST(FlatExprBuilderTest, HeterogeneousEqualityEnabled) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("{1: 2, 2u: 3}[1.0]")); cel::RuntimeOptions options; options.enable_heterogeneous_equality = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelInt64(2)); } TEST(FlatExprBuilderTest, HeterogeneousEqualityDisabled) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("{1: 2, 2u: 3}[1.0]")); cel::RuntimeOptions options; options.enable_heterogeneous_equality = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid map key type")))); } std::pair CreateTestMessage( const google::protobuf::DescriptorPool& descriptor_pool, google::protobuf::MessageFactory& message_factory, absl::string_view name) { const google::protobuf::Descriptor* desc = descriptor_pool.FindMessageTypeByName(name); const google::protobuf::Message* message_prototype = message_factory.GetPrototype(desc); google::protobuf::Message* message = message_prototype->New(); const google::protobuf::Reflection* refl = message->GetReflection(); return std::make_pair(message, refl); } struct CustomDescriptorPoolTestParam final { using SetterFunction = std::function; std::string message_type; std::string field_name; SetterFunction setter; test::CelValueMatcher matcher; }; class CustomDescriptorPoolTest : public ::testing::TestWithParam {}; // This test in particular checks for conversion errors in cel_proto_wrapper.cc. TEST_P(CustomDescriptorPoolTest, TestType) { const CustomDescriptorPoolTestParam& p = GetParam(); google::protobuf::DescriptorPool descriptor_pool; google::protobuf::Arena arena; // Setup descriptor pool and builder ASSERT_THAT(AddStandardMessageTypesToDescriptorPool(descriptor_pool), IsOk()); google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); // Create test subject, invoke custom setter for message auto [message, reflection] = CreateTestMessage(descriptor_pool, message_factory, p.message_type); const google::protobuf::FieldDescriptor* field = message->GetDescriptor()->FindFieldByName(p.field_name); p.setter(message, reflection, field); ASSERT_OK_AND_ASSIGN(std::unique_ptr expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); // Evaluate expression, verify expectation with custom matcher Activation activation; activation.InsertValue("m", CelProtoWrapper::CreateMessage(message, &arena)); ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); EXPECT_THAT(result, p.matcher); delete message; } INSTANTIATE_TEST_SUITE_P( ValueTypes, CustomDescriptorPoolTest, ::testing::ValuesIn(std::vector{ {"google.protobuf.Duration", "seconds", [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) { reflection->SetInt64(message, field, 10); }, test::IsCelDuration(absl::Seconds(10))}, {"google.protobuf.DoubleValue", "value", [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) { reflection->SetDouble(message, field, 1.2); }, test::IsCelDouble(1.2)}, {"google.protobuf.Int64Value", "value", [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) { reflection->SetInt64(message, field, -23); }, test::IsCelInt64(-23)}, {"google.protobuf.UInt64Value", "value", [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) { reflection->SetUInt64(message, field, 42); }, test::IsCelUint64(42)}, {"google.protobuf.BoolValue", "value", [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) { reflection->SetBool(message, field, true); }, test::IsCelBool(true)}, {"google.protobuf.StringValue", "value", [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) { reflection->SetString(message, field, "foo"); }, test::IsCelString("foo")}, {"google.protobuf.BytesValue", "value", [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) { reflection->SetString(message, field, "bar"); }, test::IsCelBytes("bar")}, {"google.protobuf.Timestamp", "seconds", [](google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field) { reflection->SetInt64(message, field, 20); }, test::IsCelTimestamp(absl::FromUnixSeconds(20))}})); struct ConstantFoldingTestCase { std::string test_name; std::string expr; test::CelValueMatcher matcher; absl::flat_hash_map values; }; class UnknownFunctionImpl : public cel::Function { absl::StatusOr Invoke(absl::Span args, const InvokeContext& context) const override { return cel::UnknownValue(); } }; absl::StatusOr> CreateConstantFoldingConformanceTestExprBuilder( const InterpreterOptions& options) { auto builder = google::api::expr::runtime::CreateCelExpressionBuilder(options); CEL_RETURN_IF_ERROR( RegisterBuiltinFunctions(builder->GetRegistry(), options)); CEL_RETURN_IF_ERROR(builder->GetRegistry()->RegisterLazyFunction( cel::FunctionDescriptor("LazyFunction", false, {}))); CEL_RETURN_IF_ERROR(builder->GetRegistry()->RegisterLazyFunction( cel::FunctionDescriptor("LazyFunction", false, {cel::Kind::kBool}))); CEL_RETURN_IF_ERROR(builder->GetRegistry()->Register( cel::FunctionDescriptor("UnknownFunction", false, {}), std::make_unique())); return builder; } class ConstantFoldingConformanceTest : public ::testing::TestWithParam { protected: google::protobuf::Arena arena_; }; TEST_P(ConstantFoldingConformanceTest, Updated) { InterpreterOptions options; options.constant_folding = true; options.constant_arena = &arena_; // Check interaction between const folding and list append optimizations. options.enable_comprehension_list_append = true; const ConstantFoldingTestCase& p = GetParam(); ASSERT_OK_AND_ASSIGN( auto builder, CreateConstantFoldingConformanceTestExprBuilder(options)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(p.expr)); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation activation; ASSERT_OK(activation.InsertFunction( PortableUnaryFunctionAdapter::Create( "LazyFunction", false, [](google::protobuf::Arena* arena, bool val) { return val; }))); for (auto iter = p.values.begin(); iter != p.values.end(); ++iter) { activation.InsertValue(iter->first, CelValue::CreateInt64(iter->second)); } ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena_)); // Check that none of the memoized constants are being mutated. ASSERT_OK_AND_ASSIGN(result, plan->Evaluate(activation, &arena_)); EXPECT_THAT(result, p.matcher); } INSTANTIATE_TEST_SUITE_P( Exprs, ConstantFoldingConformanceTest, ::testing::ValuesIn(std::vector{ {"simple_add", "1 + 2 + 3", test::IsCelInt64(6)}, {"add_with_var", "1 + (2 + (3 + id))", test::IsCelInt64(10), {{"id", 4}}}, {"const_list", "[1, 2, 3, 4]", test::IsCelList(_)}, {"mixed_const_list", "[1, 2, 3, 4] + [id]", test::IsCelList(_), {{"id", 5}}}, {"create_struct", "{'abc': 'def', 'def': 'efg', 'efg': 'hij'}", Truly([](const CelValue& v) { return v.IsMap(); })}, {"field_selection", "{'abc': 123}.abc == 123", test::IsCelBool(true)}, {"type_coverage", // coverage for constant literals, type() is used to make the list // homogenous. R"cel( [type(bool), type(123), type(123u), type(12.3), type(b'123'), type('123'), type(null), type(timestamp(0)), type(duration('1h')) ])cel", test::IsCelList(SizeIs(9))}, {"lazy_function", "true || LazyFunction()", test::IsCelBool(true)}, {"lazy_function_called", "LazyFunction(true) || false", test::IsCelBool(true)}, {"unknown_function", "UnknownFunction() && false", test::IsCelBool(false)}, {"nested_comprehension", "[1, 2, 3, 4].all(x, [5, 6, 7, 8].all(y, x < y))", test::IsCelBool(true)}, // Implementation detail: map and filter use replace the accu_init // expr with a special mutable list to avoid quadratic memory usage // building the projected list. {"map", "[1, 2, 3, 4].map(x, x * 2).size() == 4", test::IsCelBool(true)}, {"str_cat", "'1234567890' + '1234567890' + '1234567890' + '1234567890' + " "'1234567890'", test::IsCelString( "12345678901234567890123456789012345678901234567890")}})); // Check that list literals are pre-computed TEST(UpdatedConstantFolding, FoldsLists) { InterpreterOptions options; google::protobuf::Arena arena; options.constant_folding = true; options.constant_arena = &arena; ASSERT_OK_AND_ASSIGN( auto builder, CreateConstantFoldingConformanceTestExprBuilder(options)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1] + [2] + [3] + [4] + [5] + [6] + [7] " "+ [8] + [9] + [10] + [11] + [12]")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation activation; int before_size = arena.SpaceUsed(); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); // Some incidental allocations are expected related to interop. // 128 is less than the expected allocations for allocating the list terms and // any intermediates in the unoptimized case. EXPECT_LE(arena.SpaceUsed() - before_size, 512); EXPECT_THAT(result, test::IsCelList(SizeIs(12))); } TEST(FlatExprBuilderTest, BlockBadIndex) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" args { list_expr: { elements { const_expr: { string_value: "foo" } } } } args { ident_expr: { name: "@index-1" } } } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("bad @index"))); } TEST(FlatExprBuilderTest, OutOfRangeBlockIndex) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" args { list_expr: { elements { const_expr: { string_value: "foo" } } } } args { ident_expr: { name: "@index1" } } } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid @index greater than number of bindings:"))); } TEST(FlatExprBuilderTest, EarlyBlockIndex) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" args { list_expr: { elements { ident_expr: { name: "@index0" } } } } args { ident_expr: { name: "@index0" } } } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("@index references current or future binding:"))); } TEST(FlatExprBuilderTest, OutOfScopeCSE) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { ident_expr: { name: "@ac:0:0" } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("out of scope reference to CSE generated " "comprehension variable"))); } TEST(FlatExprBuilderTest, BlockMissingBindings) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr( "malformed cel.@block: missing list of bound expressions"))); } TEST(FlatExprBuilderTest, BlockMissingExpression) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" args { list_expr: {} } } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("malformed cel.@block: missing bound expression"))); } TEST(FlatExprBuilderTest, BlockNotListOfBoundExpressions) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" args { ident_expr: { name: "@index0" } } args { ident_expr: { name: "@index0" } } } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("malformed cel.@block: first argument is not a list " "of bound expressions"))); } TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" args { list_expr: {} } args { ident_expr: { name: "@index0" } } } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr( "malformed cel.@block: list of bound expressions is empty"))); } TEST(FlatExprBuilderTest, BlockOptionalListOfBoundExpressions) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" args { list_expr: { elements { const_expr: { string_value: "foo" } } optional_indices: [ 0 ] } } args { ident_expr: { name: "@index0" } } } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("malformed cel.@block: list of bound expressions " "contains an optional"))); } TEST(FlatExprBuilderTest, BlockNested) { ParsedExpr parsed_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { call_expr: { function: "cel.@block" args { list_expr: { elements { const_expr: { string_value: "foo" } } } } args { call_expr: { function: "cel.@block" args { list_expr: { elements { const_expr: { string_value: "foo" } } } } args { ident_expr: { name: "@index1" } } } } } } )pb", &parsed_expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("multiple cel.@block are not allowed"))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/instrumentation.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/instrumentation.h" #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/ast.h" #include "common/expr.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" namespace google::api::expr::runtime { namespace { class InstrumentStep : public ExpressionStepBase { public: explicit InstrumentStep(int64_t expr_id, Instrumentation instrumentation) : ExpressionStepBase(/*expr_id=*/expr_id, /*comes_from_ast=*/false), expr_id_(expr_id), instrumentation_(std::move(instrumentation)) {} absl::Status Evaluate(ExecutionFrame* frame) const override { if (!frame->value_stack().HasEnough(1)) { return absl::InternalError("stack underflow in instrument step."); } return instrumentation_(expr_id_, frame->value_stack().Peek()); return absl::OkStatus(); } private: int64_t expr_id_; Instrumentation instrumentation_; }; class InstrumentOptimizer : public ProgramOptimizer { public: explicit InstrumentOptimizer(Instrumentation instrumentation) : instrumentation_(std::move(instrumentation)) {} absl::Status OnPreVisit(PlannerContext& context, const cel::Expr& node) override { return absl::OkStatus(); } absl::Status OnPostVisit(PlannerContext& context, const cel::Expr& node) override { if (context.GetSubplan(node).empty()) { return absl::OkStatus(); } return context.AddSubplanStep( node, std::make_unique(node.id(), instrumentation_)); } private: Instrumentation instrumentation_; }; } // namespace ProgramOptimizerFactory CreateInstrumentationExtension( InstrumentationFactory factory) { return [fac = std::move(factory)](PlannerContext&, const cel::Ast& ast) -> absl::StatusOr> { Instrumentation ins = fac(ast); if (ins) { return std::make_unique(std::move(ins)); } return nullptr; }; } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/instrumentation.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Definitions for instrumenting a CEL expression at the planner level. // // CEL users should not use this directly. #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_INSTRUMENTATION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_INSTRUMENTATION_H_ #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "common/ast.h" #include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" namespace google::api::expr::runtime { // Instrumentation inspects intermediate values after the evaluation of an // expression node. // // Unlike traceable expressions, this callback is applied across all // evaluations of an expression. Implementations must be thread safe if the // expression is evaluated concurrently. using Instrumentation = std::function; // A factory for creating Instrumentation instances. // // This allows the extension implementations to map from a given ast to a // specific instrumentation instance. // // An empty function object may be returned to skip instrumenting the given // expression. using InstrumentationFactory = absl::AnyInvocable; // Create a new Instrumentation extension. // // These should typically be added last if any program optimizations are // applied. ProgramOptimizerFactory CreateInstrumentationExtension( InstrumentationFactory factory); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_INSTRUMENTATION_H_ ================================================ FILE: eval/compiler/instrumentation_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/instrumentation.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "common/ast.h" #include "common/value.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/regex_precompilation_optimization.h" #include "eval/eval/evaluator_core.h" #include "extensions/protobuf/ast_converters.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/function_registry.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "runtime/standard_functions.h" #include "runtime/type_registry.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::cel::IntValue; using ::cel::Value; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; using ::testing::Pair; using ::testing::UnorderedElementsAre; class InstrumentationTest : public ::testing::Test { public: InstrumentationTest() : env_(NewTestingRuntimeEnv()), function_registry_(env_->function_registry), type_registry_(env_->type_registry) {} void SetUp() override { ASSERT_OK(cel::RegisterStandardFunctions(function_registry_, options_)); } protected: absl_nonnull std::shared_ptr env_; cel::RuntimeOptions options_; cel::FunctionRegistry& function_registry_; cel::TypeRegistry& type_registry_; google::protobuf::Arena arena_; }; MATCHER_P(IsIntValue, expected, "") { const Value& got = arg; return got.Is() && got.GetInt().NativeValue() == expected; } TEST_F(InstrumentationTest, Basic) { FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { expr_ids.push_back(expr_id); return absl::OkStatus(); }; builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("1 + 2 + 3")); ASSERT_OK_AND_ASSIGN(auto ast, cel::extensions::CreateAstFromParsedExpr(expr)); ASSERT_OK_AND_ASSIGN(auto plan, builder.CreateExpressionImpl(std::move(ast), /*issues=*/nullptr)); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); cel::Activation activation; ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); // AST for the test expression: // + <4> // / \ // +<2> 3<5> // / \ // 1<1> 2<3> EXPECT_THAT(expr_ids, ElementsAre(1, 3, 2, 5, 4)); } TEST_F(InstrumentationTest, BasicWithConstFolding) { FlatExprBuilder builder(env_, options_); absl::flat_hash_map expr_id_to_value; Instrumentation expr_id_recorder = [&expr_id_to_value]( int64_t expr_id, const cel::Value& v) -> absl::Status { expr_id_to_value[expr_id] = v; return absl::OkStatus(); }; builder.AddProgramOptimizer( cel::runtime_internal::CreateConstantFoldingOptimizer()); builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("1 + 2 + 3")); ASSERT_OK_AND_ASSIGN(auto ast, cel::extensions::CreateAstFromParsedExpr(expr)); ASSERT_OK_AND_ASSIGN(auto plan, builder.CreateExpressionImpl(std::move(ast), /*issues=*/nullptr)); EXPECT_THAT( expr_id_to_value, UnorderedElementsAre(Pair(1, IsIntValue(1)), Pair(3, IsIntValue(2)), Pair(2, IsIntValue(3)), Pair(5, IsIntValue(3)))); expr_id_to_value.clear(); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); cel::Activation activation; ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); // AST for the test expression: // + <4> // / \ // +<2> 3<5> // / \ // 1<1> 2<3> EXPECT_THAT(expr_id_to_value, UnorderedElementsAre(Pair(4, IsIntValue(6)))); } TEST_F(InstrumentationTest, AndShortCircuit) { FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { expr_ids.push_back(expr_id); return absl::OkStatus(); }; builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("a && b")); ASSERT_OK_AND_ASSIGN(auto ast, cel::extensions::CreateAstFromParsedExpr(expr)); ASSERT_OK_AND_ASSIGN(auto plan, builder.CreateExpressionImpl(std::move(ast), /*issues=*/nullptr)); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); cel::Activation activation; activation.InsertOrAssignValue("a", cel::BoolValue(true)); activation.InsertOrAssignValue("b", cel::BoolValue(false)); ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); EXPECT_THAT(expr_ids, ElementsAre(1, 2, 3)); activation.InsertOrAssignValue("a", cel::BoolValue(false)); ASSERT_OK_AND_ASSIGN( value, plan.EvaluateWithCallback(activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); EXPECT_THAT(expr_ids, ElementsAre(1, 2, 3, 1, 3)); } TEST_F(InstrumentationTest, OrShortCircuit) { FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { expr_ids.push_back(expr_id); return absl::OkStatus(); }; builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("a || b")); ASSERT_OK_AND_ASSIGN(auto ast, cel::extensions::CreateAstFromParsedExpr(expr)); ASSERT_OK_AND_ASSIGN(auto plan, builder.CreateExpressionImpl(std::move(ast), /*issues=*/nullptr)); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); cel::Activation activation; activation.InsertOrAssignValue("a", cel::BoolValue(false)); activation.InsertOrAssignValue("b", cel::BoolValue(true)); ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); EXPECT_THAT(expr_ids, ElementsAre(1, 2, 3)); expr_ids.clear(); activation.InsertOrAssignValue("a", cel::BoolValue(true)); ASSERT_OK_AND_ASSIGN( value, plan.EvaluateWithCallback(activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); EXPECT_THAT(expr_ids, ElementsAre(1, 3)); } TEST_F(InstrumentationTest, Ternary) { FlatExprBuilder builder(env_, options_); std::vector expr_ids; Instrumentation expr_id_recorder = [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { expr_ids.push_back(expr_id); return absl::OkStatus(); }; builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("(c)? a : b")); ASSERT_OK_AND_ASSIGN(auto ast, cel::extensions::CreateAstFromParsedExpr(expr)); ASSERT_OK_AND_ASSIGN(auto plan, builder.CreateExpressionImpl(std::move(ast), /*issues=*/nullptr)); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); cel::Activation activation; activation.InsertOrAssignValue("c", cel::BoolValue(true)); activation.InsertOrAssignValue("a", cel::IntValue(1)); activation.InsertOrAssignValue("b", cel::IntValue(2)); ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); // AST // ?:() <2> // / | \ // c <1> a <3> b <4> EXPECT_THAT(expr_ids, ElementsAre(1, 3, 2)); expr_ids.clear(); activation.InsertOrAssignValue("c", cel::BoolValue(false)); ASSERT_OK_AND_ASSIGN( value, plan.EvaluateWithCallback(activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); EXPECT_THAT(expr_ids, ElementsAre(1, 4, 2)); expr_ids.clear(); } TEST_F(InstrumentationTest, OptimizedStepsNotEvaluated) { FlatExprBuilder builder(env_, options_); builder.AddProgramOptimizer(CreateRegexPrecompilationExtension(0)); std::vector expr_ids; Instrumentation expr_id_recorder = [&expr_ids](int64_t expr_id, const cel::Value&) -> absl::Status { expr_ids.push_back(expr_id); return absl::OkStatus(); }; builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("r'test_string'.matches(r'[a-z_]+')")); ASSERT_OK_AND_ASSIGN(auto ast, cel::extensions::CreateAstFromParsedExpr(expr)); ASSERT_OK_AND_ASSIGN(auto plan, builder.CreateExpressionImpl(std::move(ast), /*issues=*/nullptr)); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); cel::Activation activation; ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); EXPECT_THAT(expr_ids, ElementsAre(1, 2)); EXPECT_TRUE(value.Is() && value.GetBool().NativeValue()); } TEST_F(InstrumentationTest, NoopSkipped) { FlatExprBuilder builder(env_, options_); builder.AddProgramOptimizer(CreateInstrumentationExtension( [=](const cel::Ast&) -> Instrumentation { return Instrumentation(); })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("(c)? a : b")); ASSERT_OK_AND_ASSIGN(auto ast, cel::extensions::CreateAstFromParsedExpr(expr)); ASSERT_OK_AND_ASSIGN(auto plan, builder.CreateExpressionImpl(std::move(ast), /*issues=*/nullptr)); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); cel::Activation activation; activation.InsertOrAssignValue("c", cel::BoolValue(true)); activation.InsertOrAssignValue("a", cel::IntValue(1)); activation.InsertOrAssignValue("b", cel::IntValue(2)); ASSERT_OK_AND_ASSIGN(auto value, plan.EvaluateWithCallback( activation, /*embedder_context=*/nullptr, EvaluationListener(), state)); // AST // ?:() <2> // / | \ // c <1> a <3> b <4> EXPECT_THAT(value, IsIntValue(1)); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/qualified_reference_resolver.cc ================================================ // Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/qualified_reference_resolver.h" #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/ast.h" #include "base/builtins.h" #include "common/ast.h" #include "common/ast_rewrite.h" #include "common/expr.h" #include "common/kind.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "runtime/internal/issue_collector.h" #include "runtime/runtime_issue.h" namespace google::api::expr::runtime { namespace { using ::cel::Expr; using ::cel::Reference; using ::cel::RuntimeIssue; using ::cel::runtime_internal::IssueCollector; // Optional types are opt-in but require special handling in the evaluator. constexpr absl::string_view kOptionalOr = "or"; constexpr absl::string_view kOptionalOrValue = "orValue"; // Determines if function is implemented with custom evaluation step instead of // registered. bool IsSpecialFunction(absl::string_view function_name) { return function_name == cel::builtin::kAnd || function_name == cel::builtin::kOr || function_name == cel::builtin::kIndex || function_name == cel::builtin::kTernary || function_name == kOptionalOr || function_name == kOptionalOrValue || function_name == cel::builtin::kEqual || function_name == cel::builtin::kInequal || function_name == cel::builtin::kNot || function_name == cel::builtin::kNotStrictlyFalse || function_name == cel::builtin::kNotStrictlyFalseDeprecated || function_name == cel::builtin::kIn || function_name == cel::builtin::kInDeprecated || function_name == cel::builtin::kInFunction || function_name == "cel.@block"; } bool OverloadExists(const Resolver& resolver, absl::string_view name, const std::vector& arguments_matcher, bool receiver_style = false) { return !resolver.FindOverloads(name, receiver_style, arguments_matcher) .empty() || !resolver.FindLazyOverloads(name, receiver_style, arguments_matcher) .empty(); } // Return the qualified name of the most qualified matching overload, or // nullopt if no matches are found. absl::optional BestOverloadMatch(const Resolver& resolver, absl::string_view base_name, int argument_count) { if (IsSpecialFunction(base_name)) { return std::string(base_name); } auto arguments_matcher = ArgumentsMatcher(argument_count); // Check from most qualified to least qualified for a matching overload. auto names = resolver.FullyQualifiedNames(base_name); for (auto name = names.begin(); name != names.end(); ++name) { if (OverloadExists(resolver, *name, arguments_matcher)) { if (base_name[0] == '.') { // Preserve leading '.' to prevent re-resolving at plan time. return std::string(base_name); } return *name; } } return absl::nullopt; } // Rewriter visitor for resolving references. // // On previsit pass, replace (possibly qualified) identifier branches with the // canonical name in the reference map (most qualified references considered // first). // // On post visit pass, update function calls to determine whether the function // target is a namespace for the function or a receiver for the call. class ReferenceResolver : public cel::AstRewriterBase { public: ReferenceResolver( const absl::flat_hash_map& reference_map, const Resolver& resolver, IssueCollector& issue_collector) : reference_map_(reference_map), resolver_(resolver), issues_(issue_collector), progress_status_(absl::OkStatus()) {} // Attempt to resolve references in expr. Return true if part of the // expression was rewritten. // TODO(issues/95): If possible, it would be nice to write a general utility // for running the preprocess steps when traversing the AST instead of having // one pass per transform. bool PreVisitRewrite(Expr& expr) override { const Reference* reference = GetReferenceForId(expr.id()); // Fold compile time constant (e.g. enum values) if (reference != nullptr && reference->has_value()) { if (reference->value().has_int64_value()) { // Replace enum idents with const reference value. expr.mutable_const_expr().set_int64_value( reference->value().int64_value()); return true; } else if (expr.has_ident_expr()) { // "google.protobuf.NullValue.NULL_VALUE" is a special case: sometimes // it is interpreted as null value and sometimes as an enum constant. if (reference->value().has_null_value() && expr.ident_expr().name() == "google.protobuf.NullValue.NULL_VALUE") { return false; } expr.set_const_expr(reference->value()); return true; } else { return false; } } if (reference != nullptr) { if (expr.has_ident_expr()) { return MaybeUpdateIdentNode(&expr, *reference); } else if (expr.has_select_expr()) { return MaybeUpdateSelectNode(&expr, *reference); } else { // Call nodes are updated on post visit so they will see any select // path rewrites. return false; } } return false; } bool PostVisitRewrite(Expr& expr) override { const Reference* reference = GetReferenceForId(expr.id()); if (expr.has_call_expr()) { return MaybeUpdateCallNode(&expr, reference); } return false; } const absl::Status& GetProgressStatus() const { return progress_status_; } private: // Attempt to update a function call node. This disambiguates // receiver call verses namespaced names in parse if possible. // // TODO(issues/95): This duplicates some of the overload matching behavior // for parsed expressions. We should refactor to consolidate the code. bool MaybeUpdateCallNode(Expr* out, const Reference* reference) { auto& call_expr = out->mutable_call_expr(); const std::string& function = call_expr.function(); if (reference != nullptr && reference->overload_id().empty()) { UpdateStatus(issues_.AddIssue( RuntimeIssue::CreateWarning(absl::InvalidArgumentError( absl::StrCat("Reference map doesn't provide overloads for ", out->call_expr().function()))))); } bool receiver_style = call_expr.has_target(); int arg_num = call_expr.args().size(); if (receiver_style) { auto maybe_namespace = ToNamespace(call_expr.target()); if (maybe_namespace.has_value()) { std::string resolved_name = absl::StrCat(*maybe_namespace, ".", function); auto resolved_function = BestOverloadMatch(resolver_, resolved_name, arg_num); if (resolved_function.has_value()) { call_expr.set_function(*resolved_function); call_expr.set_target(nullptr); return true; } } } else { // Not a receiver style function call. Check to see if it is a namespaced // function using a shorthand inside the expression container. auto maybe_resolved_function = BestOverloadMatch(resolver_, function, arg_num); if (!maybe_resolved_function.has_value()) { UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( absl::InvalidArgumentError(absl::StrCat( "No overload found in reference resolve step for ", function)), RuntimeIssue::ErrorCode::kNoMatchingOverload))); } else if (maybe_resolved_function.value() != function) { call_expr.set_function(maybe_resolved_function.value()); return true; } } // For parity, if we didn't rewrite the receiver call style function, // check that an overload is provided in the builder. if (call_expr.has_target() && !IsSpecialFunction(function) && !OverloadExists(resolver_, function, ArgumentsMatcher(arg_num + 1), /* receiver_style= */ true)) { UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( absl::InvalidArgumentError(absl::StrCat( "No overload found in reference resolve step for ", function)), RuntimeIssue::ErrorCode::kNoMatchingOverload))); } return false; } // Attempt to resolve a select node. If reference is valid, // replace the select node with the fully qualified ident node. bool MaybeUpdateSelectNode(Expr* out, const Reference& reference) { if (out->select_expr().test_only()) { UpdateStatus(issues_.AddIssue(RuntimeIssue::CreateWarning( absl::InvalidArgumentError("Reference map points to a presence " "test -- has(container.attr)")))); } else if (!reference.name().empty()) { out->mutable_ident_expr().set_name(reference.name()); rewritten_reference_.insert(out->id()); return true; } return false; } // Attempt to resolve an ident node. If reference is valid, // replace the node with the fully qualified ident node. bool MaybeUpdateIdentNode(Expr* out, const Reference& reference) { if (!reference.name().empty() && reference.name() != out->ident_expr().name()) { out->mutable_ident_expr().set_name(reference.name()); rewritten_reference_.insert(out->id()); return true; } return false; } // Convert a select expr sub tree into a namespace name if possible. // If any operand of the top element is a not a select or an ident node, // return nullopt. absl::optional ToNamespace(const Expr& expr) { absl::optional maybe_parent_namespace; if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { // The target expr matches a reference (resolved to an ident decl). // This should not be treated as a function qualifier. return absl::nullopt; } if (expr.has_ident_expr()) { return expr.ident_expr().name(); } else if (expr.has_select_expr()) { if (expr.select_expr().test_only()) { return absl::nullopt; } maybe_parent_namespace = ToNamespace(expr.select_expr().operand()); if (!maybe_parent_namespace.has_value()) { return absl::nullopt; } return absl::StrCat(*maybe_parent_namespace, ".", expr.select_expr().field()); } else { return absl::nullopt; } } // Find a reference for the given expr id. // // Returns nullptr if no reference is available. const Reference* GetReferenceForId(int64_t expr_id) { auto iter = reference_map_.find(expr_id); if (iter == reference_map_.end()) { return nullptr; } if (expr_id == 0) { UpdateStatus(issues_.AddIssue( RuntimeIssue::CreateWarning(absl::InvalidArgumentError( "reference map entries for expression id 0 are not supported")))); return nullptr; } return &iter->second; } void UpdateStatus(absl::Status status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = std::move(status); return; } status.IgnoreError(); } const absl::flat_hash_map& reference_map_; const Resolver& resolver_; IssueCollector& issues_; absl::Status progress_status_; absl::flat_hash_set rewritten_reference_; }; class ReferenceResolverExtension : public AstTransform { public: explicit ReferenceResolverExtension(ReferenceResolverOption opt) : opt_(opt) {} absl::Status UpdateAst(PlannerContext& context, cel::Ast& ast) const override { if (opt_ == ReferenceResolverOption::kCheckedOnly && ast.reference_map().empty()) { return absl::OkStatus(); } return ResolveReferences(context.resolver(), context.issue_collector(), ast) .status(); } private: ReferenceResolverOption opt_; }; } // namespace absl::StatusOr ResolveReferences(const Resolver& resolver, IssueCollector& issues, cel::Ast& ast) { ReferenceResolver ref_resolver(ast.reference_map(), resolver, issues); // Rewriting interface doesn't support failing mid traverse propagate first // error encountered if fail fast enabled. bool was_rewritten = cel::AstRewrite(ast.mutable_root_expr(), ref_resolver); if (!ref_resolver.GetProgressStatus().ok()) { return ref_resolver.GetProgressStatus(); } return was_rewritten; } std::unique_ptr NewReferenceResolverExtension( ReferenceResolverOption option) { return std::make_unique(option); } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/qualified_reference_resolver.h ================================================ // Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_QUALIFIED_REFERENCE_RESOLVER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_QUALIFIED_REFERENCE_RESOLVER_H_ #include #include "absl/status/statusor.h" #include "common/ast.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "runtime/internal/issue_collector.h" namespace google::api::expr::runtime { // Resolves possibly qualified names in the provided expression, updating // subexpressions with to use the fully qualified name, or a constant // expressions in the case of enums. // // Returns true if updates were applied. // // Will warn or return a non-ok status if references can't be resolved (no // function overload could match a call) or are inconsistent (reference map // points to an expr node that isn't a reference). absl::StatusOr ResolveReferences( const Resolver& resolver, cel::runtime_internal::IssueCollector& issues, cel::Ast& ast); enum class ReferenceResolverOption { // Always attempt to resolve references based on runtime types and functions. kAlways, // Only attempt to resolve for checked expressions with reference metadata. kCheckedOnly, }; std::unique_ptr NewReferenceResolverExtension( ReferenceResolverOption option); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_QUALIFIED_REFERENCE_RESOLVER_H_ ================================================ FILE: eval/compiler/qualified_reference_resolver_test.cc ================================================ // Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/qualified_reference_resolver.h" #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "base/ast.h" #include "base/builtins.h" #include "common/ast.h" #include "common/ast/expr_proto.h" #include "common/expr.h" #include "eval/compiler/resolver.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "extensions/protobuf/ast_converters.h" #include "internal/proto_matchers.h" #include "internal/testing.h" #include "runtime/internal/issue_collector.h" #include "runtime/runtime_issue.h" #include "runtime/type_registry.h" #include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::Ast; using ::cel::Expr; using ::cel::RuntimeIssue; using ::cel::SourceInfo; using ::cel::ast_internal::ExprToProto; using ::cel::internal::test::EqualsProto; using ::cel::runtime_internal::IssueCollector; using ::testing::Contains; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; // foo.bar.var1 && bar.foo.var2 constexpr char kExpr[] = R"( id: 1 call_expr { function: "_&&_" args { id: 2 select_expr { field: "var1" operand { id: 3 select_expr { field: "bar" operand { id: 4 ident_expr { name: "foo" } } } } } } args { id: 5 select_expr { field: "var2" operand { id: 6 select_expr { field: "foo" operand { id: 7 ident_expr { name: "bar" } } } } } } } )"; MATCHER_P(StatusCodeIs, x, "") { const absl::Status& status = arg; return status.code() == x; } std::unique_ptr ParseTestProto(const std::string& pb) { cel::expr::Expr expr; EXPECT_TRUE(google::protobuf::TextFormat::ParseFromString(pb, &expr)); return cel::extensions::CreateAstFromParsedExpr(expr).value(); } std::vector ExtractIssuesStatus(const IssueCollector& issues) { std::vector issues_status; for (const auto& issue : issues.issues()) { issues_status.push_back(issue.ToStatus()); } return issues_status; } cel::expr::Expr ExprToProtoOrDie(const Expr& expr) { cel::expr::Expr expr_proto; ABSL_CHECK_OK(ExprToProto(expr, &expr_proto)); return expr_proto; } TEST(ResolveReferences, Basic) { std::unique_ptr expr_ast = ParseTestProto(kExpr); expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); expr_ast->mutable_reference_map()[5].set_name("bar.foo.var2"); IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_&&_" args { id: 2 ident_expr { name: "foo.bar.var1" } } args { id: 5 ident_expr { name: "bar.foo.var2" } } })pb")); } TEST(ResolveReferences, ReturnsFalseIfNoChanges) { std::unique_ptr expr_ast = ParseTestProto(kExpr); IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); // reference to the same name also doesn't count as a rewrite. expr_ast->mutable_reference_map()[4].set_name("foo"); expr_ast->mutable_reference_map()[7].set_name("bar"); result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); } TEST(ResolveReferences, NamespacedIdent) { std::unique_ptr expr_ast = ParseTestProto(kExpr); SourceInfo source_info; IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); expr_ast->mutable_reference_map()[7].set_name("namespace_x.bar"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_&&_" args { id: 2 ident_expr { name: "foo.bar.var1" } } args { id: 5 select_expr { field: "var2" operand { id: 6 select_expr { field: "foo" operand { id: 7 ident_expr { name: "namespace_x.bar" } } } } } } })pb")); } TEST(ResolveReferences, WarningOnPresenceTest) { std::unique_ptr expr_ast = ParseTestProto(R"pb( id: 1 select_expr { field: "var1" test_only: true operand { id: 2 select_expr { field: "bar" operand { id: 3 ident_expr { name: "foo" } } } } })pb"); SourceInfo source_info; IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[1].set_name("foo.bar.var1"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT( ExtractIssuesStatus(issues), testing::ElementsAre(Eq(absl::Status( absl::StatusCode::kInvalidArgument, "Reference map points to a presence test -- has(container.attr)")))); } // foo.bar.var1 == bar.foo.Enum.ENUM_VAL1 constexpr char kEnumExpr[] = R"( id: 1 call_expr { function: "_==_" args { id: 2 select_expr { field: "var1" operand { id: 3 select_expr { field: "bar" operand { id: 4 ident_expr { name: "foo" } } } } } } args { id: 5 ident_expr { name: "bar.foo.Enum.ENUM_VAL1" } } } )"; TEST(ResolveReferences, EnumConstReferenceUsed) { std::unique_ptr expr_ast = ParseTestProto(kEnumExpr); SourceInfo source_info; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); expr_ast->mutable_reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); expr_ast->mutable_reference_map()[5].mutable_value().set_int64_value(9); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_==_" args { id: 2 ident_expr { name: "foo.bar.var1" } } args { id: 5 const_expr { int64_value: 9 } } })pb")); } TEST(ResolveReferences, EnumConstReferenceUsedSelect) { std::unique_ptr expr_ast = ParseTestProto(kEnumExpr); SourceInfo source_info; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); expr_ast->mutable_reference_map()[2].mutable_value().set_int64_value(2); expr_ast->mutable_reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); expr_ast->mutable_reference_map()[5].mutable_value().set_int64_value(9); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_==_" args { id: 2 const_expr { int64_value: 2 } } args { id: 5 const_expr { int64_value: 9 } } })pb")); } // foo && bar constexpr char kConstReferenceExpr[] = R"( id: 1 call_expr { function: "_&&_" args { id: 2 ident_expr { name: "foo" } } args { id: 5 ident_expr { name: "bar" } } } )"; TEST(ResolveReferences, ConstReferenceFolded) { std::unique_ptr expr_ast = ParseTestProto(kConstReferenceExpr); SourceInfo source_info; CelFunctionRegistry func_registry; ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[2].set_name("foo"); expr_ast->mutable_reference_map()[2].mutable_value().set_bool_value(true); expr_ast->mutable_reference_map()[5].set_name("bar"); expr_ast->mutable_reference_map()[5].mutable_value().set_bool_value(false); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_&&_" args { id: 2 const_expr { bool_value: true } } args { id: 5 const_expr { bool_value: false } } })pb")); } TEST(ResolveReferences, ConstReferenceSkipped) { std::unique_ptr expr_ast = ParseTestProto(kExpr); SourceInfo source_info; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); expr_ast->mutable_reference_map()[2].mutable_value().set_bool_value(true); expr_ast->mutable_reference_map()[5].set_name("bar.foo.var2"); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "_&&_" args { id: 2 select_expr { field: "var1" operand { id: 3 select_expr { field: "bar" operand { id: 4 ident_expr { name: "foo" } } } } } } args { id: 5 ident_expr { name: "bar.foo.var2" } } })pb")); } constexpr char kNullValueReferenceExpr[] = R"( id: 1 call_expr { function: "_+_" args { id: 2 ident_expr { name: "google.protobuf.NullValue.NULL_VALUE" } } args { id: 5 const_expr { int64_value: 1 } } } )"; TEST(ResolveReferences, NullValueReferenceSkipped) { std::unique_ptr expr_ast = ParseTestProto(kNullValueReferenceExpr); SourceInfo source_info; CelFunctionRegistry func_registry; ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[2].set_name( "google.protobuf.NullValue.NULL_VALUE"); expr_ast->mutable_reference_map()[2].mutable_value().set_null_value(nullptr); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(/*was_rewritten=*/false)); } constexpr char kExtensionAndExpr[] = R"( id: 1 call_expr { function: "boolean_and" args { id: 2 const_expr { bool_value: true } } args { id: 3 const_expr { bool_value: false } } })"; TEST(ResolveReferences, FunctionReferenceBasic) { std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction( CelFunctionDescriptor("boolean_and", false, { CelValue::Type::kBool, CelValue::Type::kBool, }))); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); } TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(ExtractIssuesStatus(issues), ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); } TEST(ResolveReferences, SpecialBuiltinsNotWarned) { std::unique_ptr expr_ast = ParseTestProto(R"pb( id: 1 call_expr { function: "*" args { id: 2 const_expr { bool_value: true } } args { id: 3 const_expr { bool_value: false } } })pb"); SourceInfo source_info; std::vector special_builtins{ cel::builtin::kAnd, cel::builtin::kOr, cel::builtin::kTernary, cel::builtin::kIndex}; for (const char* builtin_fn : special_builtins) { // Builtins aren't in the function registry. CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( absl::StrCat("builtin.", builtin_fn)); expr_ast->mutable_root_expr().mutable_call_expr().set_function(builtin_fn); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); } } TEST(ResolveReferences, FunctionReferenceMissingOverloadDetectedAndMissingReference) { std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); expr_ast->mutable_reference_map()[1].set_name("udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT( ExtractIssuesStatus(issues), UnorderedElementsAre( Eq(absl::InvalidArgumentError( "No overload found in reference resolve step for boolean_and")), Eq(absl::InvalidArgumentError( "Reference map doesn't provide overloads for boolean_and")))); } TEST(ResolveReferences, EmulatesEagerFailing) { std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kWarning); expr_ast->mutable_reference_map()[1].set_name("udf_boolean_and"); EXPECT_THAT( ResolveReferences(registry, issues, *expr_ast), StatusIs(absl::StatusCode::kInvalidArgument, "Reference map doesn't provide overloads for boolean_and")); } TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { std::unique_ptr expr_ast = ParseTestProto(kExtensionAndExpr); SourceInfo source_info; IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[2].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(ExtractIssuesStatus(issues), ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); } constexpr char kReceiverCallExtensionAndExpr[] = R"( id: 1 call_expr { function: "boolean_and" target { id: 2 ident_expr { name: "ext" } } args { id: 3 const_expr { bool_value: false } } })"; TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { std::unique_ptr expr_ast = ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); } TEST(ResolveReferences, FunctionReferenceWithTargetNoChangeMissingOverloadDetected) { std::unique_ptr expr_ast = ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(ExtractIssuesStatus(issues), ElementsAre(StatusCodeIs(absl::StatusCode::kInvalidArgument))); } TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { std::unique_ptr expr_ast = ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "ext.boolean_and", false, {CelValue::Type::kBool}))); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "ext.boolean_and" args { id: 3 const_expr { bool_value: false } } } )pb")); EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); } TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunctionInContainer) { std::unique_ptr expr_ast = ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "com.google.ext.boolean_and", false, {CelValue::Type::kBool}))); cel::TypeRegistry type_registry; std::vector namespace_prefixes{"com.google.", "google.", ""}; Resolver registry("com.google", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 1 call_expr { function: "com.google.ext.boolean_and" args { id: 3 const_expr { bool_value: false } } } )pb")); EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); } // has(ext.option).boolean_and(false) constexpr char kReceiverCallHasExtensionAndExpr[] = R"( id: 1 call_expr { function: "boolean_and" target { id: 2 select_expr { test_only: true field: "option" operand { id: 3 ident_expr { name: "ext" } } } } args { id: 4 const_expr { bool_value: false } } })"; TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { std::unique_ptr expr_ast = ParseTestProto(kReceiverCallHasExtensionAndExpr); SourceInfo source_info; IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "ext.option.boolean_and", true, {CelValue::Type::kBool}))); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); // The target is unchanged because it is a test_only select. EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(kReceiverCallHasExtensionAndExpr)); EXPECT_THAT(ExtractIssuesStatus(issues), IsEmpty()); } constexpr char kComprehensionExpr[] = R"( id:17 comprehension_expr: { iter_var:"i" iter_range:{ id:1 list_expr:{ elements:{ id:2 const_expr:{int64_value:1} } elements:{ id:3 ident_expr:{name:"ENUM"} } elements:{ id:4 const_expr:{int64_value:3} } } } accu_var:"__result__" accu_init: { id:10 const_expr:{bool_value:false} } loop_condition:{ id:13 call_expr:{ function:"@not_strictly_false" args:{ id:12 call_expr:{ function:"!_" args:{ id:11 ident_expr:{name:"__result__"} } } } } } loop_step:{ id:15 call_expr: { function:"_||_" args:{ id:14 ident_expr: {name:"__result__"} } args:{ id:8 call_expr:{ function:"_==_" args:{ id:7 ident_expr:{name:"ENUM"} } args:{ id:9 ident_expr:{name:"i"} } } } } } result:{id:16 ident_expr:{name:"__result__"}} } )"; TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { std::unique_ptr expr_ast = ParseTestProto(kComprehensionExpr); SourceInfo source_info; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[3].set_name("ENUM"); expr_ast->mutable_reference_map()[3].mutable_value().set_int64_value(2); expr_ast->mutable_reference_map()[7].set_name("ENUM"); expr_ast->mutable_reference_map()[7].mutable_value().set_int64_value(2); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 17 comprehension_expr { iter_var: "i" iter_range { id: 1 list_expr { elements { id: 2 const_expr { int64_value: 1 } } elements { id: 3 const_expr { int64_value: 2 } } elements { id: 4 const_expr { int64_value: 3 } } } } accu_var: "__result__" accu_init { id: 10 const_expr { bool_value: false } } loop_condition { id: 13 call_expr { function: "@not_strictly_false" args { id: 12 call_expr { function: "!_" args { id: 11 ident_expr { name: "__result__" } } } } } } loop_step { id: 15 call_expr { function: "_||_" args { id: 14 ident_expr { name: "__result__" } } args { id: 8 call_expr { function: "_==_" args { id: 7 const_expr { int64_value: 2 } } args { id: 9 ident_expr { name: "i" } } } } } } result { id: 16 ident_expr { name: "__result__" } } })pb")); } TEST(ResolveReferences, ReferenceToId0Warns) { // ID 0 is unsupported since it is not normally used by parsers and is // ambiguous as an intentional ID or default for unset field. std::unique_ptr expr_ast = ParseTestProto(R"pb( id: 0 select_expr { operand { id: 1 ident_expr { name: "pkg" } } field: "var" })pb"); SourceInfo source_info; CelFunctionRegistry func_registry; ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); expr_ast->mutable_reference_map()[0].set_name("pkg.var"); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( id: 0 select_expr { operand { id: 1 ident_expr { name: "pkg" } } field: "var" })pb")); EXPECT_THAT( ExtractIssuesStatus(issues), Contains(StatusIs( absl::StatusCode::kInvalidArgument, "reference map entries for expression id 0 are not supported"))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/regex_precompilation_optimization.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/regex_precompilation_optimization.h" #include #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/builtins.h" #include "common/ast.h" #include "common/casting.h" #include "common/expr.h" #include "common/native_type.h" #include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/compiler_constant_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/regex_match_step.h" #include "internal/casts.h" #include "internal/re2_options.h" #include "internal/status_macros.h" #include "re2/re2.h" namespace google::api::expr::runtime { namespace { using ::cel::Ast; using ::cel::CallExpr; using ::cel::Cast; using ::cel::Expr; using ::cel::InstanceOf; using ::cel::NativeTypeId; using ::cel::Reference; using ::cel::StringValue; using ::cel::Value; using ::cel::internal::down_cast; using ReferenceMap = absl::flat_hash_map; bool IsFunctionOverload(const Expr& expr, absl::string_view function, absl::string_view overload, size_t arity, const ReferenceMap& reference_map) { if (!expr.has_call_expr()) { return false; } const auto& call_expr = expr.call_expr(); if (call_expr.function() != function) { return false; } if (call_expr.args().size() + (call_expr.has_target() ? 1 : 0) != arity) { return false; } // If parse-only and opted in to the optimization, assume this is the intended // overload. This will still only change the evaluation plan if the second arg // is a constant string. if (reference_map.empty()) { return true; } auto reference = reference_map.find(expr.id()); if (reference != reference_map.end() && reference->second.overload_id().size() == 1 && reference->second.overload_id().front() == overload) { return true; } return false; } // Abstraction for deduplicating regular expressions over the course of a single // create expression call. Should not be used during evaluation. Uses // std::shared_ptr and std::weak_ptr. class RegexProgramBuilder final { public: explicit RegexProgramBuilder(int max_program_size) : max_program_size_(max_program_size) {} absl::StatusOr> BuildRegexProgram( std::string pattern) { auto existing = programs_.find(pattern); if (existing != programs_.end()) { if (auto program = existing->second.lock(); program) { return program; } programs_.erase(existing); } auto program = std::make_shared(pattern, cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(*program, max_program_size_)); programs_.insert({std::move(pattern), program}); return program; } private: const int max_program_size_; absl::flat_hash_map> programs_; }; class RegexPrecompilationOptimization : public ProgramOptimizer { public: explicit RegexPrecompilationOptimization(const ReferenceMap& reference_map, int regex_max_program_size) : reference_map_(reference_map), regex_program_builder_(regex_max_program_size) {} absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { return absl::OkStatus(); } absl::Status OnPostVisit(PlannerContext& context, const Expr& node) override { // Check that this is the correct matches overload instead of a user defined // overload. if (!IsFunctionOverload(node, cel::builtin::kRegexMatch, "matches_string", 2, reference_map_)) { return absl::OkStatus(); } ProgramBuilder::Subexpression* subexpression = context.program_builder().GetSubexpression(&node); const CallExpr& call_expr = node.call_expr(); const Expr& pattern_expr = call_expr.args().back(); // Try to check if the regex is valid, whether or not we can actually update // the plan. absl::optional pattern = GetConstantString(context, subexpression, node, pattern_expr); if (!pattern.has_value()) { return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN( std::shared_ptr regex_program, regex_program_builder_.BuildRegexProgram(std::move(pattern).value())); if (subexpression == nullptr || subexpression->IsFlattened()) { // Already modified, can't update further. return absl::OkStatus(); } const Expr& subject_expr = call_expr.has_target() ? call_expr.target() : call_expr.args().front(); return RewritePlan(context, subexpression, node, subject_expr, std::move(regex_program)); } private: absl::optional GetConstantString( PlannerContext& context, ProgramBuilder::Subexpression* absl_nullable subexpression, const Expr& call_expr, const Expr& re_expr) const { if (re_expr.has_const_expr() && re_expr.const_expr().has_string_value()) { return re_expr.const_expr().string_value(); } if (subexpression == nullptr || subexpression->IsFlattened()) { // Already modified, can't recover the input pattern. return absl::nullopt; } absl::optional constant; if (subexpression->IsRecursive()) { const auto& program = subexpression->recursive_program(); auto deps = program.step->GetDependencies(); if (deps.has_value() && deps->size() == 2) { const auto* re_plan = TryDowncastDirectStep(deps->at(1)); if (re_plan != nullptr) { constant = re_plan->value(); } } } else { // otherwise stack-machine program. ExecutionPathView re_plan = context.GetSubplan(re_expr); if (re_plan.size() == 1 && re_plan[0]->GetNativeTypeId() == NativeTypeId::For()) { constant = down_cast(re_plan[0].get())->value(); } } if (constant.has_value() && InstanceOf(*constant)) { return Cast(*constant).ToString(); } return absl::nullopt; } absl::Status RewritePlan( PlannerContext& context, ProgramBuilder::Subexpression* absl_nonnull subexpression, const Expr& call, const Expr& subject, std::shared_ptr regex_program) { if (subexpression->IsRecursive()) { return RewriteRecursivePlan(subexpression, call, subject, std::move(regex_program)); } return RewriteStackMachinePlan(context, call, subject, std::move(regex_program)); } absl::Status RewriteRecursivePlan( ProgramBuilder::Subexpression* absl_nonnull subexpression, const Expr& call, const Expr& subject, std::shared_ptr regex_program) { auto program = subexpression->ExtractRecursiveProgram(); auto deps = program.step->ExtractDependencies(); if (!deps.has_value() || deps->size() != 2) { // Possibly already const-folded, put the plan back. subexpression->set_recursive_program(std::move(program.step), program.depth); return absl::OkStatus(); } subexpression->set_recursive_program( CreateDirectRegexMatchStep(call.id(), std::move(deps->at(0)), std::move(regex_program)), program.depth); return absl::OkStatus(); } absl::Status RewriteStackMachinePlan( PlannerContext& context, const Expr& call, const Expr& subject, std::shared_ptr regex_program) { if (context.GetSubplan(subject).empty()) { // This subexpression was already optimized, nothing to do. return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(ExecutionPath new_plan, context.ExtractSubplan(subject)); CEL_ASSIGN_OR_RETURN( new_plan.emplace_back(), CreateRegexMatchStep(std::move(regex_program), call.id())); return context.ReplaceSubplan(call, std::move(new_plan)); } const ReferenceMap& reference_map_; RegexProgramBuilder regex_program_builder_; }; } // namespace ProgramOptimizerFactory CreateRegexPrecompilationExtension( int regex_max_program_size) { return [=](PlannerContext& context, const Ast& ast) { return std::make_unique( ast.reference_map(), regex_max_program_size); }; } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/regex_precompilation_optimization.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ #include "eval/compiler/flat_expr_builder_extensions.h" namespace google::api::expr::runtime { // Create a new extension for the FlatExprBuilder that precompiles constant // regular expressions used in the standard 'Match' function. ProgramOptimizerFactory CreateRegexPrecompilationExtension( int regex_max_program_size); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_REGEX_PRECOMPILATION_OPTIMIZATION_H_ ================================================ FILE: eval/compiler/regex_precompilation_optimization_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/regex_precompilation_optimization.h" #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/ast.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/evaluator_core.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/internal/issue_collector.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::cel::RuntimeIssue; using ::cel::runtime_internal::IssueCollector; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; namespace exprpb = cel::expr; class RegexPrecompilationExtensionTest : public testing::TestWithParam { public: RegexPrecompilationExtensionTest() : env_(NewTestingRuntimeEnv()), builder_(env_), type_registry_(*builder_.GetTypeRegistry()), function_registry_(*builder_.GetRegistry()), resolver_("", function_registry_.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()), issue_collector_(RuntimeIssue::Severity::kError) { if (EnableRecursivePlanning()) { options_.max_recursion_depth = -1; options_.enable_recursive_tracing = true; } options_.enable_regex = true; options_.regex_max_program_size = 100; options_.enable_regex_precompilation = true; runtime_options_ = ConvertToRuntimeOptions(options_); } void SetUp() override { ASSERT_OK(RegisterBuiltinFunctions(&function_registry_, options_)); } bool EnableRecursivePlanning() { return GetParam(); } protected: CelEvaluationListener RecordStringValues() { return [this](int64_t, const CelValue& value, google::protobuf::Arena*) { if (value.IsString()) { string_values_.push_back(std::string(value.StringOrDie().value())); } return absl::OkStatus(); }; } absl_nonnull std::shared_ptr env_; CelExpressionBuilderFlatImpl builder_; CelTypeRegistry& type_registry_; CelFunctionRegistry& function_registry_; InterpreterOptions options_; cel::RuntimeOptions runtime_options_; Resolver resolver_; IssueCollector issue_collector_; std::vector string_values_; }; TEST_P(RegexPrecompilationExtensionTest, SmokeTest) { ProgramOptimizerFactory factory = CreateRegexPrecompilationExtension(options_.regex_max_program_size); ExecutionPath path; ProgramBuilder program_builder; cel::Ast ast_impl; ast_impl.set_is_checked(true); std::shared_ptr arena; PlannerContext context(env_, resolver_, runtime_options_, type_registry_.GetTypeProvider(), issue_collector_, program_builder, arena); ASSERT_OK_AND_ASSIGN(std::unique_ptr optimizer, factory(context, ast_impl)); } TEST_P(RegexPrecompilationExtensionTest, OptimizeableExpression) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, Parse("input.matches(r'[a-zA-Z]+[0-9]*')")); // Fake reference information for the matches call. exprpb::CheckedExpr expr; expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); Activation activation; google::protobuf::Arena arena; activation.InsertValue("input", CelValue::CreateStringView("input123")); ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); EXPECT_THAT(string_values_, ElementsAre("input123")); } TEST_P(RegexPrecompilationExtensionTest, OptimizeParsedExpr) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr expr, Parse("input.matches(r'[a-zA-Z]+[0-9]*')")); ASSERT_OK_AND_ASSIGN( std::unique_ptr plan, builder_.CreateExpression(&expr.expr(), &expr.source_info())); Activation activation; google::protobuf::Arena arena; activation.InsertValue("input", CelValue::CreateStringView("input123")); ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); EXPECT_THAT(string_values_, ElementsAre("input123")); } TEST_P(RegexPrecompilationExtensionTest, DoesNotOptimizeNonConstRegex) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, Parse("input.matches(input_re)")); // Fake reference information for the matches call. exprpb::CheckedExpr expr; expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); Activation activation; google::protobuf::Arena arena; activation.InsertValue("input", CelValue::CreateStringView("input123")); activation.InsertValue("input_re", CelValue::CreateStringView("input_re")); ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); EXPECT_THAT(string_values_, ElementsAre("input123", "input_re")); } TEST_P(RegexPrecompilationExtensionTest, DoesNotOptimizeCompoundExpr) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, Parse("input.matches('abc' + 'def')")); // Fake reference information for the matches call. exprpb::CheckedExpr expr; expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); Activation activation; google::protobuf::Arena arena; activation.InsertValue("input", CelValue::CreateStringView("input123")); ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); EXPECT_THAT(string_values_, ElementsAre("input123", "abc", "def", "abcdef")); } class RegexConstFoldInteropTest : public RegexPrecompilationExtensionTest { public: RegexConstFoldInteropTest() : RegexPrecompilationExtensionTest() { builder_.flat_expr_builder().AddProgramOptimizer( cel::runtime_internal::CreateConstantFoldingOptimizer()); } protected: google::protobuf::Arena arena_; }; TEST_P(RegexConstFoldInteropTest, StringConstantOptimizeable) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, Parse("input.matches('abc' + 'def')")); // Fake reference information for the matches call. exprpb::CheckedExpr expr; expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); Activation activation; google::protobuf::Arena arena; activation.InsertValue("input", CelValue::CreateStringView("input123")); ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); EXPECT_THAT(string_values_, ElementsAre("input123")); } TEST_P(RegexConstFoldInteropTest, WrongTypeNotOptimized) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); ASSERT_OK_AND_ASSIGN(exprpb::ParsedExpr parsed_expr, Parse("input.matches(123 + 456)")); // Fake reference information for the matches call. exprpb::CheckedExpr expr; expr.mutable_expr()->Swap(parsed_expr.mutable_expr()); expr.mutable_source_info()->Swap(parsed_expr.mutable_source_info()); (*expr.mutable_reference_map())[2].add_overload_id("matches_string"); ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); Activation activation; google::protobuf::Arena arena; activation.InsertValue("input", CelValue::CreateStringView("input123")); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Trace(activation, &arena, RecordStringValues())); EXPECT_THAT(string_values_, ElementsAre("input123")); EXPECT_TRUE(result.IsError()); EXPECT_TRUE(CheckNoMatchingOverloadError(result)); } INSTANTIATE_TEST_SUITE_P(RegexPrecompilationExtensionTest, RegexPrecompilationExtensionTest, testing::Bool()); INSTANTIATE_TEST_SUITE_P(RegexConstFoldInteropTest, RegexConstFoldInteropTest, testing::Bool()); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/resolver.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/resolver.h" #include #include #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/kind.h" #include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" #include "runtime/type_registry.h" namespace google::api::expr::runtime { namespace { using ::cel::TypeValue; using ::cel::Value; using ::cel::runtime_internal::GetEnumValueTable; std::vector MakeNamespaceCandidates(absl::string_view container) { std::vector namespace_prefixes; std::string prefix = ""; namespace_prefixes.push_back(prefix); auto container_elements = absl::StrSplit(container, '.'); for (const auto& elem : container_elements) { // Tolerate trailing / leading '.'. if (elem.empty()) { continue; } absl::StrAppend(&prefix, elem, "."); // longest prefix first. namespace_prefixes.insert(namespace_prefixes.begin(), prefix); } return namespace_prefixes; } } // namespace Resolver::Resolver(absl::string_view container, const cel::FunctionRegistry& function_registry, const cel::TypeRegistry& type_registry, const cel::TypeReflector& type_reflector, bool resolve_qualified_type_identifiers) : namespace_prefixes_(MakeNamespaceCandidates(container)), enum_value_map_(GetEnumValueTable(type_registry)), function_registry_(function_registry), type_reflector_(type_reflector), resolve_qualified_type_identifiers_(resolve_qualified_type_identifiers) {} std::vector Resolver::FullyQualifiedNames(absl::string_view name, int64_t expr_id) const { // TODO(issues/105): refactor the reference resolution into this method. // and handle the case where this id is in the reference map as either a // function name or identifier name. std::vector names; auto prefixes = GetPrefixesFor(name); names.reserve(prefixes.size()); for (const auto& prefix : prefixes) { std::string fully_qualified_name = absl::StrCat(prefix, name); names.push_back(fully_qualified_name); } return names; } absl::Span Resolver::GetPrefixesFor( absl::string_view& name) const { static const absl::NoDestructor kEmptyPrefix(""); if (absl::StartsWith(name, ".")) { name = name.substr(1); return absl::MakeConstSpan(kEmptyPrefix.get(), 1); } return namespace_prefixes_; } absl::optional Resolver::FindConstant(absl::string_view name, int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (const auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); // Attempt to resolve the fully qualified name to a known enum. auto enum_entry = enum_value_map_->find(qualified_name); if (enum_entry != enum_value_map_->end()) { return enum_entry->second; } // Attempt to resolve the fully qualified name to a known type. if (resolve_qualified_type_identifiers_) { auto type_value = type_reflector_.FindType(qualified_name); if (type_value.ok() && type_value->has_value()) { return TypeValue(**type_value); } } } if (!resolve_qualified_type_identifiers_ && !absl::StrContains(name, '.')) { auto type_value = type_reflector_.FindType(name); if (type_value.ok() && type_value->has_value()) { return TypeValue(**type_value); } } return absl::nullopt; } std::vector Resolver::FindOverloads( absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id) const { // Resolve the fully qualified names and then search the function registry // for possible matches. std::vector funcs; auto names = FullyQualifiedNames(name, expr_id); for (auto it = names.begin(); it != names.end(); it++) { // Only one set of overloads is returned along the namespace hierarchy as // the function name resolution follows the same behavior as variable name // resolution, meaning the most specific definition wins. This is different // from how C++ namespaces work, as they will accumulate the overload set // over the namespace hierarchy. funcs = function_registry_.FindStaticOverloads(*it, receiver_style, types); if (!funcs.empty()) { return funcs; } } return funcs; } std::vector Resolver::FindOverloads( absl::string_view name, bool receiver_style, size_t arity, int64_t expr_id) const { std::vector funcs; auto prefixes = GetPrefixesFor(name); for (const auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); // Only one set of overloads is returned along the namespace hierarchy as // the function name resolution follows the same behavior as variable name // resolution, meaning the most specific definition wins. This is different // from how C++ namespaces work, as they will accumulate the overload set // over the namespace hierarchy. funcs = function_registry_.FindStaticOverloadsByArity( qualified_name, receiver_style, arity); if (!funcs.empty()) { return funcs; } } return funcs; } std::vector Resolver::FindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id) const { // Resolve the fully qualified names and then search the function registry // for possible matches. std::vector funcs; auto names = FullyQualifiedNames(name, expr_id); for (const auto& name : names) { funcs = function_registry_.FindLazyOverloads(name, receiver_style, types); if (!funcs.empty()) { return funcs; } } return funcs; } std::vector Resolver::FindLazyOverloads( absl::string_view name, bool receiver_style, size_t arity, int64_t expr_id) const { std::vector funcs; auto prefixes = GetPrefixesFor(name); for (const auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); funcs = function_registry_.FindLazyOverloadsByArity(name, receiver_style, arity); if (!funcs.empty()) { return funcs; } } return funcs; } absl::StatusOr>> Resolver::FindType(absl::string_view name, int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); CEL_ASSIGN_OR_RETURN(auto maybe_type, type_reflector_.FindType(qualified_name)); if (maybe_type.has_value()) { return std::make_pair(std::move(qualified_name), std::move(*maybe_type)); } } return absl::nullopt; } } // namespace google::api::expr::runtime ================================================ FILE: eval/compiler/resolver.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/kind.h" #include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" #include "runtime/type_registry.h" namespace google::api::expr::runtime { // Resolver assists with finding functions and types from the associated // registries within a container. // // container is used to construct the namespace lookup candidates. // e.g. for "cel.dev" -> {"cel.dev.", "cel.", ""} class Resolver { public: Resolver(absl::string_view container, const cel::FunctionRegistry& function_registry, const cel::TypeRegistry& type_registry, const cel::TypeReflector& type_reflector, bool resolve_qualified_type_identifiers = true); Resolver(const Resolver&) = delete; Resolver& operator=(const Resolver&) = delete; Resolver(Resolver&&) = delete; Resolver& operator=(Resolver&&) = delete; ~Resolver() = default; // FindConstant will return an enum constant value or a type value if one // exists for the given name. An empty handle will be returned if none exists. // // Since enums and type identifiers are specified as (potentially) qualified // names within an expression, there is the chance that the name provided // is a variable name which happens to collide with an existing enum or proto // based type name. For this reason, within parsed only expressions, the // constant should be treated as a value that can be shadowed by a runtime // provided value. absl::optional FindConstant(absl::string_view name, int64_t expr_id) const; absl::StatusOr>> FindType( absl::string_view name, int64_t expr_id) const; // FindLazyOverloads returns the set, possibly empty, of lazy overloads // matching the given function signature. std::vector FindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id = -1) const; std::vector FindLazyOverloads( absl::string_view name, bool receiver_style, size_t arity, int64_t expr_id = -1) const; // FindOverloads returns the set, possibly empty, of eager function overloads // matching the given function signature. std::vector FindOverloads( absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id = -1) const; std::vector FindOverloads( absl::string_view name, bool receiver_style, size_t arity, int64_t expr_id = -1) const; // FullyQualifiedNames returns the set of fully qualified names which may be // derived from the base_name within the specified expression container. std::vector FullyQualifiedNames(absl::string_view base_name, int64_t expr_id = -1) const; private: absl::Span GetPrefixesFor(absl::string_view& name) const; std::vector namespace_prefixes_; std::shared_ptr> enum_value_map_; const cel::FunctionRegistry& function_registry_; const cel::TypeReflector& type_reflector_; bool resolve_qualified_type_identifiers_; }; // ArgumentMatcher generates a function signature matcher for CelFunctions. // TODO(issues/91): this is the same behavior as parsed exprs in the CPP // evaluator (just check the right call style and number of arguments), but we // should have enough type information in a checked expr to find a more // specific candidate list. inline std::vector ArgumentsMatcher(int argument_count) { std::vector argument_matcher(argument_count); for (int i = 0; i < argument_count; i++) { argument_matcher[i] = cel::Kind::kAny; } return argument_matcher; } } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ ================================================ FILE: eval/compiler/resolver_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/compiler/resolver.h" #include #include #include #include "absl/status/status.h" #include "absl/types/span.h" #include "common/value.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::cel::IntValue; using ::cel::TypeValue; using ::testing::Eq; class FakeFunction : public CelFunction { public: explicit FakeFunction(const std::string& name) : CelFunction(CelFunctionDescriptor{name, false, {}}) {} absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { return absl::OkStatus(); } }; class ResolverTest : public testing::Test { public: ResolverTest() = default; protected: CelTypeRegistry type_registry_; }; TEST_F(ResolverTest, TestFullyQualifiedNames) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); auto names = resolver.FullyQualifiedNames("simple_name"); std::vector expected_names( {"google.api.expr.simple_name", "google.api.simple_name", "google.simple_name", "simple_name"}); EXPECT_THAT(names, Eq(expected_names)); } TEST_F(ResolverTest, TestFullyQualifiedNamesPartiallyQualifiedName) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); auto names = resolver.FullyQualifiedNames("expr.simple_name"); std::vector expected_names( {"google.api.expr.expr.simple_name", "google.api.expr.simple_name", "google.expr.simple_name", "expr.simple_name"}); EXPECT_THAT(names, Eq(expected_names)); } TEST_F(ResolverTest, TestFullyQualifiedNamesAbsoluteName) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); auto names = resolver.FullyQualifiedNames(".google.api.expr.absolute_name"); EXPECT_THAT(names.size(), Eq(1)); EXPECT_THAT(names[0], Eq("google.api.expr.absolute_name")); } TEST_F(ResolverTest, TestFindConstantEnum) { CelFunctionRegistry func_registry; type_registry_.Register(TestMessage::TestEnum_descriptor()); Resolver resolver("google.api.expr.runtime.TestMessage", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); auto enum_value = resolver.FindConstant("TestEnum.TEST_ENUM_1", -1); ASSERT_TRUE(enum_value); ASSERT_TRUE(enum_value->Is()); EXPECT_THAT(enum_value->GetInt().NativeValue(), Eq(1L)); enum_value = resolver.FindConstant( ".google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_2", -1); ASSERT_TRUE(enum_value); ASSERT_TRUE(enum_value->Is()); EXPECT_THAT(enum_value->GetInt().NativeValue(), Eq(2L)); } TEST_F(ResolverTest, TestFindConstantUnqualifiedType) { CelFunctionRegistry func_registry; Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); auto type_value = resolver.FindConstant("int", -1); EXPECT_TRUE(type_value); EXPECT_TRUE(type_value->Is()); EXPECT_THAT(type_value->GetType().name(), Eq("int")); } TEST_F(ResolverTest, TestFindConstantFullyQualifiedType) { google::protobuf::LinkMessageReflection(); CelFunctionRegistry func_registry; Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); auto type_value = resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); ASSERT_TRUE(type_value); ASSERT_TRUE(type_value->Is()); EXPECT_THAT(type_value->GetType().name(), Eq("google.api.expr.runtime.TestMessage")); } TEST_F(ResolverTest, TestFindConstantQualifiedTypeDisabled) { CelFunctionRegistry func_registry; Resolver resolver("", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider(), false); auto type_value = resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); EXPECT_FALSE(type_value); } TEST_F(ResolverTest, FindTypeBySimpleName) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr.runtime", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); ASSERT_OK_AND_ASSIGN(auto type, resolver.FindType("TestMessage", -1)); EXPECT_TRUE(type.has_value()); EXPECT_EQ(type->second.name(), "google.api.expr.runtime.TestMessage"); } TEST_F(ResolverTest, FindTypeByQualifiedName) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr.runtime", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); ASSERT_OK_AND_ASSIGN( auto type, resolver.FindType(".google.api.expr.runtime.TestMessage", -1)); ASSERT_TRUE(type.has_value()); EXPECT_EQ(type->second.name(), "google.api.expr.runtime.TestMessage"); } TEST_F(ResolverTest, TestFindDescriptorNotFound) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr.runtime", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); ASSERT_OK_AND_ASSIGN(auto type, resolver.FindType("UndefinedMessage", -1)); EXPECT_FALSE(type.has_value()) << type->second; } TEST_F(ResolverTest, TestFindOverloads) { CelFunctionRegistry func_registry; auto status = func_registry.Register(std::make_unique("fake_func")); ASSERT_OK(status); status = func_registry.Register( std::make_unique("cel.fake_ns_func")); ASSERT_OK(status); Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); auto overloads = resolver.FindOverloads("fake_func", false, ArgumentsMatcher(0)); EXPECT_THAT(overloads.size(), Eq(1)); EXPECT_THAT(overloads[0].descriptor.name(), Eq("fake_func")); overloads = resolver.FindOverloads("fake_ns_func", false, ArgumentsMatcher(0)); EXPECT_THAT(overloads.size(), Eq(1)); EXPECT_THAT(overloads[0].descriptor.name(), Eq("cel.fake_ns_func")); } TEST_F(ResolverTest, TestFindLazyOverloads) { CelFunctionRegistry func_registry; auto status = func_registry.RegisterLazyFunction( CelFunctionDescriptor{"fake_lazy_func", false, {}}); ASSERT_OK(status); status = func_registry.RegisterLazyFunction( CelFunctionDescriptor{"cel.fake_lazy_ns_func", false, {}}); ASSERT_OK(status); Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), type_registry_.GetTypeProvider()); auto overloads = resolver.FindLazyOverloads("fake_lazy_func", false, ArgumentsMatcher(0)); EXPECT_THAT(overloads.size(), Eq(1)); overloads = resolver.FindLazyOverloads("fake_lazy_ns_func", false, ArgumentsMatcher(0)); EXPECT_THAT(overloads.size(), Eq(1)); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/BUILD ================================================ # Copyright 2017 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") # This package contains implementation of expression evaluator # internals. package(default_visibility = ["//visibility:public"]) licenses(["notice"]) exports_files(["LICENSE"]) package_group( name = "internal_eval_visibility", packages = [ "//eval/...", "//extensions", "//runtime/internal", ], ) cc_library( name = "evaluator_core", srcs = [ "evaluator_core.cc", ], hdrs = [ "evaluator_core.h", ], deps = [ ":attribute_utility", ":comprehension_slots", ":evaluator_stack", ":iterator_stack", "//base:data", "//common:native_type", "//common:value", "//runtime", "//runtime:activation_interface", "//runtime:runtime_options", "//runtime/internal:activation_attribute_matcher_access", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_expression_flat_impl", srcs = [ "cel_expression_flat_impl.cc", ], hdrs = [ "cel_expression_flat_impl.h", ], deps = [ ":attribute_trail", ":comprehension_slots", ":direct_expression_step", ":evaluator_core", "//common:native_type", "//common:value", "//eval/internal:adapter_activation_impl", "//eval/internal:interop", "//eval/public:base_activation", "//eval/public:cel_expression", "//eval/public:cel_value", "//internal:casts", "//internal:status_macros", "//runtime/internal:runtime_env", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "comprehension_slots", hdrs = [ "comprehension_slots.h", ], deps = [ ":attribute_trail", "//common:value", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/types:optional", ], ) cc_test( name = "comprehension_slots_test", srcs = [ "comprehension_slots_test.cc", ], deps = [ ":attribute_trail", ":comprehension_slots", "//base:attributes", "//base:data", "//common:memory", "//common:value", "//internal:testing", ], ) cc_library( name = "evaluator_stack", srcs = [ "evaluator_stack.cc", ], hdrs = [ "evaluator_stack.h", ], deps = [ ":attribute_trail", "//common:value", "//internal:align", "//internal:new", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_test( name = "evaluator_stack_test", srcs = [ "evaluator_stack_test.cc", ], deps = [ ":evaluator_stack", "//base:attributes", "//common:value", "//internal:testing", ], ) cc_library( name = "expression_step_base", hdrs = [ "expression_step_base.h", ], deps = [":evaluator_core"], ) cc_library( name = "const_value_step", hdrs = [ "const_value_step.h", ], deps = [ ":compiler_constant_step", ":direct_expression_step", ":evaluator_core", "//common:value", "@com_google_absl//absl/status:statusor", ], ) cc_library( name = "container_access_step", srcs = [ "container_access_step.cc", ], hdrs = [ "container_access_step.h", ], deps = [ ":attribute_trail", ":attribute_utility", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:attributes", "//common:casting", "//common:expr", "//common:kind", "//common:value", "//common:value_kind", "//eval/internal:errors", "//internal:number", "//internal:status_macros", "//runtime/internal:errors", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_library( name = "regex_match_step", srcs = ["regex_match_step.cc"], hdrs = ["regex_match_step.h"], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_googlesource_code_re2//:re2", ], ) cc_library( name = "ident_step", srcs = [ "ident_step.cc", ], hdrs = [ "ident_step.h", ], deps = [ ":attribute_trail", ":comprehension_slots", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:expr", "//common:value", "//eval/internal:errors", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) cc_library( name = "function_step", srcs = [ "function_step.cc", ], hdrs = [ "function_step.h", ], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:casting", "//common:expr", "//common:function_descriptor", "//common:kind", "//common:value", "//common:value_kind", "//eval/internal:errors", "//internal:status_macros", "//runtime:activation_interface", "//runtime:function", "//runtime:function_overload_reference", "//runtime:function_provider", "//runtime:function_registry", "//runtime/internal:errors", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_library( name = "select_step", srcs = [ "select_step.cc", ], hdrs = [ "select_step.h", ], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:expr", "//common:value", "//common:value_kind", "//eval/internal:errors", "//internal:status_macros", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "create_list_step", srcs = [ "create_list_step.cc", ], hdrs = [ "create_list_step.h", ], deps = [ ":attribute_trail", ":attribute_utility", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:casting", "//common:expr", "//common:value", "//internal:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "create_struct_step", srcs = [ "create_struct_step.cc", ], hdrs = [ "create_struct_step.h", ], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:casting", "//common:value", "//internal:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "create_map_step", srcs = [ "create_map_step.cc", ], hdrs = [ "create_map_step.h", ], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:casting", "//common:value", "//internal:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "jump_step", srcs = [ "jump_step.cc", ], hdrs = [ "jump_step.h", ], deps = [ ":evaluator_core", ":expression_step_base", "//common:value", "//eval/internal:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "logic_step", srcs = [ "logic_step.cc", ], hdrs = [ "logic_step.h", ], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:builtins", "//common:casting", "//common:value", "//common:value_kind", "//eval/internal:errors", "//internal:status_macros", "//runtime/internal:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_library( name = "equality_steps", srcs = [ "equality_steps.cc", ], hdrs = [ "equality_steps.h", ], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:builtins", "//common:value", "//common:value_kind", "//internal:number", "//internal:status_macros", "//runtime/internal:errors", "//runtime/standard:equality_functions", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_test( name = "equality_steps_test", srcs = [ "equality_steps_test.cc", ], deps = [ ":attribute_trail", ":direct_expression_step", ":equality_steps", ":evaluator_core", "//base:attributes", "//common:value", "//common:value_kind", "//common:value_testing", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "comprehension_step", srcs = [ "comprehension_step.cc", ], hdrs = [ "comprehension_step.h", ], deps = [ ":attribute_trail", ":comprehension_slots", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:attributes", "//common:casting", "//common:value", "//common:value_kind", "//eval/internal:errors", "//internal:status_macros", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_test( name = "comprehension_step_test", size = "small", srcs = [ "comprehension_step_test.cc", ], deps = [ ":attribute_trail", ":cel_expression_flat_impl", ":comprehension_slots", ":comprehension_step", ":const_value_step", ":direct_expression_step", ":evaluator_core", ":expression_step_base", ":ident_step", "//base:data", "//common:expr", "//common:value", "//common:value_testing", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_env_testing", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "evaluator_core_test", size = "small", srcs = [ "evaluator_core_test.cc", ], deps = [ ":cel_expression_flat_impl", ":evaluator_core", "//base:data", "//common:value", "//eval/compiler:cel_expression_builder_flat_impl", "//eval/internal:interop", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_value", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_env_testing", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "container_access_step_test", size = "small", srcs = [ "container_access_step_test.cc", ], deps = [ ":cel_expression_flat_impl", ":container_access_step", ":direct_expression_step", ":evaluator_core", ":ident_step", "//base:builtins", "//base:data", "//common:ast", "//common:expr", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:unknown_set", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/public/testing:matchers", "//internal:testing", "//parser", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "regex_match_step_test", size = "small", srcs = [ "regex_match_step_test.cc", ], deps = [ ":regex_match_step", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_options", "//internal:testing", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "ident_step_test", size = "small", srcs = [ "ident_step_test.cc", ], deps = [ ":attribute_trail", ":cel_expression_flat_impl", ":evaluator_core", ":ident_step", "//base:data", "//common:casting", "//common:memory", "//common:value", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_value", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_env_testing", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "function_step_test", size = "small", srcs = [ "function_step_test.cc", ], deps = [ ":cel_expression_flat_impl", ":const_value_step", ":direct_expression_step", ":evaluator_core", ":function_step", ":ident_step", "//base:builtins", "//base:data", "//common:constant", "//common:expr", "//common:kind", "//common:value", "//eval/internal:interop", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_function", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:portable_cel_function_adapter", "//eval/public/structs:cel_proto_wrapper", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:testing", "//runtime:function_overload_reference", "//runtime:function_registry", "//runtime:runtime_options", "//runtime:standard_functions", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "logic_step_test", size = "small", srcs = [ "logic_step_test.cc", ], deps = [ ":attribute_trail", ":cel_expression_flat_impl", ":const_value_step", ":direct_expression_step", ":evaluator_core", ":ident_step", ":logic_step", "//base:attributes", "//base:data", "//common:casting", "//common:expr", "//common:unknown", "//common:value", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "select_step_test", size = "small", srcs = [ "select_step_test.cc", ], deps = [ ":attribute_trail", ":cel_expression_flat_impl", ":const_value_step", ":evaluator_core", ":ident_step", ":select_step", "//base:attributes", "//base:data", "//common:casting", "//common:expr", "//common:legacy_value", "//common:value", "//common:value_testing", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:trivial_legacy_type_info", "//eval/public/testing:matchers", "//eval/testutil:test_extensions_cc_proto", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:value", "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_test( name = "create_list_step_test", size = "small", srcs = [ "create_list_step_test.cc", ], deps = [ ":attribute_trail", ":cel_expression_flat_impl", ":const_value_step", ":create_list_step", ":direct_expression_step", ":evaluator_core", ":ident_step", "//base:attributes", "//base:data", "//common:casting", "//common:expr", "//common:value", "//common:value_testing", "//eval/internal:interop", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public/testing:matchers", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "create_struct_step_test", size = "small", srcs = [ "create_struct_step_test.cc", ], deps = [ ":cel_expression_flat_impl", ":create_struct_step", ":direct_expression_step", ":evaluator_core", ":ident_step", "//base:data", "//common:expr", "//eval/public:activation", "//eval/public:cel_type_registry", "//eval/public:cel_value", "//eval/public:unknown_set", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "create_map_step_test", size = "small", srcs = [ "create_map_step_test.cc", ], deps = [ ":cel_expression_flat_impl", ":create_map_step", ":direct_expression_step", ":evaluator_core", ":ident_step", "//base:data", "//common:expr", "//eval/public:activation", "//eval/public:cel_value", "//eval/public:unknown_set", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "attribute_trail", srcs = ["attribute_trail.cc"], hdrs = ["attribute_trail.h"], deps = [ "//base:attributes", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/utility", ], ) cc_test( name = "attribute_trail_test", size = "small", srcs = [ "attribute_trail_test.cc", ], deps = [ ":attribute_trail", "//eval/public:cel_attribute", "//eval/public:cel_value", "//internal:testing", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "attribute_utility", srcs = ["attribute_utility.cc"], hdrs = ["attribute_utility.h"], deps = [ ":attribute_trail", "//base:attributes", "//base:function_result", "//base:function_result_set", "//base/internal:unknown_set", "//common:casting", "//common:function_descriptor", "//common:unknown", "//common:value", "//eval/internal:errors", "//internal:status_macros", "//runtime/internal:attribute_matcher", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_test( name = "attribute_utility_test", size = "small", srcs = [ "attribute_utility_test.cc", ], deps = [ ":attribute_trail", ":attribute_utility", "//base:attributes", "//common:unknown", "//common:value", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", "//internal:testing", "//runtime/internal:attribute_matcher", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "ternary_step", srcs = [ "ternary_step.cc", ], hdrs = [ "ternary_step.h", ], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:builtins", "//common:value", "//eval/internal:errors", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_test( name = "ternary_step_test", size = "small", srcs = [ "ternary_step_test.cc", ], deps = [ ":attribute_trail", ":cel_expression_flat_impl", ":const_value_step", ":direct_expression_step", ":evaluator_core", ":ident_step", ":ternary_step", "//base:attributes", "//base:data", "//common:casting", "//common:expr", "//common:value", "//eval/public:activation", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "shadowable_value_step", srcs = ["shadowable_value_step.cc"], hdrs = ["shadowable_value_step.h"], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:value", "//internal:status_macros", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "shadowable_value_step_test", size = "small", srcs = ["shadowable_value_step_test.cc"], deps = [ ":cel_expression_flat_impl", ":evaluator_core", ":shadowable_value_step", "//base:data", "//common:value", "//eval/internal:interop", "//eval/public:activation", "//eval/public:cel_value", "//internal:status_macros", "//internal:testing", "//runtime:runtime_options", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", ], ) cc_library( name = "compiler_constant_step", srcs = ["compiler_constant_step.cc"], hdrs = ["compiler_constant_step.h"], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:native_type", "//common:value", "@com_google_absl//absl/status", ], ) cc_test( name = "compiler_constant_step_test", srcs = ["compiler_constant_step_test.cc"], deps = [ ":compiler_constant_step", ":evaluator_core", "//common:native_type", "//common:value", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_type_provider", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "lazy_init_step", srcs = ["lazy_init_step.cc"], hdrs = ["lazy_init_step.h"], deps = [ ":attribute_trail", ":comprehension_slots", ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:value", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", ], ) cc_test( name = "lazy_init_step_test", srcs = ["lazy_init_step_test.cc"], deps = [ ":const_value_step", ":evaluator_core", ":lazy_init_step", "//base:data", "//common:value", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:runtime_type_provider", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "direct_expression_step", srcs = ["direct_expression_step.cc"], hdrs = ["direct_expression_step.h"], deps = [ ":attribute_trail", ":evaluator_core", "//common:native_type", "//common:value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "trace_step", hdrs = ["trace_step.h"], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", "//common:native_type", "//common:value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "optional_or_step", srcs = ["optional_or_step.cc"], hdrs = ["optional_or_step.h"], deps = [ ":attribute_trail", ":direct_expression_step", ":evaluator_core", ":expression_step_base", ":jump_step", "//common:casting", "//common:value", "//internal:status_macros", "//runtime/internal:errors", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_test( name = "optional_or_step_test", srcs = ["optional_or_step_test.cc"], deps = [ ":attribute_trail", ":const_value_step", ":direct_expression_step", ":evaluator_core", ":optional_or_step", "//common:casting", "//common:value", "//common:value_kind", "//common:value_testing", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:activation", "//runtime:runtime_options", "//runtime/internal:errors", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "iterator_stack", hdrs = ["iterator_stack.h"], deps = [ "//common:value", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", ], ) ================================================ FILE: eval/eval/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: eval/eval/attribute_trail.cc ================================================ #include "eval/eval/attribute_trail.h" #include #include #include #include #include #include "base/attribute.h" namespace google::api::expr::runtime { // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail AttributeTrail::Step(cel::AttributeQualifier qualifier) const { // Cannot continue void trail if (empty()) return AttributeTrail(); std::vector qualifiers; qualifiers.reserve(attribute_->qualifier_path().size() + 1); std::copy_n(attribute_->qualifier_path().begin(), attribute_->qualifier_path().size(), std::back_inserter(qualifiers)); qualifiers.push_back(std::move(qualifier)); return AttributeTrail(cel::Attribute(std::string(attribute_->variable_name()), std::move(qualifiers))); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/attribute_trail.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ #include #include #include "absl/types/optional.h" #include "absl/utility/utility.h" #include "base/attribute.h" namespace google::api::expr::runtime { // AttributeTrail reflects current attribute path. // It is functionally similar to cel::Attribute, yet intended to have better // complexity on attribute path increment operations. // TODO(issues/41) Current AttributeTrail implementation is equivalent to // cel::Attribute - improve it. // Intended to be used in conjunction with cel::Value, describing the attribute // value originated from. // Empty AttributeTrail denotes object with attribute path not defined // or supported. class AttributeTrail { public: AttributeTrail() : attribute_(absl::nullopt) {} explicit AttributeTrail(std::string variable_name) : attribute_(absl::in_place, std::move(variable_name)) {} explicit AttributeTrail(cel::Attribute attribute) : attribute_(std::move(attribute)) {} // NOLINTNEXTLINE(google-explicit-constructor) AttributeTrail(absl::nullopt_t) : AttributeTrail() {} AttributeTrail(const AttributeTrail&) = default; AttributeTrail& operator=(const AttributeTrail&) = default; AttributeTrail(AttributeTrail&&) = default; AttributeTrail& operator=(AttributeTrail&&) = default; AttributeTrail& operator=(absl::nullopt_t) { attribute_.reset(); return *this; } // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail Step(cel::AttributeQualifier qualifier) const; // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail Step(const std::string* qualifier) const { return Step(cel::AttributeQualifier::OfString(*qualifier)); } // Returns CelAttribute that corresponds to content of AttributeTrail. const cel::Attribute& attribute() const { return attribute_.value(); } bool empty() const { return !attribute_.has_value(); } private: absl::optional attribute_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ ================================================ FILE: eval/eval/attribute_trail_test.cc ================================================ #include "eval/eval/attribute_trail.h" #include #include "cel/expr/syntax.pb.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "internal/testing.h" namespace google::api::expr::runtime { // Attribute Trail behavior TEST(AttributeTrailTest, AttributeTrailEmptyStep) { std::string step = "step"; CelValue step_value = CelValue::CreateString(&step); AttributeTrail trail; ASSERT_TRUE(trail.Step(&step).empty()); ASSERT_TRUE(trail.Step(CreateCelAttributeQualifier(step_value)).empty()); } TEST(AttributeTrailTest, AttributeTrailStep) { std::string step = "step"; CelValue step_value = CelValue::CreateString(&step); AttributeTrail trail = AttributeTrail("ident").Step(&step); ASSERT_EQ(trail.attribute(), CelAttribute("ident", {CreateCelAttributeQualifier(step_value)})); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/attribute_utility.cc ================================================ #include "eval/eval/attribute_utility.h" #include #include #include #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "base/attribute_set.h" #include "base/function_result.h" #include "base/function_result_set.h" #include "base/internal/unknown_set.h" #include "common/casting.h" #include "common/function_descriptor.h" #include "common/unknown.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/internal/errors.h" #include "internal/status_macros.h" #include "runtime/internal/attribute_matcher.h" namespace google::api::expr::runtime { using ::cel::Attribute; using ::cel::AttributePattern; using ::cel::AttributeSet; using ::cel::Cast; using ::cel::ErrorValue; using ::cel::FunctionResult; using ::cel::FunctionResultSet; using ::cel::InstanceOf; using ::cel::UnknownValue; using ::cel::Value; using ::cel::base_internal::UnknownSet; using ::cel::runtime_internal::AttributeMatcher; using Accumulator = AttributeUtility::Accumulator; using MatchResult = AttributeMatcher::MatchResult; DefaultAttributeMatcher::DefaultAttributeMatcher( absl::Span unknown_patterns, absl::Span missing_patterns) : unknown_patterns_(unknown_patterns), missing_patterns_(missing_patterns) {} DefaultAttributeMatcher::DefaultAttributeMatcher() = default; AttributeMatcher::MatchResult MatchAgainstPatterns( absl::Span patterns, const Attribute& attr) { MatchResult result = MatchResult::NONE; for (const auto& pattern : patterns) { auto current_match = pattern.IsMatch(attr); if (current_match == cel::AttributePattern::MatchType::FULL) { return MatchResult::FULL; } if (current_match == cel::AttributePattern::MatchType::PARTIAL) { result = MatchResult::PARTIAL; } } return result; } DefaultAttributeMatcher::MatchResult DefaultAttributeMatcher::CheckForUnknown( const Attribute& attr) const { return MatchAgainstPatterns(unknown_patterns_, attr); } DefaultAttributeMatcher::MatchResult DefaultAttributeMatcher::CheckForMissing( const Attribute& attr) const { return MatchAgainstPatterns(missing_patterns_, attr); } bool AttributeUtility::CheckForMissingAttribute( const AttributeTrail& trail) const { if (trail.empty()) { return false; } // Missing attributes are only treated as errors if the attribute exactly // matches (so no guard against passing partial state to a function as with // unknowns). This was initially a design oversight, but is difficult to // change now. return matcher_->CheckForMissing(trail.attribute()) == AttributeMatcher::MatchResult::FULL; } // Checks whether particular corresponds to any patterns that define unknowns. bool AttributeUtility::CheckForUnknown(const AttributeTrail& trail, bool use_partial) const { if (trail.empty()) { return false; } MatchResult result = matcher_->CheckForUnknown(trail.attribute()); if (result == MatchResult::FULL || (use_partial && result == MatchResult::PARTIAL)) { return true; } return false; } // Creates merged UnknownAttributeSet. // Scans over the args collection, merges any UnknownSets found in // it together with initial_set (if initial_set is not null). // Returns pointer to merged set or nullptr, if there were no sets to merge. absl::optional AttributeUtility::MergeUnknowns( absl::Span args) const { // Empty unknown value may be used as a sentinel in some tests so need to // distinguish unset (nullopt) and empty(engaged empty value). absl::optional result_set; for (const auto& value : args) { if (!value->Is()) continue; if (!result_set.has_value()) { result_set.emplace(); } const auto& current_set = value.GetUnknown(); cel::base_internal::UnknownSetAccess::Add( *result_set, UnknownSet(current_set.attribute_set(), current_set.function_result_set())); } if (!result_set.has_value()) { return absl::nullopt; } return UnknownValue(cel::Unknown(result_set->unknown_attributes(), result_set->unknown_function_results())); } UnknownValue AttributeUtility::MergeUnknownValues( const UnknownValue& left, const UnknownValue& right) const { // Empty unknown value may be used as a sentinel in some tests so need to // distinguish unset (nullopt) and empty(engaged empty value). AttributeSet attributes; FunctionResultSet function_results; attributes.Add(left.attribute_set()); function_results.Add(left.function_result_set()); attributes.Add(right.attribute_set()); function_results.Add(right.function_result_set()); return UnknownValue( cel::Unknown(std::move(attributes), std::move(function_results))); } // Creates merged UnknownAttributeSet. // Scans over the args collection, determines if there matches to unknown // patterns, merges attributes together with those from initial_set // (if initial_set is not null). // Returns pointer to merged set or nullptr, if there were no sets to merge. AttributeSet AttributeUtility::CheckForUnknowns( absl::Span args, bool use_partial) const { AttributeSet attribute_set; for (const auto& trail : args) { if (CheckForUnknown(trail, use_partial)) { attribute_set.Add(trail.attribute()); } } return attribute_set; } // Creates merged UnknownAttributeSet. // Merges together attributes from UnknownAttributeSets found in the args // collection, attributes from attr that match unknown pattern // patterns, and attributes from initial_set // (if initial_set is not null). // Returns pointer to merged set or nullptr, if there were no sets to merge. absl::optional AttributeUtility::IdentifyAndMergeUnknowns( absl::Span args, absl::Span attrs, bool use_partial) const { absl::optional result_set; // Identify new unknowns by attribute patterns. cel::AttributeSet attr_set = CheckForUnknowns(attrs, use_partial); if (!attr_set.empty()) { result_set.emplace(std::move(attr_set)); } // merge down existing unknown sets absl::optional arg_unknowns = MergeUnknowns(args); if (!result_set.has_value()) { // No new unknowns so no need to check for presence of existing unknowns -- // just forward. return arg_unknowns; } if (arg_unknowns.has_value()) { cel::base_internal::UnknownSetAccess::Add( *result_set, UnknownSet((*arg_unknowns).attribute_set(), (*arg_unknowns).function_result_set())); } return UnknownValue(cel::Unknown(result_set->unknown_attributes(), result_set->unknown_function_results())); } UnknownValue AttributeUtility::CreateUnknownSet(cel::Attribute attr) const { return UnknownValue(cel::Unknown(AttributeSet({std::move(attr)}))); } absl::StatusOr AttributeUtility::CreateMissingAttributeError( const cel::Attribute& attr) const { CEL_ASSIGN_OR_RETURN(std::string message, attr.AsString()); return cel::ErrorValue( cel::runtime_internal::CreateMissingAttributeError(message)); } UnknownValue AttributeUtility::CreateUnknownSet( const cel::FunctionDescriptor& fn_descriptor, int64_t expr_id, absl::Span args) const { return UnknownValue( cel::Unknown(FunctionResultSet(FunctionResult(fn_descriptor, expr_id)))); } void AttributeUtility::Add(Accumulator& a, const cel::UnknownValue& v) const { a.attribute_set_.Add(v.attribute_set()); a.function_result_set_.Add(v.function_result_set()); } void AttributeUtility::Add(Accumulator& a, const AttributeTrail& attr) const { a.attribute_set_.Add(attr.attribute()); } void Accumulator::Add(const UnknownValue& value) { unknown_present_ = true; parent_.Add(*this, value); } void Accumulator::Add(const AttributeTrail& attr) { parent_.Add(*this, attr); } void Accumulator::MaybeAdd(const Value& v) { if (v.IsUnknown()) { Add(v.GetUnknown()); } } void Accumulator::MaybeAdd(const Value& v, const AttributeTrail& attr) { if (v.IsUnknown()) { Add(v.GetUnknown()); } else if (parent_.CheckForUnknown(attr, /*use_partial=*/true)) { Add(attr); } } bool Accumulator::IsEmpty() const { return !unknown_present_ && attribute_set_.empty() && function_result_set_.empty(); } cel::UnknownValue Accumulator::Build() && { return cel::UnknownValue( cel::Unknown(std::move(attribute_set_), std::move(function_result_set_))); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/attribute_utility.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "base/attribute_set.h" #include "base/function_result_set.h" #include "common/function_descriptor.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "runtime/internal/attribute_matcher.h" namespace google::api::expr::runtime { // Default implementation of the attribute matcher. // Scans the attribute trail against a list of unknown or missing patterns. class DefaultAttributeMatcher : public cel::runtime_internal::AttributeMatcher { private: using MatchResult = cel::runtime_internal::AttributeMatcher::MatchResult; public: DefaultAttributeMatcher( absl::Span unknown_patterns, absl::Span missing_patterns); DefaultAttributeMatcher(); MatchResult CheckForUnknown(const cel::Attribute& attr) const override; MatchResult CheckForMissing(const cel::Attribute& attr) const override; private: absl::Span unknown_patterns_; absl::Span missing_patterns_; }; // Helper class for handling unknowns and missing attribute logic. Provides // helpers for merging unknown sets from arguments on the stack and for // identifying unknown/missing attributes based on the patterns for a given // Evaluation. // Neither moveable nor copyable. class AttributeUtility { public: class Accumulator { public: Accumulator(const Accumulator&) = delete; Accumulator& operator=(const Accumulator&) = delete; Accumulator(Accumulator&&) = delete; Accumulator& operator=(Accumulator&&) = delete; // Add to the accumulated unknown attributes and functions. void Add(const cel::UnknownValue& v); void Add(const AttributeTrail& attr); // Add to the accumulated set of unknowns if value is UnknownValue. void MaybeAdd(const cel::Value& v); // Add to the accumulated set of unknowns if value is UnknownValue or // the attribute trail is (partially) unknown. This version prefers // preserving an already present unknown value over a new one matching the // attribute trail. // // Uses partial matching (a pattern matches the attribute or any // sub-attribute). void MaybeAdd(const cel::Value& v, const AttributeTrail& attr); bool IsEmpty() const; cel::UnknownValue Build() &&; private: explicit Accumulator(const AttributeUtility& parent) : parent_(parent), unknown_present_(false) {} friend class AttributeUtility; const AttributeUtility& parent_; cel::AttributeSet attribute_set_; cel::FunctionResultSet function_result_set_; // Some tests will use an empty unknown set as a sentinel. // Preserve forwarding behavior. bool unknown_present_; }; AttributeUtility(absl::Span unknown_patterns, absl::Span missing_patterns) : default_matcher_(unknown_patterns, missing_patterns), matcher_(&default_matcher_) {} explicit AttributeUtility( const cel::runtime_internal::AttributeMatcher* absl_nonnull matcher) : matcher_(matcher) {} AttributeUtility(const AttributeUtility&) = delete; AttributeUtility& operator=(const AttributeUtility&) = delete; AttributeUtility(AttributeUtility&&) = delete; AttributeUtility& operator=(AttributeUtility&&) = delete; // Checks whether particular corresponds to any patterns that define missing // attribute. bool CheckForMissingAttribute(const AttributeTrail& trail) const; // Checks whether trail corresponds to any patterns that define unknowns. bool CheckForUnknown(const AttributeTrail& trail, bool use_partial) const; // Checks whether trail corresponds to any patterns that identify // unknowns. Only matches exactly (exact attribute match for self or parent). bool CheckForUnknownExact(const AttributeTrail& trail) const { return CheckForUnknown(trail, false); } // Checks whether trail corresponds to any patterns that define unknowns. // Matches if a parent or any descendant (select or index of) the attribute. bool CheckForUnknownPartial(const AttributeTrail& trail) const { return CheckForUnknown(trail, true); } // Creates merged UnknownAttributeSet. // Scans over the args collection, determines if there matches to unknown // patterns and returns the (possibly empty) collection. cel::AttributeSet CheckForUnknowns(absl::Span args, bool use_partial) const; // Creates merged UnknownValue. // Scans over the args collection, merges any UnknownValues found. // Returns the merged UnknownValue or nullopt if not found. absl::optional MergeUnknowns( absl::Span args) const; // Creates a merged UnknownValue from two unknown values. cel::UnknownValue MergeUnknownValues(const cel::UnknownValue& left, const cel::UnknownValue& right) const; // Creates merged UnknownValue. // Merges together UnknownValues found in the args // along with attributes from attr that match the configured unknown patterns // Returns returns the merged UnknownValue if available or nullopt. absl::optional IdentifyAndMergeUnknowns( absl::Span args, absl::Span attrs, bool use_partial) const; // Create an initial UnknownSet from a single attribute. cel::UnknownValue CreateUnknownSet(cel::Attribute attr) const; // Factory function for missing attribute errors. absl::StatusOr CreateMissingAttributeError( const cel::Attribute& attr) const; // Create an initial UnknownSet from a single missing function call. cel::UnknownValue CreateUnknownSet( const cel::FunctionDescriptor& fn_descriptor, int64_t expr_id, absl::Span args) const; Accumulator CreateAccumulator() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return Accumulator(*this); } void set_matcher( const cel::runtime_internal::AttributeMatcher* absl_nonnull matcher) { matcher_ = matcher; } private: // Workaround friend visibility. void Add(Accumulator& a, const cel::UnknownValue& v) const; void Add(Accumulator& a, const AttributeTrail& attr) const; DefaultAttributeMatcher default_matcher_; const cel::runtime_internal::AttributeMatcher* absl_nonnull matcher_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ ================================================ FILE: eval/eval/attribute_utility_test.cc ================================================ #include "eval/eval/attribute_utility.h" #include #include #include "absl/types/span.h" #include "base/attribute.h" #include "base/attribute_set.h" #include "common/unknown.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" #include "internal/testing.h" #include "runtime/internal/attribute_matcher.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { using ::cel::AttributeSet; using ::cel::UnknownValue; using ::cel::Value; using ::testing::Eq; using ::testing::SizeIs; using ::testing::UnorderedPointwise; class AttributeUtilityTest : public ::testing::Test { public: AttributeUtilityTest() = default; protected: google::protobuf::Arena arena_; }; absl::Span NoPatterns() { return {}; } TEST_F(AttributeUtilityTest, UnknownsUtilityCheckUnknowns) { std::vector unknown_patterns = { CelAttributePattern("unknown0", {CreateCelAttributeQualifierPattern( CelValue::CreateInt64(1))}), CelAttributePattern("unknown0", {CreateCelAttributeQualifierPattern( CelValue::CreateInt64(2))}), CelAttributePattern("unknown1", {}), CelAttributePattern("unknown2", {}), }; std::vector missing_attribute_patterns; AttributeUtility utility(unknown_patterns, missing_attribute_patterns); // no match for void trail ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), true)); ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), false)); AttributeTrail unknown_trail0("unknown0"); { ASSERT_FALSE(utility.CheckForUnknown(unknown_trail0, false)); } { ASSERT_TRUE(utility.CheckForUnknown(unknown_trail0, true)); } { ASSERT_TRUE(utility.CheckForUnknown( unknown_trail0.Step( CreateCelAttributeQualifier(CelValue::CreateInt64(1))), false)); } { ASSERT_TRUE(utility.CheckForUnknown( unknown_trail0.Step( CreateCelAttributeQualifier(CelValue::CreateInt64(1))), true)); } } TEST_F(AttributeUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { std::vector unknown_patterns; std::vector missing_attribute_patterns; CelAttribute attribute0("unknown0", {}); CelAttribute attribute1("unknown1", {}); AttributeUtility utility(unknown_patterns, missing_attribute_patterns); UnknownValue unknown_set0 = cel::UnknownValue(cel::Unknown(AttributeSet({attribute0}))); UnknownValue unknown_set1 = cel::UnknownValue(cel::Unknown(AttributeSet({attribute1}))); std::vector values = { unknown_set0, unknown_set1, cel::BoolValue(true), cel::IntValue(1), }; absl::optional unknown_set = utility.MergeUnknowns(values); ASSERT_TRUE(unknown_set.has_value()); EXPECT_THAT((*unknown_set).attribute_set(), UnorderedPointwise( Eq(), std::vector{attribute0, attribute1})); } TEST_F(AttributeUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { std::vector unknown_patterns = { CelAttributePattern("unknown0", {CelAttributeQualifierPattern::CreateWildcard()}), }; std::vector missing_attribute_patterns; AttributeTrail trail0("unknown0"); AttributeTrail trail1("unknown1"); CelAttribute attribute1("unknown1", {}); UnknownSet unknown_set1(UnknownAttributeSet({attribute1})); AttributeUtility utility(unknown_patterns, missing_attribute_patterns); UnknownSet unknown_attr_set(utility.CheckForUnknowns( { AttributeTrail(), // To make sure we handle empty trail gracefully. trail0.Step(CreateCelAttributeQualifier(CelValue::CreateInt64(1))), trail0.Step(CreateCelAttributeQualifier(CelValue::CreateInt64(2))), }, false)); UnknownSet unknown_set(unknown_set1, unknown_attr_set); ASSERT_THAT(unknown_set.unknown_attributes(), SizeIs(3)); } TEST_F(AttributeUtilityTest, UnknownsUtilityCheckForMissingAttributes) { std::vector unknown_patterns; std::vector missing_attribute_patterns; AttributeTrail trail("destination"); trail = trail.Step(CreateCelAttributeQualifier(CelValue::CreateStringView("ip"))); AttributeUtility utility0(unknown_patterns, missing_attribute_patterns); EXPECT_FALSE(utility0.CheckForMissingAttribute(trail)); missing_attribute_patterns.push_back(CelAttributePattern( "destination", {CreateCelAttributeQualifierPattern(CelValue::CreateStringView("ip"))})); AttributeUtility utility1(unknown_patterns, missing_attribute_patterns); EXPECT_TRUE(utility1.CheckForMissingAttribute(trail)); } TEST_F(AttributeUtilityTest, CreateUnknownSet) { AttributeTrail trail("destination"); trail = trail.Step(CreateCelAttributeQualifier(CelValue::CreateStringView("ip"))); std::vector empty_patterns; AttributeUtility utility(empty_patterns, empty_patterns); UnknownValue set = utility.CreateUnknownSet(trail.attribute()); ASSERT_THAT(set.attribute_set(), SizeIs(1)); ASSERT_OK_AND_ASSIGN(auto elem, set.attribute_set().begin()->AsString()); EXPECT_EQ(elem, "destination.ip"); } class FakeMatcher : public cel::runtime_internal::AttributeMatcher { private: using MatchResult = cel::runtime_internal::AttributeMatcher::MatchResult; public: MatchResult CheckForUnknown(const cel::Attribute& attr) const override { std::string attr_str = attr.AsString().value_or(""); if (attr_str == "device.foo") { return MatchResult::FULL; } else if (attr_str == "device") { return MatchResult::PARTIAL; } return MatchResult::NONE; } MatchResult CheckForMissing(const cel::Attribute& attr) const override { std::string attr_str = attr.AsString().value_or(""); if (attr_str == "device2.foo") { return MatchResult::FULL; } else if (attr_str == "device2") { return MatchResult::PARTIAL; } return MatchResult::NONE; } }; TEST_F(AttributeUtilityTest, CustomMatcher) { AttributeTrail trail("device"); AttributeUtility utility(NoPatterns(), NoPatterns()); FakeMatcher matcher; utility.set_matcher(&matcher); EXPECT_TRUE(utility.CheckForUnknownPartial(trail)); EXPECT_FALSE(utility.CheckForUnknownExact(trail)); trail = trail.Step(cel::AttributeQualifier::OfString("foo")); EXPECT_TRUE(utility.CheckForUnknownExact(trail)); EXPECT_TRUE(utility.CheckForUnknownPartial(trail)); trail = AttributeTrail("device2"); EXPECT_FALSE(utility.CheckForMissingAttribute(trail)); trail = trail.Step(cel::AttributeQualifier::OfString("foo")); EXPECT_TRUE(utility.CheckForMissingAttribute(trail)); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/cel_expression_flat_impl.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/cel_expression_flat_impl.h" #include #include #include #include "absl/base/nullability.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "common/native_type.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/comprehension_slots.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/internal/adapter_activation_impl.h" #include "eval/internal/interop.h" #include "eval/public/base_activation.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "internal/casts.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_env.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::cel::Value; using ::cel::runtime_internal::RuntimeEnv; EvaluationListener AdaptListener(const CelEvaluationListener& listener) { if (!listener) return nullptr; return [&](int64_t expr_id, const Value& value, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) -> absl::Status { if (value->Is()) { // Opaque types are used to implement some optimized operations. // These aren't representable as legacy values and shouldn't be // inspectable by clients. return absl::OkStatus(); } CelValue legacy_value = cel::interop_internal::ModernValueToLegacyValueOrDie(arena, value); return listener(expr_id, legacy_value, arena); }; } } // namespace CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( google::protobuf::Arena* arena, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, const FlatExpression& expression) : state_(expression.MakeEvaluatorState(descriptor_pool, message_factory, arena)) {} absl::StatusOr CelExpressionFlatImpl::Trace( const BaseActivation& activation, CelEvaluationState* _state, CelEvaluationListener callback) const { auto state = ::cel::internal::down_cast(_state); state->state().Reset(); cel::interop_internal::AdapterActivationImpl modern_activation(activation); CEL_ASSIGN_OR_RETURN(cel::Value value, flat_expression_.EvaluateWithCallback( modern_activation, /*embedder_context=*/nullptr, AdaptListener(callback), state->state())); return cel::interop_internal::ModernValueToLegacyValueOrDie(state->arena(), value); } std::unique_ptr CelExpressionFlatImpl::InitializeState( google::protobuf::Arena* arena) const { return std::make_unique( arena, env_->descriptor_pool.get(), env_->MutableMessageFactory(), flat_expression_); } absl::StatusOr CelExpressionFlatImpl::Evaluate( const BaseActivation& activation, CelEvaluationState* state) const { return Trace(activation, state, CelEvaluationListener()); } absl::StatusOr> CelExpressionRecursiveImpl::Create( absl_nonnull std::shared_ptr env, FlatExpression flat_expr) { if (flat_expr.path().empty() || flat_expr.path().front()->GetNativeTypeId() != cel::NativeTypeId::For()) { return absl::InvalidArgumentError(absl::StrCat( "Expected a recursive program step", flat_expr.path().size())); } auto* instance = new CelExpressionRecursiveImpl(std::move(env), std::move(flat_expr)); return absl::WrapUnique(instance); } absl::StatusOr CelExpressionRecursiveImpl::Trace( const BaseActivation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback) const { cel::interop_internal::AdapterActivationImpl modern_activation(activation); ComprehensionSlots slots(flat_expression_.comprehension_slots_size()); ExecutionFrameBase execution_frame( modern_activation, AdaptListener(callback), flat_expression_.options(), flat_expression_.type_provider(), env_->descriptor_pool.get(), env_->MutableMessageFactory(), arena, /*embedder_context=*/nullptr, slots); cel::Value result; AttributeTrail trail; CEL_RETURN_IF_ERROR(root_->Evaluate(execution_frame, result, trail)); return cel::interop_internal::ModernValueToLegacyValueOrDie(arena, result); } absl::StatusOr CelExpressionRecursiveImpl::Evaluate( const BaseActivation& activation, google::protobuf::Arena* arena) const { return Trace(activation, arena, /*callback=*/nullptr); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/cel_expression_flat_impl.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CEL_EXPRESSION_FLAT_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CEL_EXPRESSION_FLAT_IMPL_H_ #include #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/public/base_activation.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "internal/casts.h" #include "runtime/internal/runtime_env.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { // Wrapper for FlatExpressionEvaluationState used to implement CelExpression. class CelExpressionFlatEvaluationState : public CelEvaluationState { public: CelExpressionFlatEvaluationState( google::protobuf::Arena* arena, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, const FlatExpression& expr); google::protobuf::Arena* arena() { return state_.arena(); } FlatExpressionEvaluatorState& state() { return state_; } private: FlatExpressionEvaluatorState state_; }; // Implementation of the CelExpression that evaluates a flattened representation // of the AST. // // This class adapts FlatExpression to implement the CelExpression interface. class CelExpressionFlatImpl : public CelExpression { public: CelExpressionFlatImpl( absl_nonnull std::shared_ptr env, FlatExpression flat_expression) : env_(std::move(env)), flat_expression_(std::move(flat_expression)) {} // Move-only CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; CelExpressionFlatImpl& operator=(const CelExpressionFlatImpl&) = delete; CelExpressionFlatImpl(CelExpressionFlatImpl&&) = default; CelExpressionFlatImpl& operator=(CelExpressionFlatImpl&&) = delete; // Implement CelExpression. std::unique_ptr InitializeState( google::protobuf::Arena* arena) const override; absl::StatusOr Evaluate(const BaseActivation& activation, google::protobuf::Arena* arena) const override { return Evaluate(activation, InitializeState(arena).get()); } absl::StatusOr Evaluate(const BaseActivation& activation, CelEvaluationState* state) const override; absl::StatusOr Trace( const BaseActivation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback) const override { return Trace(activation, InitializeState(arena).get(), callback); } absl::StatusOr Trace(const BaseActivation& activation, CelEvaluationState* state, CelEvaluationListener callback) const override; // Exposed for inspection in tests. const FlatExpression& flat_expression() const { return flat_expression_; } private: absl_nonnull std::shared_ptr env_; FlatExpression flat_expression_; }; // Implementation of the CelExpression that evaluates a recursive representation // of the AST. // // This class adapts FlatExpression to implement the CelExpression interface. // // Assumes that the flat expression is wrapping a simple recursive program. class CelExpressionRecursiveImpl : public CelExpression { private: class EvaluationState : public CelEvaluationState { public: explicit EvaluationState(google::protobuf::Arena* arena) : arena_(arena) {} google::protobuf::Arena* arena() { return arena_; } private: google::protobuf::Arena* arena_; }; public: static absl::StatusOr> Create( absl_nonnull std::shared_ptr env, FlatExpression flat_expression); // Move-only CelExpressionRecursiveImpl(const CelExpressionRecursiveImpl&) = delete; CelExpressionRecursiveImpl& operator=(const CelExpressionRecursiveImpl&) = delete; CelExpressionRecursiveImpl(CelExpressionRecursiveImpl&&) = default; CelExpressionRecursiveImpl& operator=(CelExpressionRecursiveImpl&&) = delete; // Implement CelExpression. std::unique_ptr InitializeState( google::protobuf::Arena* arena) const override { return std::make_unique(arena); } absl::StatusOr Evaluate(const BaseActivation& activation, google::protobuf::Arena* arena) const override; absl::StatusOr Evaluate(const BaseActivation& activation, CelEvaluationState* state) const override { auto* state_impl = cel::internal::down_cast(state); return Evaluate(activation, state_impl->arena()); } absl::StatusOr Trace(const BaseActivation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback) const override; absl::StatusOr Trace( const BaseActivation& activation, CelEvaluationState* state, CelEvaluationListener callback) const override { auto* state_impl = cel::internal::down_cast(state); return Trace(activation, state_impl->arena(), callback); } // Exposed for inspection in tests. const FlatExpression& flat_expression() const { return flat_expression_; } const DirectExpressionStep* root() const { return root_; } private: explicit CelExpressionRecursiveImpl( absl_nonnull std::shared_ptr env, FlatExpression flat_expression) : env_(std::move(env)), flat_expression_(std::move(flat_expression)), root_(cel::internal::down_cast( flat_expression_.path()[0].get()) ->wrapped()) {} absl_nonnull std::shared_ptr env_; FlatExpression flat_expression_; const DirectExpressionStep* root_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CEL_EXPRESSION_FLAT_IMPL_H_ ================================================ FILE: eval/eval/compiler_constant_step.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/compiler_constant_step.h" #include "absl/status/status.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { using ::cel::Value; absl::Status DirectCompilerConstantStep::Evaluate( ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { result = value_; return absl::OkStatus(); } absl::Status CompilerConstantStep::Evaluate(ExecutionFrame* frame) const { frame->value_stack().Push(value_); return absl::OkStatus(); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/compiler_constant_step.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ #include #include #include "absl/status/status.h" #include "common/native_type.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" namespace google::api::expr::runtime { // DirectExpressionStep implementation that simply assigns a constant value. // // Overrides NativeTypeId() allow the FlatExprBuilder and extensions to // inspect the underlying value. class DirectCompilerConstantStep : public DirectExpressionStep { public: DirectCompilerConstantStep(cel::Value value, int64_t expr_id) : DirectExpressionStep(expr_id), value_(std::move(value)) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& attribute) const override; cel::NativeTypeId GetNativeTypeId() const override { return cel::NativeTypeId::For(); } const cel::Value& value() const { return value_; } private: cel::Value value_; }; // ExpressionStep implementation that simply pushes a constant value on the // stack. // // Overrides NativeTypeId ()o allow the FlatExprBuilder and extensions to // inspect the underlying value. class CompilerConstantStep : public ExpressionStepBase { public: CompilerConstantStep(cel::Value value, int64_t expr_id, bool comes_from_ast) : ExpressionStepBase(expr_id, comes_from_ast), value_(std::move(value)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; cel::NativeTypeId GetNativeTypeId() const override { return cel::NativeTypeId::For(); } const cel::Value& value() const { return value_; } private: cel::Value value_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ ================================================ FILE: eval/eval/compiler_constant_step_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/compiler_constant_step.h" #include #include "common/native_type.h" #include "common/value.h" #include "eval/eval/evaluator_core.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { class CompilerConstantStepTest : public testing::Test { public: CompilerConstantStepTest() : type_provider_(cel::internal::GetTestingDescriptorPool()), state_(2, 0, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_) {} protected: google::protobuf::Arena arena_; cel::runtime_internal::RuntimeTypeProvider type_provider_; FlatExpressionEvaluatorState state_; cel::Activation empty_activation_; cel::RuntimeOptions options_; }; TEST_F(CompilerConstantStepTest, Evaluate) { ExecutionPath path; path.push_back( std::make_unique(cel::IntValue(42), -1, false)); ExecutionFrame frame(path, empty_activation_, options_, state_); ASSERT_OK_AND_ASSIGN(cel::Value result, frame.Evaluate()); EXPECT_EQ(result.GetInt().NativeValue(), 42); } TEST_F(CompilerConstantStepTest, TypeId) { CompilerConstantStep step(cel::IntValue(42), -1, false); ExpressionStep& abstract_step = step; EXPECT_EQ(abstract_step.GetNativeTypeId(), cel::NativeTypeId::For()); } TEST_F(CompilerConstantStepTest, Value) { CompilerConstantStep step(cel::IntValue(42), -1, false); EXPECT_EQ(step.value().GetInt().NativeValue(), 42); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/comprehension_slots.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_SLOTS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_SLOTS_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/container/fixed_array.h" #include "absl/log/absl_check.h" #include "absl/types/optional.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" namespace google::api::expr::runtime { class ComprehensionSlot final { public: ComprehensionSlot() = default; ComprehensionSlot(const ComprehensionSlot&) = delete; ComprehensionSlot(ComprehensionSlot&&) = delete; ComprehensionSlot& operator=(const ComprehensionSlot&) = delete; ComprehensionSlot& operator=(ComprehensionSlot&&) = delete; const cel::Value& value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Has()); return value_; } cel::Value* absl_nonnull mutable_value() ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Has()); return &value_; } const AttributeTrail& attribute() const ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Has()); return attribute_; } AttributeTrail* absl_nonnull mutable_attribute() ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK(Has()); return &attribute_; } bool Has() const { return has_; } void Set() { Set(cel::NullValue(), absl::nullopt); } template void Set(V&& value) { Set(std::forward(value), absl::nullopt); } template void Set(V&& value, A&& attribute) { value_ = std::forward(value); attribute_ = std::forward(attribute); has_ = true; } void Clear() { if (has_) { value_ = cel::NullValue(); attribute_ = absl::nullopt; has_ = false; } } private: cel::Value value_; AttributeTrail attribute_; bool has_ = false; }; // Simple manager for comprehension variables. // // At plan time, each comprehension variable is assigned a slot by index. // This is used instead of looking up the variable identifier by name in a // runtime stack. // // Callers must handle range checking. class ComprehensionSlots final { public: using Slot = ComprehensionSlot; // Trivial instance if no slots are needed. // Trivially thread safe since no effective state. static ComprehensionSlots& GetEmptyInstance() { static absl::NoDestructor instance(0); return *instance; } explicit ComprehensionSlots(size_t size) : slots_(size) {} ComprehensionSlots(const ComprehensionSlots&) = delete; ComprehensionSlots& operator=(const ComprehensionSlots&) = delete; ComprehensionSlots(ComprehensionSlots&&) = delete; ComprehensionSlots& operator=(ComprehensionSlots&&) = delete; Slot* absl_nonnull Get(size_t index) ABSL_ATTRIBUTE_LIFETIME_BOUND { ABSL_DCHECK_LT(index, size()); return &slots_[index]; } void Reset() { for (Slot& slot : slots_) { slot.Clear(); } } void ClearSlot(size_t index) { Get(index)->Clear(); } template void Set(size_t index, V&& value) { Set(index, std::forward(value), absl::nullopt); } template void Set(size_t index, V&& value, A&& attribute) { Get(index)->Set(std::forward(value), std::forward(attribute)); } size_t size() const { return slots_.size(); } private: absl::FixedArray slots_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_SLOTS_H_ ================================================ FILE: eval/eval/comprehension_slots_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/comprehension_slots.h" #include "base/attribute.h" #include "base/type_provider.h" #include "common/memory.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "internal/testing.h" namespace google::api::expr::runtime { using ::cel::Attribute; using ::absl_testing::IsOkAndHolds; using ::cel::MemoryManagerRef; using ::cel::StringValue; using ::cel::TypeProvider; using ::cel::Value; using ::testing::Truly; TEST(ComprehensionSlots, Basic) { ComprehensionSlots slots(4); ComprehensionSlots::Slot* slot0 = slots.Get(0); EXPECT_FALSE(slot0->Has()); slots.Set(0, cel::StringValue("abcd"), AttributeTrail(Attribute("fake_attr"))); ASSERT_TRUE(slot0->Has()); EXPECT_THAT(slot0->value(), Truly([](const Value& v) { return v.Is() && v.GetString().ToString() == "abcd"; })) << "value is 'abcd'"; EXPECT_THAT(slot0->attribute().attribute().AsString(), IsOkAndHolds("fake_attr")); slots.ClearSlot(0); EXPECT_FALSE(slot0->Has()); slots.Set(3, cel::StringValue("abcd"), AttributeTrail(Attribute("fake_attr"))); auto* slot3 = slots.Get(3); ASSERT_TRUE(slot3->Has()); EXPECT_THAT(slot3->value(), Truly([](const Value& v) { return v.Is() && v.GetString().ToString() == "abcd"; })) << "value is 'abcd'"; slots.Reset(); EXPECT_FALSE(slot0->Has()); EXPECT_FALSE(slot3->Has()); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/comprehension_step.cc ================================================ #include "eval/eval/comprehension_step.h" #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/attribute.h" #include "common/casting.h" #include "common/value.h" #include "common/value_kind.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/comprehension_slots.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { enum class IterableKind { kList = 1, kMap, }; using ::cel::AttributeQualifier; using ::cel::Cast; using ::cel::InstanceOf; using ::cel::UnknownValue; using ::cel::Value; using ::cel::ValueIterator; using ::cel::ValueIteratorPtr; using ::cel::ValueKind; using ::cel::runtime_internal::CreateNoMatchingOverloadError; AttributeQualifier AttributeQualifierFromValue(const Value& v) { switch (v.kind()) { case ValueKind::kString: return AttributeQualifier::OfString(v.GetString().ToString()); case ValueKind::kInt64: return AttributeQualifier::OfInt(v.GetInt().NativeValue()); case ValueKind::kUint64: return AttributeQualifier::OfUint(v.GetUint().NativeValue()); case ValueKind::kBool: return AttributeQualifier::OfBool(v.GetBool().NativeValue()); default: // Non-matching qualifier. return AttributeQualifier(); } } class ComprehensionFinishStep final : public ExpressionStepBase { public: ComprehensionFinishStep(size_t accu_slot, int64_t expr_id) : ExpressionStepBase(expr_id), accu_slot_(accu_slot) {} absl::Status Evaluate(ExecutionFrame* frame) const override { if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } frame->value_stack().SwapAndPop(2, 1); frame->comprehension_slots().ClearSlot(accu_slot_); frame->iterator_stack().Pop(); return absl::OkStatus(); } private: const size_t accu_slot_; }; class ComprehensionDirectStep final : public DirectExpressionStep { public: explicit ComprehensionDirectStep( size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, std::unique_ptr condition_step, std::unique_ptr result_step, bool shortcircuiting, int64_t expr_id) : DirectExpressionStep(expr_id), iter_slot_(iter_slot), iter2_slot_(iter2_slot), accu_slot_(accu_slot), range_(std::move(range)), accu_init_(std::move(accu_init)), loop_step_(std::move(loop_step)), condition_(std::move(condition_step)), result_step_(std::move(result_step)), shortcircuiting_(shortcircuiting) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const override { return iter_slot_ == iter2_slot_ ? Evaluate1(frame, result, trail) : Evaluate2(frame, result, trail); } private: absl::Status Evaluate1(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const; absl::StatusOr Evaluate1Unknown( ExecutionFrameBase& frame, IterableKind range_iter_kind, const AttributeTrail& range_iter_attr, ValueIterator* absl_nonnull range_iter, ComprehensionSlots::Slot* absl_nonnull accu_slot, ComprehensionSlots::Slot* absl_nonnull iter_slot, Value& result, AttributeTrail& trail) const; absl::StatusOr Evaluate1Known( ExecutionFrameBase& frame, ValueIterator* absl_nonnull range_iter, ComprehensionSlots::Slot* absl_nonnull accu_slot, ComprehensionSlots::Slot* absl_nonnull iter_slot, Value& result, AttributeTrail& trail) const; absl::Status Evaluate2(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const; const size_t iter_slot_; const size_t iter2_slot_; const size_t accu_slot_; const std::unique_ptr range_; const std::unique_ptr accu_init_; const std::unique_ptr loop_step_; const std::unique_ptr condition_; const std::unique_ptr result_step_; const bool shortcircuiting_; }; absl::Status ComprehensionDirectStep::Evaluate1(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const { Value range; AttributeTrail range_attr; CEL_RETURN_IF_ERROR(range_->Evaluate(frame, range, range_attr)); if (frame.unknown_processing_enabled() && range.IsMap()) { if (frame.attribute_utility().CheckForUnknownPartial(range_attr)) { result = frame.attribute_utility().CreateUnknownSet(range_attr.attribute()); return absl::OkStatus(); } } absl_nullability_unknown ValueIteratorPtr range_iter; IterableKind iterable_kind; switch (range.kind()) { case ValueKind::kList: { CEL_ASSIGN_OR_RETURN(range_iter, range.GetList().NewIterator()); iterable_kind = IterableKind::kList; } break; case ValueKind::kMap: { CEL_ASSIGN_OR_RETURN(range_iter, range.GetMap().NewIterator()); iterable_kind = IterableKind::kMap; } break; case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: result = std::move(range); return absl::OkStatus(); default: result = cel::ErrorValue(CreateNoMatchingOverloadError("")); return absl::OkStatus(); } ABSL_DCHECK(range_iter != nullptr); ComprehensionSlots::Slot* accu_slot = frame.comprehension_slots().Get(accu_slot_); ABSL_DCHECK(accu_slot != nullptr); { Value accu_init; AttributeTrail accu_init_attr; CEL_RETURN_IF_ERROR(accu_init_->Evaluate(frame, accu_init, accu_init_attr)); accu_slot->Set(std::move(accu_init), std::move(accu_init_attr)); } ComprehensionSlots::Slot* iter_slot = frame.comprehension_slots().Get(iter_slot_); ABSL_DCHECK(iter_slot != nullptr); iter_slot->Set(); bool should_skip_result; if (frame.unknown_processing_enabled()) { CEL_ASSIGN_OR_RETURN( should_skip_result, Evaluate1Unknown(frame, iterable_kind, range_attr, range_iter.get(), accu_slot, iter_slot, result, trail)); } else { CEL_ASSIGN_OR_RETURN(should_skip_result, Evaluate1Known(frame, range_iter.get(), accu_slot, iter_slot, result, trail)); } frame.comprehension_slots().ClearSlot(iter_slot_); if (!should_skip_result) { CEL_RETURN_IF_ERROR(result_step_->Evaluate(frame, result, trail)); } frame.comprehension_slots().ClearSlot(accu_slot_); return absl::OkStatus(); } absl::StatusOr ComprehensionDirectStep::Evaluate1Unknown( ExecutionFrameBase& frame, IterableKind range_iter_kind, const AttributeTrail& range_iter_attr, ValueIterator* absl_nonnull range_iter, ComprehensionSlots::Slot* absl_nonnull accu_slot, ComprehensionSlots::Slot* absl_nonnull iter_slot, Value& result, AttributeTrail& trail) const { Value condition; AttributeTrail condition_attr; Value key_or_value; Value* key; Value* value; switch (range_iter_kind) { case IterableKind::kList: key = &key_or_value; value = iter_slot->mutable_value(); break; case IterableKind::kMap: key = iter_slot->mutable_value(); value = nullptr; break; default: ABSL_UNREACHABLE(); } while (true) { CEL_ASSIGN_OR_RETURN(bool ok, range_iter->Next2(frame.descriptor_pool(), frame.message_factory(), frame.arena(), key, value)); if (!ok) { break; } CEL_RETURN_IF_ERROR(frame.IncrementIterations()); *iter_slot->mutable_attribute() = range_iter_attr.Step(AttributeQualifierFromValue(*key)); if (frame.attribute_utility().CheckForUnknownExact( iter_slot->attribute())) { *iter_slot->mutable_value() = frame.attribute_utility().CreateUnknownSet( iter_slot->attribute().attribute()); } // Evaluate the loop condition. CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); switch (condition.kind()) { case ValueKind::kBool: break; case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: result = std::move(condition); return true; default: result = cel::ErrorValue(CreateNoMatchingOverloadError("")); return true; } if (shortcircuiting_ && !absl::implicit_cast(condition.GetBool())) { break; } // Evaluate the loop step. CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, *accu_slot->mutable_value(), *accu_slot->mutable_attribute())); } return false; } absl::StatusOr ComprehensionDirectStep::Evaluate1Known( ExecutionFrameBase& frame, ValueIterator* absl_nonnull range_iter, ComprehensionSlots::Slot* absl_nonnull accu_slot, ComprehensionSlots::Slot* absl_nonnull iter_slot, Value& result, AttributeTrail& trail) const { Value condition; AttributeTrail condition_attr; while (true) { CEL_ASSIGN_OR_RETURN( bool ok, range_iter->Next1(frame.descriptor_pool(), frame.message_factory(), frame.arena(), iter_slot->mutable_value())); if (!ok) { break; } CEL_RETURN_IF_ERROR(frame.IncrementIterations()); // Evaluate the loop condition. CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); switch (condition.kind()) { case ValueKind::kBool: break; case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: result = std::move(condition); return true; default: result = cel::ErrorValue(CreateNoMatchingOverloadError("")); return true; } if (shortcircuiting_ && !absl::implicit_cast(condition.GetBool())) { break; } // Evaluate the loop step. CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, *accu_slot->mutable_value(), *accu_slot->mutable_attribute())); } return false; } absl::Status ComprehensionDirectStep::Evaluate2(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const { Value range; AttributeTrail range_attr; CEL_RETURN_IF_ERROR(range_->Evaluate(frame, range, range_attr)); if (frame.unknown_processing_enabled() && range.IsMap()) { if (frame.attribute_utility().CheckForUnknownPartial(range_attr)) { result = frame.attribute_utility().CreateUnknownSet(range_attr.attribute()); return absl::OkStatus(); } } absl_nullability_unknown ValueIteratorPtr range_iter; switch (range.kind()) { case ValueKind::kList: { CEL_ASSIGN_OR_RETURN(range_iter, range.GetList().NewIterator()); } break; case ValueKind::kMap: { CEL_ASSIGN_OR_RETURN(range_iter, range.GetMap().NewIterator()); } break; case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: result = std::move(range); return absl::OkStatus(); default: result = cel::ErrorValue(CreateNoMatchingOverloadError("")); return absl::OkStatus(); } ABSL_DCHECK(range_iter != nullptr); ComprehensionSlots::Slot* accu_slot = frame.comprehension_slots().Get(accu_slot_); ABSL_DCHECK(accu_slot != nullptr); { Value accu_init; AttributeTrail accu_init_attr; CEL_RETURN_IF_ERROR(accu_init_->Evaluate(frame, accu_init, accu_init_attr)); accu_slot->Set(std::move(accu_init), std::move(accu_init_attr)); } ComprehensionSlots::Slot* iter_slot = frame.comprehension_slots().Get(iter_slot_); ABSL_DCHECK(iter_slot != nullptr); iter_slot->Set(); ComprehensionSlots::Slot* iter2_slot = frame.comprehension_slots().Get(iter2_slot_); ABSL_DCHECK(iter2_slot != nullptr); iter2_slot->Set(); Value condition; AttributeTrail condition_attr; bool should_skip_result = false; while (true) { CEL_ASSIGN_OR_RETURN( bool ok, range_iter->Next2(frame.descriptor_pool(), frame.message_factory(), frame.arena(), iter_slot->mutable_value(), iter2_slot->mutable_value())); if (!ok) { break; } CEL_RETURN_IF_ERROR(frame.IncrementIterations()); if (frame.unknown_processing_enabled()) { *iter_slot->mutable_attribute() = *iter2_slot->mutable_attribute() = range_attr.Step(AttributeQualifierFromValue(iter_slot->value())); if (frame.attribute_utility().CheckForUnknownExact( iter_slot->attribute())) { *iter2_slot->mutable_value() = frame.attribute_utility().CreateUnknownSet( iter_slot->attribute().attribute()); } } // Evaluate the loop condition. CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); switch (condition.kind()) { case ValueKind::kBool: break; case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: result = std::move(condition); should_skip_result = true; goto finish; default: result = cel::ErrorValue(CreateNoMatchingOverloadError("")); should_skip_result = true; goto finish; } if (shortcircuiting_ && !absl::implicit_cast(condition.GetBool())) { break; } // Evaluate the loop step. CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, *accu_slot->mutable_value(), *accu_slot->mutable_attribute())); } finish: iter_slot->Clear(); iter2_slot->Clear(); if (!should_skip_result) { CEL_RETURN_IF_ERROR(result_step_->Evaluate(frame, result, trail)); } accu_slot->Clear(); return absl::OkStatus(); } } // namespace absl::Status ComprehensionInitStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } const Value& top = frame->value_stack().Peek(); if (top.IsError() || top.IsUnknown()) { return frame->JumpTo(error_jump_offset_); } if (frame->enable_unknowns() && top.IsMap()) { const AttributeTrail& top_attr = frame->value_stack().PeekAttribute(); if (frame->attribute_utility().CheckForUnknownPartial(top_attr)) { frame->value_stack().PopAndPush( frame->attribute_utility().CreateUnknownSet(top_attr.attribute())); return frame->JumpTo(error_jump_offset_); } } switch (top.kind()) { case ValueKind::kList: { CEL_ASSIGN_OR_RETURN(auto iterator, top.GetList().NewIterator()); frame->iterator_stack().Push(std::move(iterator)); } break; case ValueKind::kMap: { CEL_ASSIGN_OR_RETURN(auto iterator, top.GetMap().NewIterator()); frame->iterator_stack().Push(std::move(iterator)); } break; default: // Replace with an error and jump past // ComprehensionFinishStep. frame->value_stack().PopAndPush( cel::ErrorValue(CreateNoMatchingOverloadError(""))); return frame->JumpTo(error_jump_offset_); } return absl::OkStatus(); } absl::Status ComprehensionNextStep::Evaluate1(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } { Value& accu_var = frame->value_stack().Peek(); AttributeTrail& accu_var_attr = frame->value_stack().PeekAttribute(); frame->comprehension_slots().Set(accu_slot_, std::move(accu_var), std::move(accu_var_attr)); frame->value_stack().Pop(1); } ComprehensionSlots::Slot* iter_slot = frame->comprehension_slots().Get(iter_slot_); ABSL_DCHECK(iter_slot != nullptr); iter_slot->Set(); if (frame->enable_unknowns()) { Value key_or_value; Value* key; Value* value; switch (frame->value_stack().Peek().kind()) { case ValueKind::kList: key = &key_or_value; value = iter_slot->mutable_value(); break; case ValueKind::kMap: key = iter_slot->mutable_value(); value = nullptr; break; default: ABSL_UNREACHABLE(); } CEL_ASSIGN_OR_RETURN(bool ok, frame->iterator_stack().Peek()->Next2( frame->descriptor_pool(), frame->message_factory(), frame->arena(), key, value)); if (!ok) { iter_slot->Clear(); return frame->JumpTo(jump_offset_); } CEL_RETURN_IF_ERROR(frame->IncrementIterations()); *iter_slot->mutable_attribute() = frame->value_stack().PeekAttribute().Step( AttributeQualifierFromValue(*key)); if (frame->attribute_utility().CheckForUnknownExact( iter_slot->attribute())) { *iter_slot->mutable_value() = frame->attribute_utility().CreateUnknownSet( iter_slot->attribute().attribute()); } } else { CEL_ASSIGN_OR_RETURN(bool ok, frame->iterator_stack().Peek()->Next1( frame->descriptor_pool(), frame->message_factory(), frame->arena(), iter_slot->mutable_value())); if (!ok) { iter_slot->Clear(); return frame->JumpTo(jump_offset_); } CEL_RETURN_IF_ERROR(frame->IncrementIterations()); } return absl::OkStatus(); } absl::Status ComprehensionNextStep::Evaluate2(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } { Value& accu_var = frame->value_stack().Peek(); AttributeTrail& accu_var_attr = frame->value_stack().PeekAttribute(); frame->comprehension_slots().Set(accu_slot_, std::move(accu_var), std::move(accu_var_attr)); frame->value_stack().Pop(1); } ComprehensionSlots::Slot* iter_slot = frame->comprehension_slots().Get(iter_slot_); ABSL_DCHECK(iter_slot != nullptr); iter_slot->Set(); ComprehensionSlots::Slot* iter2_slot = frame->comprehension_slots().Get(iter2_slot_); ABSL_DCHECK(iter2_slot != nullptr); iter2_slot->Set(); CEL_ASSIGN_OR_RETURN( bool ok, frame->iterator_stack().Peek()->Next2( frame->descriptor_pool(), frame->message_factory(), frame->arena(), iter_slot->mutable_value(), iter2_slot->mutable_value())); if (!ok) { iter_slot->Clear(); iter2_slot->Clear(); return frame->JumpTo(jump_offset_); } CEL_RETURN_IF_ERROR(frame->IncrementIterations()); if (frame->enable_unknowns()) { *iter_slot->mutable_attribute() = *iter2_slot->mutable_attribute() = frame->value_stack().PeekAttribute().Step( AttributeQualifierFromValue(iter_slot->value())); if (frame->attribute_utility().CheckForUnknownExact( iter2_slot->attribute())) { *iter2_slot->mutable_value() = frame->attribute_utility().CreateUnknownSet( iter2_slot->attribute().attribute()); } } return absl::OkStatus(); } absl::Status ComprehensionCondStep::Evaluate1(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } const Value& top = frame->value_stack().Peek(); switch (top.kind()) { case ValueKind::kBool: break; case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: frame->value_stack().SwapAndPop(2, 1); frame->comprehension_slots().ClearSlot(iter_slot_); frame->comprehension_slots().ClearSlot(accu_slot_); frame->iterator_stack().Pop(); return frame->JumpTo(error_jump_offset_); default: frame->value_stack().PopAndPush( 2, cel::ErrorValue(CreateNoMatchingOverloadError(""))); frame->comprehension_slots().ClearSlot(iter_slot_); frame->comprehension_slots().ClearSlot(accu_slot_); frame->iterator_stack().Pop(); return frame->JumpTo(error_jump_offset_); } const bool loop_condition = absl::implicit_cast(top.GetBool()); frame->value_stack().Pop(1); // loop_condition if (!loop_condition && shortcircuiting_) { return frame->JumpTo(jump_offset_); } return absl::OkStatus(); } absl::Status ComprehensionCondStep::Evaluate2(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } const Value& top = frame->value_stack().Peek(); switch (top.kind()) { case ValueKind::kBool: break; case ValueKind::kError: ABSL_FALLTHROUGH_INTENDED; case ValueKind::kUnknown: frame->value_stack().SwapAndPop(2, 1); frame->comprehension_slots().ClearSlot(iter_slot_); frame->comprehension_slots().ClearSlot(iter2_slot_); frame->comprehension_slots().ClearSlot(accu_slot_); frame->iterator_stack().Pop(); return frame->JumpTo(error_jump_offset_); default: frame->value_stack().PopAndPush( 2, cel::ErrorValue(CreateNoMatchingOverloadError(""))); frame->comprehension_slots().ClearSlot(iter_slot_); frame->comprehension_slots().ClearSlot(iter2_slot_); frame->comprehension_slots().ClearSlot(accu_slot_); frame->iterator_stack().Pop(); return frame->JumpTo(error_jump_offset_); } const bool loop_condition = absl::implicit_cast(top.GetBool()); frame->value_stack().Pop(1); // loop_condition if (!loop_condition && shortcircuiting_) { return frame->JumpTo(jump_offset_); } return absl::OkStatus(); } std::unique_ptr CreateDirectComprehensionStep( size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, std::unique_ptr condition_step, std::unique_ptr result_step, bool shortcircuiting, int64_t expr_id) { return std::make_unique( iter_slot, iter2_slot, accu_slot, std::move(range), std::move(accu_init), std::move(loop_step), std::move(condition_step), std::move(result_step), shortcircuiting, expr_id); } std::unique_ptr CreateComprehensionFinishStep(size_t accu_slot, int64_t expr_id) { return std::make_unique(accu_slot, expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/comprehension_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_STEP_H_ #include #include #include #include #include "absl/status/status.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" namespace google::api::expr::runtime { // Comprehension Evaluation // // 0: 1 -> 1 // 1: ComprehensionInitStep 1 -> 1 // 2: 1 -> 2 // 3: ComprehensionNextStep 2 -> 1 // 4: 1 -> 2 // 5: ComprehensionCondStep 2 -> 1 // 6: 1 -> 2 // 8: 1 -> 2 // 9: ComprehensionFinishStep 2 -> 1 class ComprehensionInitStep final : public ExpressionStepBase { public: explicit ComprehensionInitStep(int64_t expr_id) : ExpressionStepBase(expr_id, /*comes_from_ast=*/false) {} void set_error_jump_offset(int offset) { error_jump_offset_ = offset; } absl::Status Evaluate(ExecutionFrame* frame) const override; private: int error_jump_offset_ = std::numeric_limits::max(); }; class ComprehensionNextStep final : public ExpressionStepBase { public: ComprehensionNextStep(size_t iter_slot, size_t iter2_slot, size_t accu_slot, int64_t expr_id) : ExpressionStepBase(expr_id, /*comes_from_ast=*/false), iter_slot_(iter_slot), iter2_slot_(iter2_slot), accu_slot_(accu_slot) {} void set_jump_offset(int offset) { jump_offset_ = offset; } void set_error_jump_offset(int offset) { error_jump_offset_ = offset; } absl::Status Evaluate(ExecutionFrame* frame) const override { return iter_slot_ == iter2_slot_ ? Evaluate1(frame) : Evaluate2(frame); } private: absl::Status Evaluate1(ExecutionFrame* frame) const; absl::Status Evaluate2(ExecutionFrame* frame) const; const size_t iter_slot_; const size_t iter2_slot_; const size_t accu_slot_; int jump_offset_ = std::numeric_limits::max(); int error_jump_offset_ = std::numeric_limits::max(); }; class ComprehensionCondStep final : public ExpressionStepBase { public: ComprehensionCondStep(size_t iter_slot, size_t iter2_slot, size_t accu_slot, bool shortcircuiting, int64_t expr_id) : ExpressionStepBase(expr_id, /*comes_from_ast=*/false), iter_slot_(iter_slot), iter2_slot_(iter2_slot), accu_slot_(accu_slot), shortcircuiting_(shortcircuiting) {} void set_jump_offset(int offset) { jump_offset_ = offset; } void set_error_jump_offset(int offset) { error_jump_offset_ = offset; } absl::Status Evaluate(ExecutionFrame* frame) const override { return iter_slot_ == iter2_slot_ ? Evaluate1(frame) : Evaluate2(frame); } private: absl::Status Evaluate1(ExecutionFrame* frame) const; absl::Status Evaluate2(ExecutionFrame* frame) const; const size_t iter_slot_; const size_t iter2_slot_; const size_t accu_slot_; int jump_offset_ = std::numeric_limits::max(); int error_jump_offset_ = std::numeric_limits::max(); const bool shortcircuiting_; }; // Creates a step for executing a comprehension. std::unique_ptr CreateDirectComprehensionStep( size_t iter_slot, size_t iter2_slot, size_t accu_slot, std::unique_ptr range, std::unique_ptr accu_init, std::unique_ptr loop_step, std::unique_ptr condition_step, std::unique_ptr result_step, bool shortcircuiting, int64_t expr_id); // Creates a cleanup step for the comprehension. // Removes the comprehension context then pushes the 'result' sub expression to // the top of the stack. std::unique_ptr CreateComprehensionFinishStep(size_t accu_slot, int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPREHENSION_STEP_H_ ================================================ FILE: eval/eval/comprehension_step_test.cc ================================================ #include "eval/eval/comprehension_step.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/type_provider.h" #include "common/expr.h" #include "common/value.h" #include "common/value_testing.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/comprehension_slots.h" #include "eval/eval/const_value_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::BoolValue; using ::cel::Expr; using ::cel::IdentExpr; using ::cel::IntValue; using ::cel::TypeProvider; using ::cel::Value; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::test::BoolValueIs; using ::google::protobuf::Struct; using ::google::protobuf::Arena; using ::testing::_; using ::testing::Eq; using ::testing::Return; using ::testing::SizeIs; IdentExpr CreateIdent(const std::string& var) { IdentExpr expr; expr.set_name(var); return expr; } class ListKeysStepTest : public testing::Test { public: ListKeysStepTest() = default; std::unique_ptr MakeExpression( ExecutionPath&& path, bool unknown_attributes = false) { cel::RuntimeOptions options; if (unknown_attributes) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; } auto env = NewTestingRuntimeEnv(); return std::make_unique( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); } private: Expr dummy_expr_; }; class GetListKeysResultStep : public ExpressionStepBase { public: GetListKeysResultStep() : ExpressionStepBase(-1, false) {} absl::Status Evaluate(ExecutionFrame* frame) const override { frame->value_stack().Pop(1); return absl::OkStatus(); } }; MATCHER_P(CelStringValue, val, "") { const CelValue& to_match = arg; absl::string_view value = val; return to_match.IsString() && to_match.StringOrDie().value() == value; } TEST_F(ListKeysStepTest, MapPartiallyUnknown) { ExecutionPath path; auto result = CreateIdentStep("var", 0); ASSERT_OK(result); path.push_back(*std::move(result)); ComprehensionInitStep* init_step = new ComprehensionInitStep(1); init_step->set_error_jump_offset(1); path.push_back(absl::WrapUnique(init_step)); path.push_back(std::make_unique()); auto expression = MakeExpression(std::move(path), /*unknown_attributes=*/true); Activation activation; Arena arena; Struct value; (*value.mutable_fields())["key1"].set_number_value(1.0); (*value.mutable_fields())["key2"].set_number_value(2.0); (*value.mutable_fields())["key3"].set_number_value(3.0); activation.InsertValue("var", CelProtoWrapper::CreateMessage(&value, &arena)); activation.set_unknown_attribute_patterns({CelAttributePattern( "var", {CreateCelAttributeQualifierPattern(CelValue::CreateStringView("key2")), CreateCelAttributeQualifierPattern(CelValue::CreateStringView("foo")), CelAttributeQualifierPattern::CreateWildcard()})}); auto eval_result = expression->Evaluate(activation, &arena); ASSERT_OK(eval_result); ASSERT_TRUE(eval_result->IsUnknownSet()); const auto& attrs = eval_result->UnknownSetOrDie()->unknown_attributes(); EXPECT_THAT(attrs, SizeIs(1)); EXPECT_THAT(attrs.begin()->variable_name(), Eq("var")); EXPECT_THAT(attrs.begin()->qualifier_path(), SizeIs(0)); } TEST_F(ListKeysStepTest, ErrorPassedThrough) { ExecutionPath path; auto result = CreateIdentStep("var", 0); ASSERT_OK(result); path.push_back(*std::move(result)); ComprehensionInitStep* init_step = new ComprehensionInitStep(1); init_step->set_error_jump_offset(1); path.push_back(absl::WrapUnique(init_step)); path.push_back(std::make_unique()); auto expression = MakeExpression(std::move(path)); Activation activation; Arena arena; // Var not in activation, turns into cel error at eval time. auto eval_result = expression->Evaluate(activation, &arena); ASSERT_OK(eval_result); ASSERT_TRUE(eval_result->IsError()); EXPECT_THAT(eval_result->ErrorOrDie()->message(), testing::HasSubstr("\"var\"")); EXPECT_EQ(eval_result->ErrorOrDie()->code(), absl::StatusCode::kUnknown); } TEST_F(ListKeysStepTest, UnknownSetPassedThrough) { ExecutionPath path; auto result = CreateIdentStep("var", 0); ASSERT_OK(result); path.push_back(*std::move(result)); ComprehensionInitStep* init_step = new ComprehensionInitStep(1); init_step->set_error_jump_offset(1); path.push_back(absl::WrapUnique(init_step)); path.push_back(std::make_unique()); auto expression = MakeExpression(std::move(path), /*unknown_attributes=*/true); Activation activation; Arena arena; activation.set_unknown_attribute_patterns({CelAttributePattern("var", {})}); auto eval_result = expression->Evaluate(activation, &arena); ASSERT_OK(eval_result); ASSERT_TRUE(eval_result->IsUnknownSet()); EXPECT_THAT(eval_result->UnknownSetOrDie()->unknown_attributes(), SizeIs(1)); } class MockDirectStep : public DirectExpressionStep { public: MockDirectStep() : DirectExpressionStep(-1) {} MOCK_METHOD(absl::Status, Evaluate, (ExecutionFrameBase&, Value&, AttributeTrail&), (const, override)); }; // Test fixture for comprehensions. // // Comprehensions are quite involved so tests here focus on edge cases that are // hard to exercise normally in functional-style tests for the planner. class DirectComprehensionTest : public testing::Test { public: DirectComprehensionTest() : type_provider_(cel::internal::GetTestingDescriptorPool()), slots_(2) {} // returns a two element list for testing [1, 2]. absl::StatusOr MakeList() { auto builder = cel::NewListValueBuilder(&arena_); CEL_RETURN_IF_ERROR(builder->Add(IntValue(1))); CEL_RETURN_IF_ERROR(builder->Add(IntValue(2))); return std::move(*builder).Build(); } protected: google::protobuf::Arena arena_; cel::runtime_internal::RuntimeTypeProvider type_provider_; ComprehensionSlots slots_; cel::Activation empty_activation_; }; TEST_F(DirectComprehensionTest, PropagateRangeNonOkStatus) { cel::RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_, /*embedder_context=*/nullptr, slots_); auto range_step = std::make_unique(); MockDirectStep* mock = range_step.get(); ON_CALL(*mock, Evaluate(_, _, _)) .WillByDefault(Return(absl::InternalError("test range error"))); auto compre_step = CreateDirectComprehensionStep( 0, 0, 1, /*range_step=*/std::move(range_step), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), /*shortcircuiting=*/true, -1); Value result; AttributeTrail trail; EXPECT_THAT(compre_step->Evaluate(frame, result, trail), StatusIs(absl::StatusCode::kInternal, "test range error")); } TEST_F(DirectComprehensionTest, PropagateAccuInitNonOkStatus) { cel::RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_, /*embedder_context=*/nullptr, slots_); auto accu_init = std::make_unique(); MockDirectStep* mock = accu_init.get(); ON_CALL(*mock, Evaluate(_, _, _)) .WillByDefault(Return(absl::InternalError("test accu init error"))); ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/std::move(accu_init), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), /*shortcircuiting=*/true, -1); Value result; AttributeTrail trail; EXPECT_THAT(compre_step->Evaluate(frame, result, trail), StatusIs(absl::StatusCode::kInternal, "test accu init error")); } TEST_F(DirectComprehensionTest, PropagateLoopNonOkStatus) { cel::RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_, /*embedder_context=*/nullptr, slots_); auto loop_step = std::make_unique(); MockDirectStep* mock = loop_step.get(); ON_CALL(*mock, Evaluate(_, _, _)) .WillByDefault(Return(absl::InternalError("test loop error"))); ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), /*shortcircuiting=*/true, -1); Value result; AttributeTrail trail; EXPECT_THAT(compre_step->Evaluate(frame, result, trail), StatusIs(absl::StatusCode::kInternal, "test loop error")); } TEST_F(DirectComprehensionTest, PropagateConditionNonOkStatus) { cel::RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_, /*embedder_context=*/nullptr, slots_); auto condition = std::make_unique(); MockDirectStep* mock = condition.get(); ON_CALL(*mock, Evaluate(_, _, _)) .WillByDefault(Return(absl::InternalError("test condition error"))); ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), /*condition_step=*/std::move(condition), /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), /*shortcircuiting=*/true, -1); Value result; AttributeTrail trail; EXPECT_THAT(compre_step->Evaluate(frame, result, trail), StatusIs(absl::StatusCode::kInternal, "test condition error")); } TEST_F(DirectComprehensionTest, PropagateResultNonOkStatus) { cel::RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_, /*embedder_context=*/nullptr, slots_); auto result_step = std::make_unique(); MockDirectStep* mock = result_step.get(); ON_CALL(*mock, Evaluate(_, _, _)) .WillByDefault(Return(absl::InternalError("test result error"))); ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), /*result_step=*/std::move(result_step), /*shortcircuiting=*/true, -1); Value result; AttributeTrail trail; EXPECT_THAT(compre_step->Evaluate(frame, result, trail), StatusIs(absl::StatusCode::kInternal, "test result error")); } TEST_F(DirectComprehensionTest, Shortcircuit) { cel::RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_, /*embedder_context=*/nullptr, slots_); auto loop_step = std::make_unique(); MockDirectStep* mock = loop_step.get(); EXPECT_CALL(*mock, Evaluate(_, _, _)) .Times(0) .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { result = BoolValue(false); return absl::OkStatus(); }); ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), /*condition_step=*/CreateConstValueDirectStep(BoolValue(false)), /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), /*shortcircuiting=*/true, -1); Value result; AttributeTrail trail; ASSERT_OK(compre_step->Evaluate(frame, result, trail)); EXPECT_THAT(result, BoolValueIs(false)); } TEST_F(DirectComprehensionTest, IterationLimit) { cel::RuntimeOptions options; options.comprehension_max_iterations = 2; ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_, /*embedder_context=*/nullptr, slots_); auto loop_step = std::make_unique(); MockDirectStep* mock = loop_step.get(); EXPECT_CALL(*mock, Evaluate(_, _, _)) .Times(1) .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { result = BoolValue(false); return absl::OkStatus(); }); ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), /*shortcircuiting=*/true, -1); Value result; AttributeTrail trail; EXPECT_THAT(compre_step->Evaluate(frame, result, trail), StatusIs(absl::StatusCode::kInternal)); } TEST_F(DirectComprehensionTest, Exhaustive) { cel::RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_, /*embedder_context=*/nullptr, slots_); auto loop_step = std::make_unique(); MockDirectStep* mock = loop_step.get(); EXPECT_CALL(*mock, Evaluate(_, _, _)) .Times(2) .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { result = BoolValue(false); return absl::OkStatus(); }); ASSERT_OK_AND_ASSIGN(auto list, MakeList()); auto compre_step = CreateDirectComprehensionStep( 0, 0, 1, /*range_step=*/CreateConstValueDirectStep(std::move(list)), /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), /*loop_step=*/std::move(loop_step), /*condition_step=*/CreateConstValueDirectStep(BoolValue(false)), /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), /*shortcircuiting=*/false, -1); Value result; AttributeTrail trail; ASSERT_OK(compre_step->Evaluate(frame, result, trail)); EXPECT_THAT(result, BoolValueIs(false)); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/const_value_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONST_VALUE_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONST_VALUE_STEP_H_ #include #include #include #include "absl/status/statusor.h" #include "common/value.h" #include "eval/eval/compiler_constant_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Factory method for Constant AST node expression recursive step. inline std::unique_ptr CreateConstValueDirectStep( cel::Value value, int64_t id = -1) { return std::make_unique(std::move(value), id); } // Factory method for Constant AST node expression step. inline absl::StatusOr> CreateConstValueStep( cel::Value value, int64_t expr_id, bool comes_from_ast = true) { return std::make_unique(std::move(value), expr_id, comes_from_ast); } } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONST_VALUE_STEP_H_ ================================================ FILE: eval/eval/container_access_step.cc ================================================ #include "eval/eval/container_access_step.h" #include #include #include #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/casting.h" #include "common/expr.h" #include "common/kind.h" #include "common/value.h" #include "common/value_kind.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/attribute_utility.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" #include "internal/number.h" #include "internal/status_macros.h" #include "runtime/internal/errors.h" namespace google::api::expr::runtime { namespace { using ::cel::AttributeQualifier; using ::cel::Cast; using ::cel::ErrorValue; using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::ListValue; using ::cel::MapValue; using ::cel::UintValue; using ::cel::Value; using ::cel::ValueKind; using ::cel::ValueKindToString; using ::cel::internal::Number; using ::cel::runtime_internal::CreateNoSuchKeyError; inline constexpr int kNumContainerAccessArguments = 2; absl::optional CelNumberFromValue(const Value& value) { switch (value->kind()) { case ValueKind::kInt64: return Number::FromInt64(value.GetInt().NativeValue()); case ValueKind::kUint64: return Number::FromUint64(value.GetUint().NativeValue()); case ValueKind::kDouble: return Number::FromDouble(value.GetDouble().NativeValue()); default: return absl::nullopt; } } absl::Status CheckMapKeyType(const Value& key) { ValueKind kind = key->kind(); switch (kind) { case ValueKind::kString: case ValueKind::kInt64: case ValueKind::kUint64: case ValueKind::kBool: return absl::OkStatus(); default: return absl::InvalidArgumentError(absl::StrCat( "Invalid map key type: '", ValueKindToString(kind), "'")); } } AttributeQualifier AttributeQualifierFromValue(const Value& v) { switch (v->kind()) { case ValueKind::kString: return AttributeQualifier::OfString(v.GetString().ToString()); case ValueKind::kInt64: return AttributeQualifier::OfInt(v.GetInt().NativeValue()); case ValueKind::kUint64: return AttributeQualifier::OfUint(v.GetUint().NativeValue()); case ValueKind::kBool: return AttributeQualifier::OfBool(v.GetBool().NativeValue()); default: // Non-matching qualifier. return AttributeQualifier(); } } void LookupInMap(const MapValue& cel_map, const Value& key, ExecutionFrameBase& frame, Value& result) { if (frame.options().enable_heterogeneous_equality) { // Double isn't a supported key type but may be convertible to an integer. absl::optional number = CelNumberFromValue(key); if (number.has_value()) { // Consider uint as uint first then try coercion (prefer matching the // original type of the key value). if (key->Is()) { auto lookup = cel_map.Find(key, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result); if (!lookup.ok()) { result = cel::ErrorValue(std::move(lookup).status()); return; } if (*lookup) { ABSL_DCHECK(!result.IsUnknown()); return; } } // double / int / uint -> int if (number->LosslessConvertibleToInt()) { auto lookup = cel_map.Find(IntValue(number->AsInt()), frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result); if (!lookup.ok()) { result = cel::ErrorValue(std::move(lookup).status()); return; } if (*lookup) { ABSL_DCHECK(!result.IsUnknown()); return; } } // double / int -> uint if (number->LosslessConvertibleToUint()) { auto lookup = cel_map.Find(UintValue(number->AsUint()), frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result); if (!lookup.ok()) { result = cel::ErrorValue(std::move(lookup).status()); return; } if (*lookup) { ABSL_DCHECK(!result.IsUnknown()); return; } } result = cel::ErrorValue(CreateNoSuchKeyError(key->DebugString())); return; } } absl::Status status = CheckMapKeyType(key); if (!status.ok()) { result = cel::ErrorValue(std::move(status)); return; } absl::Status lookup = cel_map.Get(key, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result); if (!lookup.ok()) { result = cel::ErrorValue(std::move(lookup)); } ABSL_DCHECK(!result.IsUnknown()); } void LookupInList(const ListValue& cel_list, const Value& key, ExecutionFrameBase& frame, Value& result) { absl::optional maybe_idx; if (frame.options().enable_heterogeneous_equality) { auto number = CelNumberFromValue(key); if (number.has_value() && number->LosslessConvertibleToInt()) { maybe_idx = number->AsInt(); } } else if (InstanceOf(key)) { maybe_idx = key.GetInt().NativeValue(); } if (!maybe_idx.has_value()) { result = cel::ErrorValue(absl::UnknownError( absl::StrCat("Index error: expected integer type, got ", cel::KindToString(ValueKindToKind(key->kind()))))); return; } int64_t idx = *maybe_idx; auto size = cel_list.Size(); if (!size.ok()) { result = cel::ErrorValue(size.status()); return; } if (idx < 0 || idx >= *size) { result = cel::ErrorValue(absl::UnknownError( absl::StrCat("Index error: index=", idx, " size=", *size))); return; } absl::Status lookup = cel_list.Get(idx, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result); if (!lookup.ok()) { result = cel::ErrorValue(std::move(lookup)); } ABSL_DCHECK(!result.IsUnknown()); } void LookupInContainer(const Value& container, const Value& key, ExecutionFrameBase& frame, Value& result) { // Select steps can be applied to either maps or messages switch (container.kind()) { case ValueKind::kMap: { LookupInMap(Cast(container), key, frame, result); return; } case ValueKind::kList: { LookupInList(Cast(container), key, frame, result); return; } default: result = cel::ErrorValue(absl::InvalidArgumentError( absl::StrCat("Invalid container type: '", ValueKindToString(container->kind()), "'"))); return; } } void PerformLookup(ExecutionFrameBase& frame, const Value& container, const Value& key, const AttributeTrail& container_trail, bool enable_optional_types, Value& result, AttributeTrail& trail) { if (frame.unknown_processing_enabled()) { AttributeUtility::Accumulator unknowns = frame.attribute_utility().CreateAccumulator(); unknowns.MaybeAdd(container); unknowns.MaybeAdd(key); if (!unknowns.IsEmpty()) { result = std::move(unknowns).Build(); return; } trail = container_trail.Step(AttributeQualifierFromValue(key)); if (frame.attribute_utility().CheckForUnknownExact(trail)) { result = frame.attribute_utility().CreateUnknownSet(trail.attribute()); return; } } if (InstanceOf(container)) { result = container; return; } if (InstanceOf(key)) { result = key; return; } if (enable_optional_types && container.IsOptional()) { const auto& optional_value = container.GetOptional(); if (!optional_value.HasValue()) { result = cel::OptionalValue::None(); return; } Value value; optional_value.Value(&value); LookupInContainer(value, key, frame, result); if (auto error_value = cel::As(result); error_value && cel::IsNoSuchKey(*error_value)) { result = cel::OptionalValue::None(); return; } result = cel::OptionalValue::Of(std::move(result), frame.arena()); return; } LookupInContainer(container, key, frame, result); } // ContainerAccessStep performs message field access specified by Expr::Select // message. class ContainerAccessStep : public ExpressionStepBase { public: ContainerAccessStep(int64_t expr_id, bool enable_optional_types) : ExpressionStepBase(expr_id), enable_optional_types_(enable_optional_types) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: bool enable_optional_types_; }; absl::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(kNumContainerAccessArguments)) { return absl::Status( absl::StatusCode::kInternal, "Insufficient arguments supplied for ContainerAccess-type expression"); } Value result; AttributeTrail result_trail; auto args = frame->value_stack().GetSpan(kNumContainerAccessArguments); const AttributeTrail& container_trail = frame->value_stack().GetAttributeSpan(kNumContainerAccessArguments)[0]; PerformLookup(*frame, args[0], args[1], container_trail, enable_optional_types_, result, result_trail); frame->value_stack().PopAndPush(kNumContainerAccessArguments, std::move(result), std::move(result_trail)); return absl::OkStatus(); } class DirectContainerAccessStep : public DirectExpressionStep { public: DirectContainerAccessStep( std::unique_ptr container_step, std::unique_ptr key_step, bool enable_optional_types, int64_t expr_id) : DirectExpressionStep(expr_id), container_step_(std::move(container_step)), key_step_(std::move(key_step)), enable_optional_types_(enable_optional_types) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const override; private: std::unique_ptr container_step_; std::unique_ptr key_step_; bool enable_optional_types_; }; absl::Status DirectContainerAccessStep::Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const { Value container; Value key; AttributeTrail container_trail; AttributeTrail key_trail; CEL_RETURN_IF_ERROR( container_step_->Evaluate(frame, container, container_trail)); CEL_RETURN_IF_ERROR(key_step_->Evaluate(frame, key, key_trail)); PerformLookup(frame, container, key, container_trail, enable_optional_types_, result, trail); return absl::OkStatus(); } } // namespace std::unique_ptr CreateDirectContainerAccessStep( std::unique_ptr container_step, std::unique_ptr key_step, bool enable_optional_types, int64_t expr_id) { return std::make_unique( std::move(container_step), std::move(key_step), enable_optional_types, expr_id); } // Factory method for Select - based Execution step absl::StatusOr> CreateContainerAccessStep( const cel::CallExpr& call, int64_t expr_id, bool enable_optional_types) { int arg_count = call.args().size() + (call.has_target() ? 1 : 0); if (arg_count != kNumContainerAccessArguments) { return absl::InvalidArgumentError(absl::StrCat( "Invalid argument count for index operation: ", arg_count)); } return std::make_unique(expr_id, enable_optional_types); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/container_access_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ #include #include #include "absl/status/statusor.h" #include "common/expr.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { std::unique_ptr CreateDirectContainerAccessStep( std::unique_ptr container_step, std::unique_ptr key_step, bool enable_optional_types, int64_t expr_id); // Factory method for Select - based Execution step absl::StatusOr> CreateContainerAccessStep( const cel::CallExpr& call, int64_t expr_id, bool enable_optional_types = false); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ ================================================ FILE: eval/eval/container_access_step_test.cc ================================================ #include "eval/eval/container_access_step.h" #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "base/builtins.h" #include "base/type_provider.h" #include "common/ast.h" #include "common/expr.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_set.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::Expr; using ::cel::SourceInfo; using ::cel::TypeProvider; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::protobuf::Struct; using ::testing::_; using ::testing::AllOf; using ::testing::HasSubstr; using TestParamType = std::tuple; CelValue EvaluateAttributeHelper( const absl_nonnull std::shared_ptr& env, google::protobuf::Arena* arena, CelValue container, CelValue key, bool use_recursive_impl, bool receiver_style, bool enable_unknown, const std::vector& patterns) { ExecutionPath path; Expr expr; SourceInfo source_info; auto& call = expr.mutable_call_expr(); call.set_function(cel::builtin::kIndex); call.mutable_args().reserve(2); Expr& container_expr = (receiver_style) ? call.mutable_target() : call.mutable_args().emplace_back(); Expr& key_expr = call.mutable_args().emplace_back(); container_expr.mutable_ident_expr().set_name("container"); key_expr.mutable_ident_expr().set_name("key"); if (use_recursive_impl) { path.push_back(std::make_unique( CreateDirectContainerAccessStep(CreateDirectIdentStep("container", 1), CreateDirectIdentStep("key", 2), /*enable_optional_types=*/false, 3), 3)); } else { path.push_back(std::move(CreateIdentStep("container", 1).value())); path.push_back(std::move(CreateIdentStep("key", 2).value())); path.push_back(std::move(CreateContainerAccessStep(call, 3).value())); } cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; options.enable_heterogeneous_equality = false; CelExpressionFlatImpl cel_expr( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("container", container); activation.InsertValue("key", key); activation.set_unknown_attribute_patterns(patterns); auto result = cel_expr.Evaluate(activation, arena); return *result; } class ContainerAccessStepTest : public ::testing::Test { protected: ContainerAccessStepTest() = default; void SetUp() override { env_ = NewTestingRuntimeEnv(); } CelValue EvaluateAttribute( CelValue container, CelValue key, bool receiver_style, bool enable_unknown, bool use_recursive_impl = false, const std::vector& patterns = {}) { return EvaluateAttributeHelper(env_, &arena_, container, key, receiver_style, enable_unknown, use_recursive_impl, patterns); } absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; }; class ContainerAccessStepUniformityTest : public ::testing::TestWithParam { protected: ContainerAccessStepUniformityTest() = default; void SetUp() override { env_ = NewTestingRuntimeEnv(); } bool receiver_style() { TestParamType params = GetParam(); return std::get<0>(params); } bool enable_unknown() { TestParamType params = GetParam(); return std::get<1>(params); } bool use_recursive_impl() { TestParamType params = GetParam(); return std::get<2>(params); } // Helper method. Looks up in registry and tests comparison operation. CelValue EvaluateAttribute( CelValue container, CelValue key, bool receiver_style, bool enable_unknown, bool use_recursive_impl = false, const std::vector& patterns = {}) { return EvaluateAttributeHelper(env_, &arena_, container, key, receiver_style, enable_unknown, use_recursive_impl, patterns); } absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; }; TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccess) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(1), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), 2); } TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccessOutOfBounds) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(0), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsInt64()); result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(2), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsInt64()); result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(-1), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsError()); result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(3), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsError()); } TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccessNotAnInt) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateUint64(1), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsError()); } TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccess) { const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; const std::string kKey2 = "testkey2"; Struct cel_struct; (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); CelValue result = EvaluateAttribute( CelProtoWrapper::CreateMessage(&cel_struct, &arena_), CelValue::CreateString(&kKey0), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsString()); ASSERT_EQ(result.StringOrDie().value(), "value0"); } TEST_P(ContainerAccessStepUniformityTest, TestBoolKeyType) { CelMapBuilder cel_map; ASSERT_OK(cel_map.Add(CelValue::CreateBool(true), CelValue::CreateStringView("value_true"))); CelValue result = EvaluateAttribute(CelValue::CreateMap(&cel_map), CelValue::CreateBool(true), receiver_style(), enable_unknown()); ASSERT_THAT(result, test::IsCelString("value_true")); } TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; Struct cel_struct; (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); CelValue result = EvaluateAttribute( CelProtoWrapper::CreateMessage(&cel_struct, &arena_), CelValue::CreateString(&kKey1), receiver_style(), enable_unknown()); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kNotFound, AllOf(HasSubstr("Key not found in map : "), HasSubstr("testkey1")))); } TEST_F(ContainerAccessStepTest, TestInvalidReceiverCreateContainerAccessStep) { Expr expr; auto& call = expr.mutable_call_expr(); call.set_function(cel::builtin::kIndex); Expr& container_expr = call.mutable_target(); container_expr.mutable_ident_expr().set_name("container"); call.mutable_args().reserve(2); Expr& key_expr = call.mutable_args().emplace_back(); key_expr.mutable_ident_expr().set_name("key"); Expr& extra_arg = call.mutable_args().emplace_back(); extra_arg.mutable_const_expr().set_bool_value(true); EXPECT_THAT(CreateContainerAccessStep(call, 0).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid argument count"))); } TEST_F(ContainerAccessStepTest, TestInvalidGlobalCreateContainerAccessStep) { Expr expr; auto& call = expr.mutable_call_expr(); call.set_function(cel::builtin::kIndex); call.mutable_args().reserve(3); Expr& container_expr = call.mutable_args().emplace_back(); container_expr.mutable_ident_expr().set_name("container"); Expr& key_expr = call.mutable_args().emplace_back(); key_expr.mutable_ident_expr().set_name("key"); Expr& extra_arg = call.mutable_args().emplace_back(); extra_arg.mutable_const_expr().set_bool_value(true); EXPECT_THAT(CreateContainerAccessStep(call, 0).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid argument count"))); } TEST_F(ContainerAccessStepTest, TestListIndexAccessUnknown) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(1), true, true, {}); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), 2); std::vector patterns = {CelAttributePattern( "container", {CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1))})}; result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateInt64(1), true, true, false, patterns); ASSERT_TRUE(result.IsUnknownSet()); } TEST_F(ContainerAccessStepTest, TestListUnknownKey) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); UnknownSet unknown_set; CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), CelValue::CreateUnknownSet(&unknown_set), true, true); ASSERT_TRUE(result.IsUnknownSet()); } TEST_F(ContainerAccessStepTest, TestMapInvalidKey) { const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; const std::string kKey2 = "testkey2"; Struct cel_struct; (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); CelValue result = EvaluateAttribute(CelProtoWrapper::CreateMessage(&cel_struct, &arena_), CelValue::CreateDouble(1.0), true, true); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid map key type: 'double'"))); } TEST_F(ContainerAccessStepTest, TestMapUnknownKey) { const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; const std::string kKey2 = "testkey2"; Struct cel_struct; (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); UnknownSet unknown_set; CelValue result = EvaluateAttribute(CelProtoWrapper::CreateMessage(&cel_struct, &arena_), CelValue::CreateUnknownSet(&unknown_set), true, true); ASSERT_TRUE(result.IsUnknownSet()); } TEST_F(ContainerAccessStepTest, TestUnknownContainer) { UnknownSet unknown_set; CelValue result = EvaluateAttribute(CelValue::CreateUnknownSet(&unknown_set), CelValue::CreateInt64(1), true, true); ASSERT_TRUE(result.IsUnknownSet()); } TEST_F(ContainerAccessStepTest, TestInvalidContainerType) { CelValue result = EvaluateAttribute(CelValue::CreateInt64(1), CelValue::CreateInt64(0), true, true); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Invalid container type: 'int"))); } INSTANTIATE_TEST_SUITE_P( CombinedContainerTest, ContainerAccessStepUniformityTest, testing::Combine(/*receiver_style*/ testing::Bool(), /*unknown_enabled*/ testing::Bool(), /*use_recursive_impl*/ testing::Bool())); class ContainerAccessHeterogeneousLookupsTest : public testing::Test { public: ContainerAccessHeterogeneousLookupsTest() { options_.enable_heterogeneous_equality = true; builder_ = CreateCelExpressionBuilder(options_); } protected: InterpreterOptions options_; std::unique_ptr builder_; google::protobuf::Arena arena_; Activation activation_; }; TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.0]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelInt64(2)); } TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyNotAnInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.1]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleMapKeyUint) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1.0]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelUint64(2)); } TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleListIndex) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.0]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelInt64(2)); } TEST_F(ContainerAccessHeterogeneousLookupsTest, DoubleListIndexNotAnInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.1]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } // treat uint as uint before trying coercion to signed int. TEST_F(ContainerAccessHeterogeneousLookupsTest, UintKeyAsUint) { // TODO(uncreated-issue/4): Map creation should error here instead of permitting // mixed key types with equivalent values. ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u, 1: 2}[1u]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelUint64(2)); } TEST_F(ContainerAccessHeterogeneousLookupsTest, UintKeyAsInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1u]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelInt64(2)); } TEST_F(ContainerAccessHeterogeneousLookupsTest, IntKeyAsUint) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelUint64(2)); } TEST_F(ContainerAccessHeterogeneousLookupsTest, UintListIndex) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][2u]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelInt64(3)); } TEST_F(ContainerAccessHeterogeneousLookupsTest, StringKeyUnaffected) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2, '1': 3}['1']")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelInt64(3)); } class ContainerAccessHeterogeneousLookupsDisabledTest : public testing::Test { public: ContainerAccessHeterogeneousLookupsDisabledTest() { options_.enable_heterogeneous_equality = false; builder_ = CreateCelExpressionBuilder(options_); } protected: InterpreterOptions options_; std::unique_ptr builder_; google::protobuf::Arena arena_; Activation activation_; }; TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.0]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyNotAnInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1.1]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleMapKeyUint) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1.0]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleListIndex) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.0]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, DoubleListIndexNotAnInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][1.1]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintKeyAsUint) { // TODO(uncreated-issue/4): Map creation should error here instead of permitting // mixed key types with equivalent values. ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u, 1: 2}[1u]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelUint64(2)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintKeyAsInt) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2}[1u]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, IntKeyAsUint) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1u: 2u}[1]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, UintListIndex) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("[1, 2, 3][2u]")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelError(_)); } TEST_F(ContainerAccessHeterogeneousLookupsDisabledTest, StringKeyUnaffected) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse("{1: 2, '1': 3}['1']")); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder_->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation_, &arena_)); EXPECT_THAT(result, test::IsCelInt64(3)); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/create_list_step.cc ================================================ #include "eval/eval/create_list_step.h" #include #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "common/casting.h" #include "common/expr.h" #include "common/value.h" #include "common/values/list_value_builder.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/attribute_utility.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { using ::cel::Cast; using ::cel::ErrorValue; using ::cel::InstanceOf; using ::cel::ListValueBuilderPtr; using ::cel::UnknownValue; using ::cel::Value; using ::cel::common_internal::NewListValueBuilder; class CreateListStep : public ExpressionStepBase { public: CreateListStep(int64_t expr_id, int list_size, absl::flat_hash_set optional_indices) : ExpressionStepBase(expr_id), list_size_(list_size), optional_indices_(std::move(optional_indices)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: absl::Status DoEvaluate(ExecutionFrame* frame, Value* result) const; int list_size_; absl::flat_hash_set optional_indices_; }; absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { if (list_size_ < 0) { return absl::Status(absl::StatusCode::kInternal, "CreateListStep: list size is <0"); } if (!frame->value_stack().HasEnough(list_size_)) { return absl::Status(absl::StatusCode::kInternal, "CreateListStep: stack underflow"); } Value result; CEL_RETURN_IF_ERROR(DoEvaluate(frame, &result)); frame->value_stack().PopAndPush(list_size_, std::move(result)); return absl::OkStatus(); } absl::Status CreateListStep::DoEvaluate(ExecutionFrame* frame, Value* result) const { auto args = frame->value_stack().GetSpan(list_size_); for (const auto& arg : args) { if (arg.IsError()) { *result = arg; return absl::OkStatus(); } } if (frame->enable_unknowns()) { absl::optional unknown_set = frame->attribute_utility().IdentifyAndMergeUnknowns( args, frame->value_stack().GetAttributeSpan(list_size_), /*use_partial=*/true); if (unknown_set.has_value()) { *result = std::move(*unknown_set); return absl::OkStatus(); } } ListValueBuilderPtr builder = NewListValueBuilder(frame->arena()); builder->Reserve(args.size()); for (size_t i = 0; i < args.size(); ++i) { const auto& arg = args[i]; if (optional_indices_.contains(static_cast(i))) { if (auto optional_arg = arg.AsOptional(); optional_arg) { if (!optional_arg->HasValue()) { continue; } Value optional_arg_value; optional_arg->Value(&optional_arg_value); if (optional_arg_value.IsError()) { // Error should never be in optional, but better safe than sorry. *result = std::move(optional_arg_value); return absl::OkStatus(); } CEL_RETURN_IF_ERROR(builder->Add(std::move(optional_arg_value))); } else { *result = cel::TypeConversionError(arg.GetTypeName(), "optional_type"); return absl::OkStatus(); } } else { CEL_RETURN_IF_ERROR(builder->Add(arg)); } } *result = std::move(*builder).Build(); return absl::OkStatus(); } absl::flat_hash_set MakeOptionalIndicesSet( const cel::ListExpr& create_list_expr) { absl::flat_hash_set optional_indices; for (size_t i = 0; i < create_list_expr.elements().size(); ++i) { if (create_list_expr.elements()[i].optional()) { optional_indices.insert(static_cast(i)); } } return optional_indices; } class CreateListDirectStep : public DirectExpressionStep { public: CreateListDirectStep( std::vector> elements, absl::flat_hash_set optional_indices, int64_t expr_id) : DirectExpressionStep(expr_id), elements_(std::move(elements)), optional_indices_(std::move(optional_indices)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const override { ListValueBuilderPtr builder = NewListValueBuilder(frame.arena()); builder->Reserve(elements_.size()); AttributeUtility::Accumulator unknowns = frame.attribute_utility().CreateAccumulator(); AttributeTrail tmp_attr; for (size_t i = 0; i < elements_.size(); ++i) { const auto& element = elements_[i]; CEL_RETURN_IF_ERROR(element->Evaluate(frame, result, tmp_attr)); if (result.IsError()) { return absl::OkStatus(); } if (frame.attribute_tracking_enabled()) { if (frame.missing_attribute_errors_enabled()) { if (frame.attribute_utility().CheckForMissingAttribute(tmp_attr)) { CEL_ASSIGN_OR_RETURN( result, frame.attribute_utility().CreateMissingAttributeError( tmp_attr.attribute())); return absl::OkStatus(); } } if (frame.unknown_processing_enabled()) { if (result.IsUnknown()) { unknowns.Add(result.GetUnknown()); } if (frame.attribute_utility().CheckForUnknown(tmp_attr, /*use_partial=*/true)) { unknowns.Add(tmp_attr); } } } if (!unknowns.IsEmpty()) { // We found an unknown, there is no point in attempting to create a // list. Instead iterate through the remaining elements and look for // more unknowns. continue; } // Conditionally add if optional. if (optional_indices_.contains(static_cast(i))) { if (auto optional_arg = result.AsOptional(); optional_arg) { if (!optional_arg->HasValue()) { continue; } Value optional_arg_value; optional_arg->Value(&optional_arg_value); if (optional_arg_value.IsError()) { // Error should never be in optional, but better safe than sorry. result = std::move(optional_arg_value); return absl::OkStatus(); } CEL_RETURN_IF_ERROR(builder->Add(std::move(optional_arg_value))); continue; } result = cel::TypeConversionError(result.GetTypeName(), "optional_type"); return absl::OkStatus(); } // Otherwise just add. CEL_RETURN_IF_ERROR(builder->Add(std::move(result))); } if (!unknowns.IsEmpty()) { result = std::move(unknowns).Build(); return absl::OkStatus(); } result = std::move(*builder).Build(); return absl::OkStatus(); } private: std::vector> elements_; absl::flat_hash_set optional_indices_; }; class MutableListStep : public ExpressionStepBase { public: explicit MutableListStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; }; absl::Status MutableListStep::Evaluate(ExecutionFrame* frame) const { frame->value_stack().Push(cel::CustomListValue( cel::common_internal::NewMutableListValue(frame->arena()), frame->arena())); return absl::OkStatus(); } class DirectMutableListStep : public DirectExpressionStep { public: explicit DirectMutableListStep(int64_t expr_id) : DirectExpressionStep(expr_id) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override; }; absl::Status DirectMutableListStep::Evaluate( ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const { result = cel::CustomListValue( cel::common_internal::NewMutableListValue(frame.arena()), frame.arena()); return absl::OkStatus(); } } // namespace std::unique_ptr CreateDirectListStep( std::vector> deps, absl::flat_hash_set optional_indices, int64_t expr_id) { return std::make_unique( std::move(deps), std::move(optional_indices), expr_id); } absl::StatusOr> CreateCreateListStep( const cel::ListExpr& create_list_expr, int64_t expr_id) { return std::make_unique( expr_id, create_list_expr.elements().size(), MakeOptionalIndicesSet(create_list_expr)); } std::unique_ptr CreateMutableListStep(int64_t expr_id) { return std::make_unique(expr_id); } std::unique_ptr CreateDirectMutableListStep( int64_t expr_id) { return std::make_unique(expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/create_list_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_LIST_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_LIST_STEP_H_ #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "common/expr.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Factory method for CreateList that evaluates recursively. std::unique_ptr CreateDirectListStep( std::vector> deps, absl::flat_hash_set optional_indices, int64_t expr_id); // Factory method for CreateList which constructs an immutable list. absl::StatusOr> CreateCreateListStep( const cel::ListExpr& create_list_expr, int64_t expr_id); // Factory method for CreateList which constructs a mutable list. // // This is intended for the list construction step is generated for a // list-building comprehension (rather than a user authored expression). std::unique_ptr CreateMutableListStep(int64_t expr_id); // Factory method for CreateList which constructs a mutable list. // // This is intended for the list construction step is generated for a // list-building comprehension (rather than a user authored expression). std::unique_ptr CreateDirectMutableListStep( int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_LIST_STEP_H_ ================================================ FILE: eval/eval/create_list_step_test.cc ================================================ #include "eval/eval/create_list_step.h" #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "base/attribute.h" #include "base/attribute_set.h" #include "base/type_provider.h" #include "common/casting.h" #include "common/expr.h" #include "common/value.h" #include "common/value_testing.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/const_value_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/internal/interop.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_attribute_set.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::Attribute; using ::cel::AttributeQualifier; using ::cel::AttributeSet; using ::cel::Cast; using ::cel::ErrorValue; using ::cel::Expr; using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::ListValue; using ::cel::TypeProvider; using ::cel::UnknownValue; using ::cel::Value; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::cel::test::IntValueIs; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Not; using ::testing::UnorderedElementsAre; // Helper method. Creates simple pipeline containing Select step and runs it. absl::StatusOr RunExpression( const absl_nonnull std::shared_ptr& env, const std::vector& values, google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; auto& create_list = dummy_expr.mutable_list_expr(); for (auto value : values) { auto& expr0 = create_list.mutable_elements().emplace_back().mutable_expr(); expr0.mutable_const_expr().set_int64_value(value); CEL_ASSIGN_OR_RETURN( auto const_step, CreateConstValueStep(cel::interop_internal::CreateIntValue(value), /*expr_id=*/-1)); path.push_back(std::move(const_step)); } CEL_ASSIGN_OR_RETURN(auto step, CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step)); cel::RuntimeOptions options; if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; return cel_expr.Evaluate(activation, arena); } // Helper method. Creates simple pipeline containing Select step and runs it. absl::StatusOr RunExpressionWithCelValues( const absl_nonnull std::shared_ptr& env, const std::vector& values, google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; Activation activation; auto& create_list = dummy_expr.mutable_list_expr(); int ind = 0; for (auto value : values) { std::string var_name = absl::StrCat("name_", ind++); auto& expr0 = create_list.mutable_elements().emplace_back().mutable_expr(); expr0.set_id(ind); expr0.mutable_ident_expr().set_name(var_name); CEL_ASSIGN_OR_RETURN(auto ident_step, CreateIdentStep(var_name, /*expr_id=*/-1)); path.push_back(std::move(ident_step)); activation.InsertValue(var_name, value); } CEL_ASSIGN_OR_RETURN(auto step0, CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step0)); cel::RuntimeOptions options; if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); return cel_expr.Evaluate(activation, arena); } class CreateListStepTest : public testing::TestWithParam { public: CreateListStepTest() : env_(NewTestingRuntimeEnv()) {} protected: absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; }; // Tests error when not enough list elements are on the stack during list // creation. TEST(CreateListStepTest, TestCreateListStackUnderflow) { ExecutionPath path; Expr dummy_expr; auto& create_list = dummy_expr.mutable_list_expr(); auto& expr0 = create_list.mutable_elements().emplace_back().mutable_expr(); expr0.mutable_const_expr().set_int64_value(1); ASSERT_OK_AND_ASSIGN(auto step0, CreateCreateListStep(create_list, dummy_expr.id())); path.push_back(std::move(step0)); auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl cel_expr( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), cel::RuntimeOptions{})); Activation activation; google::protobuf::Arena arena; auto status = cel_expr.Evaluate(activation, &arena); ASSERT_THAT(status, Not(IsOk())); } TEST_P(CreateListStepTest, CreateListEmpty) { ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(env_, {}, &arena_, GetParam())); ASSERT_TRUE(result.IsList()); EXPECT_THAT(result.ListOrDie()->size(), Eq(0)); } TEST_P(CreateListStepTest, CreateListOne) { ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(env_, {100}, &arena_, GetParam())); ASSERT_TRUE(result.IsList()); const auto& list = *result.ListOrDie(); ASSERT_THAT(list.size(), Eq(1)); const CelValue& value = list.Get(&arena_, 0); EXPECT_THAT(value, test::IsCelInt64(100)); } TEST_P(CreateListStepTest, CreateListWithError) { std::vector values; CelError error = absl::InvalidArgumentError("bad arg"); values.push_back(CelValue::CreateError(&error)); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpressionWithCelValues( env_, values, &arena_, GetParam())); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), Eq(absl::InvalidArgumentError("bad arg"))); } TEST_P(CreateListStepTest, CreateListWithErrorAndUnknown) { // list composition is: {unknown, error} std::vector values; Expr expr0; expr0.mutable_ident_expr().set_name("name0"); CelAttribute attr0(expr0.ident_expr().name(), {}); UnknownSet unknown_set0(UnknownAttributeSet({attr0})); values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); CelError error = absl::InvalidArgumentError("bad arg"); values.push_back(CelValue::CreateError(&error)); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpressionWithCelValues( env_, values, &arena_, GetParam())); // The bad arg should win. ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), Eq(absl::InvalidArgumentError("bad arg"))); } TEST_P(CreateListStepTest, CreateListHundred) { std::vector values; for (size_t i = 0; i < 100; i++) { values.push_back(i); } ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(env_, values, &arena_, GetParam())); ASSERT_TRUE(result.IsList()); const auto& list = *result.ListOrDie(); EXPECT_THAT(list.size(), Eq(static_cast(values.size()))); for (size_t i = 0; i < values.size(); i++) { EXPECT_THAT(list.Get(&arena_, i), test::IsCelInt64(values[i])); } } INSTANTIATE_TEST_SUITE_P(CombinedCreateListTest, CreateListStepTest, testing::Bool()); TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { google::protobuf::Arena arena; std::vector values; Expr expr0; expr0.mutable_ident_expr().set_name("name0"); CelAttribute attr0(expr0.ident_expr().name(), {}); Expr expr1; expr1.mutable_ident_expr().set_name("name1"); CelAttribute attr1(expr1.ident_expr().name(), {}); UnknownSet unknown_set0(UnknownAttributeSet({attr0})); UnknownSet unknown_set1(UnknownAttributeSet({attr1})); for (size_t i = 0; i < 100; i++) { values.push_back(CelValue::CreateInt64(i)); } values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); values.push_back(CelValue::CreateUnknownSet(&unknown_set1)); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpressionWithCelValues(NewTestingRuntimeEnv(), values, &arena, true)); ASSERT_TRUE(result.IsUnknownSet()); const UnknownSet* result_set = result.UnknownSetOrDie(); EXPECT_THAT(result_set->unknown_attributes().size(), Eq(2)); } TEST(CreateDirectListStep, Basic) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; cel::RuntimeOptions options; ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> deps; deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); deps.push_back(CreateConstValueDirectStep(IntValue(2), -1)); auto step = CreateDirectListStep(std::move(deps), {}, -1); cel::Value result; AttributeTrail attr; ASSERT_OK(step->Evaluate(frame, result, attr)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).Size(), IsOkAndHolds(2)); } TEST(CreateDirectListStep, ForwardFirstError) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; cel::RuntimeOptions options; ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> deps; deps.push_back(CreateConstValueDirectStep( cel::ErrorValue(absl::InternalError("test1")), -1)); deps.push_back(CreateConstValueDirectStep( cel::ErrorValue(absl::InternalError("test2")), -1)); auto step = CreateDirectListStep(std::move(deps), {}, -1); cel::Value result; AttributeTrail attr; ASSERT_OK(step->Evaluate(frame, result, attr)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kInternal, "test1")); } std::vector UnknownAttrNames(const UnknownValue& v) { std::vector names; names.reserve(v.attribute_set().size()); for (const auto& attr : v.attribute_set()) { EXPECT_OK(attr.AsString().status()); names.push_back(attr.AsString().value_or("")); } return names; } TEST(CreateDirectListStep, MergeUnknowns) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); AttributeSet attr_set1({Attribute("var1")}); AttributeSet attr_set2({Attribute("var2")}); std::vector> deps; deps.push_back(CreateConstValueDirectStep( cel::UnknownValue(cel::Unknown(std::move(attr_set1))), -1)); deps.push_back(CreateConstValueDirectStep( cel::UnknownValue(cel::Unknown(std::move(attr_set2))), -1)); auto step = CreateDirectListStep(std::move(deps), {}, -1); cel::Value result; AttributeTrail attr; ASSERT_OK(step->Evaluate(frame, result, attr)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(UnknownAttrNames(Cast(result)), UnorderedElementsAre("var1", "var2")); } TEST(CreateDirectListStep, ErrorBeforeUnknown) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; cel::RuntimeOptions options; ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); AttributeSet attr_set1({Attribute("var1")}); std::vector> deps; deps.push_back(CreateConstValueDirectStep( cel::ErrorValue(absl::InternalError("test1")), -1)); deps.push_back(CreateConstValueDirectStep( cel::ErrorValue(absl::InternalError("test2")), -1)); auto step = CreateDirectListStep(std::move(deps), {}, -1); cel::Value result; AttributeTrail attr; ASSERT_OK(step->Evaluate(frame, result, attr)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kInternal, "test1")); } class SetAttrDirectStep : public DirectExpressionStep { public: explicit SetAttrDirectStep(Attribute attr) : DirectExpressionStep(-1), attr_(std::move(attr)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attr) const override { result = cel::NullValue(); attr = AttributeTrail(attr_); return absl::OkStatus(); } private: cel::Attribute attr_; }; TEST(CreateDirectListStep, MissingAttribute) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; cel::RuntimeOptions options; options.enable_missing_attribute_errors = true; activation.SetMissingPatterns({cel::AttributePattern( "var1", {cel::AttributeQualifierPattern::OfString("field1")})}); ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> deps; deps.push_back(CreateConstValueDirectStep(cel::NullValue(), -1)); deps.push_back(std::make_unique( Attribute("var1", {AttributeQualifier::OfString("field1")}))); auto step = CreateDirectListStep(std::move(deps), {}, -1); cel::Value result; AttributeTrail attr; ASSERT_OK(step->Evaluate(frame, result, attr)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT( Cast(result).NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("var1.field1"))); } TEST(CreateDirectListStep, OptionalPresentSet) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; cel::RuntimeOptions options; ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> deps; deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); deps.push_back(CreateConstValueDirectStep( cel::OptionalValue::Of(IntValue(2), &arena), -1)); auto step = CreateDirectListStep(std::move(deps), {1}, -1); cel::Value result; AttributeTrail attr; ASSERT_OK(step->Evaluate(frame, result, attr)); ASSERT_TRUE(InstanceOf(result)); auto list = Cast(result); EXPECT_THAT(list.Size(), IsOkAndHolds(2)); EXPECT_THAT(list.Get(0, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena), IsOkAndHolds(IntValueIs(1))); EXPECT_THAT(list.Get(1, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena), IsOkAndHolds(IntValueIs(2))); } TEST(CreateDirectListStep, OptionalAbsentNotSet) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; cel::RuntimeOptions options; ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> deps; deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); deps.push_back(CreateConstValueDirectStep(cel::OptionalValue::None(), -1)); auto step = CreateDirectListStep(std::move(deps), {1}, -1); cel::Value result; AttributeTrail attr; ASSERT_OK(step->Evaluate(frame, result, attr)); ASSERT_TRUE(InstanceOf(result)); auto list = Cast(result); EXPECT_THAT(list.Size(), IsOkAndHolds(1)); EXPECT_THAT(list.Get(0, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena), IsOkAndHolds(IntValueIs(1))); } TEST(CreateDirectListStep, PartialUnknown) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; activation.SetUnknownPatterns({cel::AttributePattern( "var1", {cel::AttributeQualifierPattern::OfString("field1")})}); ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> deps; deps.push_back(CreateConstValueDirectStep(cel::IntValue(1), -1)); deps.push_back(std::make_unique(Attribute("var1", {}))); auto step = CreateDirectListStep(std::move(deps), {}, -1); cel::Value result; AttributeTrail attr; ASSERT_OK(step->Evaluate(frame, result, attr)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(UnknownAttrNames(Cast(result)), UnorderedElementsAre("var1")); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/create_map_step.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/create_map_step.h" #include #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/casting.h" #include "common/value.h" #include "common/values/map_value_builder.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { using ::cel::Cast; using ::cel::ErrorValue; using ::cel::ErrorValueAssign; using ::cel::ErrorValueReturn; using ::cel::InstanceOf; using ::cel::MapValueBuilderPtr; using ::cel::UnknownValue; using ::cel::Value; using ::cel::common_internal::NewMapValueBuilder; using ::cel::common_internal::NewMutableMapValue; // `CreateStruct` implementation for map. class CreateStructStepForMap final : public ExpressionStepBase { public: CreateStructStepForMap(int64_t expr_id, size_t entry_count, absl::flat_hash_set optional_indices) : ExpressionStepBase(expr_id), entry_count_(entry_count), optional_indices_(std::move(optional_indices)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: absl::StatusOr DoEvaluate(ExecutionFrame* frame) const; size_t entry_count_; absl::flat_hash_set optional_indices_; }; absl::StatusOr CreateStructStepForMap::DoEvaluate( ExecutionFrame* frame) const { auto args = frame->value_stack().GetSpan(2 * entry_count_); for (const auto& arg : args) { if (arg.IsError()) { return arg; } } if (frame->enable_unknowns()) { absl::optional unknown_set = frame->attribute_utility().IdentifyAndMergeUnknowns( args, frame->value_stack().GetAttributeSpan(args.size()), true); if (unknown_set.has_value()) { return *unknown_set; } } MapValueBuilderPtr builder = NewMapValueBuilder(frame->arena()); builder->Reserve(entry_count_); for (size_t i = 0; i < entry_count_; i += 1) { const auto& map_key = args[2 * i]; CEL_RETURN_IF_ERROR(cel::CheckMapKey(map_key)).With(ErrorValueReturn()); const auto& map_value = args[(2 * i) + 1]; if (optional_indices_.contains(static_cast(i))) { if (auto optional_map_value = map_value.AsOptional(); optional_map_value) { if (!optional_map_value->HasValue()) { continue; } Value optional_map_value_value; optional_map_value->Value(&optional_map_value_value); if (optional_map_value_value.IsError()) { // Error should never be in optional, but better safe than sorry. return optional_map_value_value; } CEL_RETURN_IF_ERROR( builder->Put(map_key, std::move(optional_map_value_value))); } else { return cel::TypeConversionError(map_value.DebugString(), "optional_type"); } } else { CEL_RETURN_IF_ERROR(builder->Put(map_key, map_value)); } } return std::move(*builder).Build(); } absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { if (frame->value_stack().size() < 2 * entry_count_) { return absl::InternalError("CreateStructStepForMap: stack underflow"); } CEL_ASSIGN_OR_RETURN(auto result, DoEvaluate(frame)); frame->value_stack().PopAndPush(2 * entry_count_, std::move(result)); return absl::OkStatus(); } class DirectCreateMapStep : public DirectExpressionStep { public: DirectCreateMapStep(std::vector> deps, absl::flat_hash_set optional_indices, int64_t expr_id) : DirectExpressionStep(expr_id), deps_(std::move(deps)), optional_indices_(std::move(optional_indices)), entry_count_(deps_.size() / 2) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const override; private: std::vector> deps_; absl::flat_hash_set optional_indices_; size_t entry_count_; }; absl::Status DirectCreateMapStep::Evaluate( ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const { auto unknowns = frame.attribute_utility().CreateAccumulator(); MapValueBuilderPtr builder = NewMapValueBuilder(frame.arena()); builder->Reserve(entry_count_); for (size_t i = 0; i < entry_count_; i += 1) { Value key; Value value; AttributeTrail tmp_attr; int map_key_index = 2 * i; int map_value_index = map_key_index + 1; CEL_RETURN_IF_ERROR(deps_[map_key_index]->Evaluate(frame, key, tmp_attr)); if (key.IsError()) { result = std::move(key); return absl::OkStatus(); } if (frame.unknown_processing_enabled()) { if (key.IsUnknown()) { unknowns.Add(key.GetUnknown()); } else if (frame.attribute_utility().CheckForUnknownPartial(tmp_attr)) { unknowns.Add(tmp_attr); } } CEL_RETURN_IF_ERROR(cel::CheckMapKey(key)).With(ErrorValueAssign(result)); CEL_RETURN_IF_ERROR( deps_[map_value_index]->Evaluate(frame, value, tmp_attr)); if (value.IsError()) { result = std::move(value); return absl::OkStatus(); } if (frame.unknown_processing_enabled()) { if (value.IsUnknown()) { unknowns.Add(value.GetUnknown()); } else if (frame.attribute_utility().CheckForUnknownPartial(tmp_attr)) { unknowns.Add(tmp_attr); } } // Preserve the stack machine behavior of forwarding unknowns before // errors. if (!unknowns.IsEmpty()) { continue; } if (optional_indices_.contains(static_cast(i))) { if (auto optional_map_value = value.AsOptional(); optional_map_value) { if (!optional_map_value->HasValue()) { continue; } Value optional_map_value_value; optional_map_value->Value(&optional_map_value_value); if (optional_map_value_value.IsError()) { // Error should never be in optional, but better safe than sorry. result = optional_map_value_value; return absl::OkStatus(); } CEL_RETURN_IF_ERROR( builder->Put(std::move(key), std::move(optional_map_value_value))); continue; } result = cel::TypeConversionError(value.DebugString(), "optional_type"); return absl::OkStatus(); } CEL_RETURN_IF_ERROR(builder->Put(std::move(key), std::move(value))); } if (!unknowns.IsEmpty()) { result = std::move(unknowns).Build(); return absl::OkStatus(); } result = std::move(*builder).Build(); return absl::OkStatus(); } class MutableMapStep final : public ExpressionStep { public: explicit MutableMapStep(int64_t expr_id) : ExpressionStep(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override { frame->value_stack().Push(cel::CustomMapValue( NewMutableMapValue(frame->arena()), frame->arena())); return absl::OkStatus(); } }; class DirectMutableMapStep final : public DirectExpressionStep { public: explicit DirectMutableMapStep(int64_t expr_id) : DirectExpressionStep(expr_id) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override { result = cel::CustomMapValue(NewMutableMapValue(frame.arena()), frame.arena()); return absl::OkStatus(); } }; } // namespace std::unique_ptr CreateDirectCreateMapStep( std::vector> deps, absl::flat_hash_set optional_indices, int64_t expr_id) { return std::make_unique( std::move(deps), std::move(optional_indices), expr_id); } absl::StatusOr> CreateCreateStructStepForMap( size_t entry_count, absl::flat_hash_set optional_indices, int64_t expr_id) { // Make map-creating step. return std::make_unique(expr_id, entry_count, std::move(optional_indices)); } absl::StatusOr> CreateMutableMapStep( int64_t expr_id) { return std::make_unique(expr_id); } std::unique_ptr CreateDirectMutableMapStep( int64_t expr_id) { return std::make_unique(expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/create_map_step.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Creates an expression step that evaluates a create map expression. // // Deps must have an even number of elements, that alternate key, value pairs. // (key1, value1, key2, value2...). std::unique_ptr CreateDirectCreateMapStep( std::vector> deps, absl::flat_hash_set optional_indices, int64_t expr_id); // Creates an `ExpressionStep` which performs `CreateStruct` for a map. absl::StatusOr> CreateCreateStructStepForMap( size_t entry_count, absl::flat_hash_set optional_indices, int64_t expr_id); // Factory method for CreateMap which constructs a mutable map. // // This is intended for the map construction step is generated for a // map-building comprehension (rather than a user authored expression). absl::StatusOr> CreateMutableMapStep( int64_t expr_id); // Factory method for CreateMap which constructs a mutable map. // // This is intended for the map construction step is generated for a // map-building comprehension (rather than a user authored expression). std::unique_ptr CreateDirectMutableMapStep( int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ ================================================ FILE: eval/eval/create_map_step_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/create_map_step.h" #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "base/type_provider.h" #include "common/expr.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::Expr; using ::cel::TypeProvider; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; absl::StatusOr CreateStackMachineProgram( const std::vector>& values, Activation& activation) { ExecutionPath path; Expr expr1; Expr expr0; std::vector exprs; exprs.reserve(values.size() * 2); int index = 0; auto& create_struct = expr1.mutable_struct_expr(); for (const auto& item : values) { std::string key_name = absl::StrCat("key", index); std::string value_name = absl::StrCat("value", index); CEL_ASSIGN_OR_RETURN(auto step_key, CreateIdentStep(key_name, /*expr_id=*/-1)); CEL_ASSIGN_OR_RETURN(auto step_value, CreateIdentStep(value_name, /*expr _id=*/-1)); path.push_back(std::move(step_key)); path.push_back(std::move(step_value)); activation.InsertValue(key_name, item.first); activation.InsertValue(value_name, item.second); create_struct.mutable_fields().emplace_back(); index++; } CEL_ASSIGN_OR_RETURN( auto step1, CreateCreateStructStepForMap(values.size(), {}, expr1.id())); path.push_back(std::move(step1)); return path; } absl::StatusOr CreateRecursiveProgram( const std::vector>& values, Activation& activation) { ExecutionPath path; int index = 0; std::vector> deps; for (const auto& item : values) { std::string key_name = absl::StrCat("key", index); std::string value_name = absl::StrCat("value", index); deps.push_back(CreateDirectIdentStep(key_name, -1)); deps.push_back(CreateDirectIdentStep(value_name, -1)); activation.InsertValue(key_name, item.first); activation.InsertValue(value_name, item.second); index++; } path.push_back(std::make_unique( CreateDirectCreateMapStep(std::move(deps), {}, -1), -1)); return path; } // Helper method. Creates simple pipeline containing CreateStruct step that // builds Map and runs it. // Equivalent to {key0: value0, ...} absl::StatusOr RunCreateMapExpression( const absl_nonnull std::shared_ptr& env, const std::vector>& values, google::protobuf::Arena* arena, bool enable_unknowns, bool enable_recursive_program) { Activation activation; ExecutionPath path; if (enable_recursive_program) { CEL_ASSIGN_OR_RETURN(path, CreateRecursiveProgram(values, activation)); } else { CEL_ASSIGN_OR_RETURN(path, CreateStackMachineProgram(values, activation)); } cel::RuntimeOptions options; if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); return cel_expr.Evaluate(activation, arena); } class CreateMapStepTest : public testing::TestWithParam> { public: CreateMapStepTest() : env_(NewTestingRuntimeEnv()) {} bool enable_unknowns() { return std::get<0>(GetParam()); } bool enable_recursive_program() { return std::get<1>(GetParam()); } absl::StatusOr RunMapExpression( const std::vector>& values) { return RunCreateMapExpression(env_, values, &arena_, enable_unknowns(), enable_recursive_program()); } protected: absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; }; // Test that Empty Map is created successfully. TEST_P(CreateMapStepTest, TestCreateEmptyMap) { ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression({})); ASSERT_TRUE(result.IsMap()); const CelMap* cel_map = result.MapOrDie(); ASSERT_EQ(cel_map->size(), 0); } // Test message creation if unknown argument is passed TEST(CreateMapStepTest, TestMapCreateWithUnknown) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; UnknownSet unknown_set; std::vector> entries; std::vector kKeys = {"test2", "test1"}; entries.push_back( {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); entries.push_back({CelValue::CreateString(&kKeys[1]), CelValue::CreateUnknownSet(&unknown_set)}); ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( env, entries, &arena, true, false)); ASSERT_TRUE(result.IsUnknownSet()); } TEST(CreateMapStepTest, TestMapCreateWithError) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; UnknownSet unknown_set; absl::Status error = absl::CancelledError(); std::vector> entries; entries.push_back({CelValue::CreateStringView("foo"), CelValue::CreateUnknownSet(&unknown_set)}); entries.push_back( {CelValue::CreateStringView("bar"), CelValue::CreateError(&error)}); ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( env, entries, &arena, true, false)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kCancelled)); } TEST(CreateMapStepTest, TestMapCreateWithErrorRecursiveProgram) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; UnknownSet unknown_set; absl::Status error = absl::CancelledError(); std::vector> entries; entries.push_back({CelValue::CreateStringView("foo"), CelValue::CreateUnknownSet(&unknown_set)}); entries.push_back( {CelValue::CreateStringView("bar"), CelValue::CreateError(&error)}); ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( env, entries, &arena, true, true)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kCancelled)); } TEST(CreateMapStepTest, TestMapCreateWithUnknownRecursiveProgram) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; UnknownSet unknown_set; std::vector> entries; std::vector kKeys = {"test2", "test1"}; entries.push_back( {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); entries.push_back({CelValue::CreateString(&kKeys[1]), CelValue::CreateUnknownSet(&unknown_set)}); ASSERT_OK_AND_ASSIGN(CelValue result, RunCreateMapExpression( env, entries, &arena, true, true)); ASSERT_TRUE(result.IsUnknownSet()); } // Test that String Map is created successfully. TEST_P(CreateMapStepTest, TestCreateStringMap) { Arena arena; std::vector> entries; std::vector kKeys = {"test2", "test1"}; entries.push_back( {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); entries.push_back( {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression(entries)); ASSERT_TRUE(result.IsMap()); const CelMap* cel_map = result.MapOrDie(); ASSERT_EQ(cel_map->size(), 2); auto lookup0 = cel_map->Get(&arena, CelValue::CreateString(&kKeys[0])); ASSERT_TRUE(lookup0.has_value()); ASSERT_TRUE(lookup0->IsInt64()) << lookup0->DebugString(); EXPECT_EQ(lookup0->Int64OrDie(), 2); auto lookup1 = cel_map->Get(&arena, CelValue::CreateString(&kKeys[1])); ASSERT_TRUE(lookup1.has_value()); ASSERT_TRUE(lookup1->IsInt64()); EXPECT_EQ(lookup1->Int64OrDie(), 1); } INSTANTIATE_TEST_SUITE_P(CreateMapStep, CreateMapStepTest, testing::Combine(testing::Bool(), testing::Bool())); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/create_struct_step.cc ================================================ // Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/create_struct_step.h" #include #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/casting.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { using ::cel::Cast; using ::cel::ErrorValue; using ::cel::InstanceOf; using ::cel::StructValueBuilderInterface; using ::cel::UnknownValue; using ::cel::Value; // `CreateStruct` implementation for message/struct. class CreateStructStepForStruct final : public ExpressionStepBase { public: CreateStructStepForStruct(int64_t expr_id, std::string name, std::vector entries, absl::flat_hash_set optional_indices) : ExpressionStepBase(expr_id), name_(std::move(name)), entries_(std::move(entries)), optional_indices_(std::move(optional_indices)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: absl::StatusOr DoEvaluate(ExecutionFrame* frame) const; std::string name_; std::vector entries_; absl::flat_hash_set optional_indices_; }; absl::StatusOr CreateStructStepForStruct::DoEvaluate( ExecutionFrame* frame) const { int entries_size = entries_.size(); auto args = frame->value_stack().GetSpan(entries_size); for (const auto& arg : args) { if (arg.IsError()) { return arg; } } if (frame->enable_unknowns()) { absl::optional unknown_set = frame->attribute_utility().IdentifyAndMergeUnknowns( args, frame->value_stack().GetAttributeSpan(entries_size), /*use_partial=*/true); if (unknown_set.has_value()) { return *unknown_set; } } CEL_ASSIGN_OR_RETURN(auto builder, frame->type_provider().NewValueBuilder( name_, frame->message_factory(), frame->arena())); if (builder == nullptr) { return ErrorValue( absl::NotFoundError(absl::StrCat("Unable to find builder: ", name_))); } for (int i = 0; i < entries_size; ++i) { const auto& entry = entries_[i]; const auto& arg = args[i]; if (optional_indices_.contains(static_cast(i))) { if (auto optional_arg = arg.AsOptional(); optional_arg) { if (!optional_arg->HasValue()) { continue; } Value optional_arg_value; optional_arg->Value(&optional_arg_value); if (optional_arg_value.IsError()) { // Error should never be in optional, but better safe than sorry. return optional_arg_value; } CEL_ASSIGN_OR_RETURN( absl::optional error_value, builder->SetFieldByName(entry, std::move(optional_arg_value))); if (error_value) { return std::move(*error_value); } } else { return cel::TypeConversionError(arg.DebugString(), "optional_type"); } } else { CEL_ASSIGN_OR_RETURN(absl::optional error_value, builder->SetFieldByName(entry, arg)); if (error_value) { return std::move(*error_value); } } } return std::move(*builder).Build(); } absl::Status CreateStructStepForStruct::Evaluate(ExecutionFrame* frame) const { if (frame->value_stack().size() < entries_.size()) { return absl::InternalError("CreateStructStepForStruct: stack underflow"); } CEL_ASSIGN_OR_RETURN(Value result, DoEvaluate(frame)); frame->value_stack().PopAndPush(entries_.size(), std::move(result)); return absl::OkStatus(); } class DirectCreateStructStep : public DirectExpressionStep { public: DirectCreateStructStep( int64_t expr_id, std::string name, std::vector field_keys, std::vector> deps, absl::flat_hash_set optional_indices) : DirectExpressionStep(expr_id), name_(std::move(name)), field_keys_(std::move(field_keys)), deps_(std::move(deps)), optional_indices_(std::move(optional_indices)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const override; private: std::string name_; std::vector field_keys_; std::vector> deps_; absl::flat_hash_set optional_indices_; }; absl::Status DirectCreateStructStep::Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& trail) const { Value field_value; AttributeTrail field_attr; auto unknowns = frame.attribute_utility().CreateAccumulator(); CEL_ASSIGN_OR_RETURN(auto builder, frame.type_provider().NewValueBuilder( name_, frame.message_factory(), frame.arena())); if (builder == nullptr) { result = cel::ErrorValue( absl::NotFoundError(absl::StrCat("Unable to find builder: ", name_))); return absl::OkStatus(); } for (int i = 0; i < field_keys_.size(); i++) { CEL_RETURN_IF_ERROR(deps_[i]->Evaluate(frame, field_value, field_attr)); // TODO(uncreated-issue/67): if the value is an error, we should be able to return // early, however some client tests depend on the error message the struct // impl returns in the stack machine version. if (field_value.IsError()) { result = std::move(field_value); return absl::OkStatus(); } if (frame.unknown_processing_enabled()) { if (field_value.IsUnknown()) { unknowns.Add(field_value.GetUnknown()); } else if (frame.attribute_utility().CheckForUnknownPartial(field_attr)) { unknowns.Add(field_attr); } } if (!unknowns.IsEmpty()) { continue; } if (optional_indices_.contains(static_cast(i))) { if (auto optional_arg = field_value.AsOptional(); optional_arg) { if (!optional_arg->HasValue()) { continue; } Value optional_arg_value; optional_arg->Value(&optional_arg_value); if (optional_arg_value.IsError()) { // Error should never be in optional, but better safe than sorry. result = std::move(optional_arg_value); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN( absl::optional error_value, builder->SetFieldByName(field_keys_[i], std::move(optional_arg_value))); if (error_value) { result = std::move(*error_value); return absl::OkStatus(); } continue; } else { result = cel::TypeConversionError(field_value.DebugString(), "optional_type"); return absl::OkStatus(); } } CEL_ASSIGN_OR_RETURN( absl::optional error_value, builder->SetFieldByName(field_keys_[i], std::move(field_value))); if (error_value) { result = std::move(*error_value); return absl::OkStatus(); } } if (!unknowns.IsEmpty()) { result = std::move(unknowns).Build(); return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(result, std::move(*builder).Build()); return absl::OkStatus(); } } // namespace std::unique_ptr CreateDirectCreateStructStep( std::string resolved_name, std::vector field_keys, std::vector> deps, absl::flat_hash_set optional_indices, int64_t expr_id) { return std::make_unique( expr_id, std::move(resolved_name), std::move(field_keys), std::move(deps), std::move(optional_indices)); } std::unique_ptr CreateCreateStructStep( std::string name, std::vector field_keys, absl::flat_hash_set optional_indices, int64_t expr_id) { // MakeOptionalIndicesSet(create_struct_expr) return std::make_unique( expr_id, std::move(name), std::move(field_keys), std::move(optional_indices)); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/create_struct_step.h ================================================ // Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ #include #include #include #include #include "absl/container/flat_hash_set.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Creates an `ExpressionStep` which performs `CreateStruct` for a // message/struct. std::unique_ptr CreateDirectCreateStructStep( std::string name, std::vector field_keys, std::vector> deps, absl::flat_hash_set optional_indices, int64_t expr_id); // Creates an `ExpressionStep` which performs `CreateStruct` for a // message/struct. std::unique_ptr CreateCreateStructStep( std::string name, std::vector field_keys, absl::flat_hash_set optional_indices, int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ ================================================ FILE: eval/eval/create_struct_step_test.cc ================================================ // Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/create_struct_step.h" #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/type_provider.h" #include "common/expr.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::Expr; using ::cel::TypeProvider; using ::cel::internal::test::EqualsProto; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::google::protobuf::Message; using ::testing::Eq; using ::testing::IsNull; using ::testing::Not; using ::testing::Pointwise; absl::StatusOr MakeStackMachinePath(absl::string_view field) { ExecutionPath path; CEL_ASSIGN_OR_RETURN(auto step0, CreateIdentStep("message", /*expr_id=*/-1)); auto step1 = CreateCreateStructStep("google.api.expr.runtime.TestMessage", {std::string(field)}, /*optional_indices=*/{}, /*id=*/-1); path.push_back(std::move(step0)); path.push_back(std::move(step1)); return path; } absl::StatusOr MakeRecursivePath(absl::string_view field) { ExecutionPath path; std::vector> deps; deps.push_back(CreateDirectIdentStep("message", -1)); auto step1 = CreateDirectCreateStructStep("google.api.expr.runtime.TestMessage", {std::string(field)}, std::move(deps), /*optional_indices=*/{}, /*id=*/-1); path.push_back(std::make_unique(std::move(step1), -1)); return path; } // Helper method. Creates simple pipeline containing CreateStruct step that // builds message and runs it. absl::StatusOr RunExpression( const absl_nonnull std::shared_ptr& env, absl::string_view field, const CelValue& value, google::protobuf::Arena* arena, bool enable_unknowns, bool enable_recursive_planning) { google::protobuf::LinkMessageReflection(); CEL_ASSIGN_OR_RETURN(auto maybe_type, env->type_registry.GetComposedTypeProvider().FindType( "google.api.expr.runtime.TestMessage")); if (!maybe_type.has_value()) { return absl::Status(absl::StatusCode::kFailedPrecondition, "missing proto message type"); } cel::RuntimeOptions options; if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } ExecutionPath path; if (enable_recursive_planning) { CEL_ASSIGN_OR_RETURN(path, MakeRecursivePath(field)); } else { CEL_ASSIGN_OR_RETURN(path, MakeStackMachinePath(field)); } CelExpressionFlatImpl cel_expr( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", value); return cel_expr.Evaluate(activation, arena); } void RunExpressionAndGetMessage( const absl_nonnull std::shared_ptr& env, absl::string_view field, const CelValue& value, google::protobuf::Arena* arena, TestMessage* test_msg, bool enable_unknowns, bool enable_recursive_planning) { ASSERT_OK_AND_ASSIGN(auto result, RunExpression(env, field, value, arena, enable_unknowns, enable_recursive_planning)); ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); ASSERT_THAT(msg, Not(IsNull())); ASSERT_EQ(msg->GetDescriptor()->full_name(), "google.api.expr.runtime.TestMessage"); test_msg->MergePartialFromString(msg->SerializePartialAsCord()); } void RunExpressionAndGetMessage( const absl_nonnull std::shared_ptr& env, absl::string_view field, std::vector values, google::protobuf::Arena* arena, TestMessage* test_msg, bool enable_unknowns, bool enable_recursive_planning) { ContainerBackedListImpl cel_list(std::move(values)); CelValue value = CelValue::CreateList(&cel_list); ASSERT_OK_AND_ASSIGN(auto result, RunExpression(env, field, value, arena, enable_unknowns, enable_recursive_planning)); ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); ASSERT_THAT(msg, Not(IsNull())); ASSERT_EQ(msg->GetDescriptor()->full_name(), "google.api.expr.runtime.TestMessage"); test_msg->MergePartialFromString(msg->SerializePartialAsCord()); } class CreateCreateStructStepTest : public testing::TestWithParam> { public: CreateCreateStructStepTest() : env_(NewTestingRuntimeEnv()) {} bool enable_unknowns() { return std::get<0>(GetParam()); } bool enable_recursive_planning() { return std::get<1>(GetParam()); } protected: absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; }; TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; auto adapter = env_->legacy_type_registry.FindTypeAdapter( "google.api.expr.runtime.TestMessage"); ASSERT_TRUE(adapter.has_value() && adapter->mutation_apis() != nullptr); ASSERT_OK_AND_ASSIGN(auto maybe_type, env_->type_registry.GetComposedTypeProvider().FindType( "google.api.expr.runtime.TestMessage")); ASSERT_TRUE(maybe_type.has_value()); if (enable_recursive_planning()) { auto step = CreateDirectCreateStructStep("google.api.expr.runtime.TestMessage", /*fields=*/{}, /*deps=*/{}, /*optional_indices=*/{}, /*id=*/-1); path.push_back( std::make_unique(std::move(step), /*id=*/-1)); } else { auto step = CreateCreateStructStep("google.api.expr.runtime.TestMessage", /*fields=*/{}, /*optional_indices=*/{}, /*id=*/-1); path.push_back(std::move(step)); } cel::RuntimeOptions options; if (enable_unknowns(), enable_recursive_planning()) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); ASSERT_THAT(msg, Not(IsNull())); ASSERT_EQ(msg->GetDescriptor()->full_name(), "google.api.expr.runtime.TestMessage"); } TEST(CreateCreateStructStepTest, TestMessageCreateError) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; TestMessage test_msg; absl::Status error = absl::CancelledError(); auto eval_status = RunExpression(env, "bool_value", CelValue::CreateError(&error), &arena, true, /*enable_recursive_planning=*/false); ASSERT_THAT(eval_status, IsOk()); EXPECT_THAT(*eval_status->ErrorOrDie(), StatusIs(absl::StatusCode::kCancelled)); } TEST(CreateCreateStructStepTest, TestMessageCreateErrorRecursive) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; TestMessage test_msg; absl::Status error = absl::CancelledError(); auto eval_status = RunExpression(env, "bool_value", CelValue::CreateError(&error), &arena, true, /*enable_recursive_planning=*/true); ASSERT_THAT(eval_status, IsOk()); EXPECT_THAT(*eval_status->ErrorOrDie(), StatusIs(absl::StatusCode::kCancelled)); } // Test message creation if unknown argument is passed TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknown) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; TestMessage test_msg; UnknownSet unknown_set; auto eval_status = RunExpression(env, "bool_value", CelValue::CreateUnknownSet(&unknown_set), &arena, true, /*enable_recursive_planning=*/false); ASSERT_OK(eval_status); ASSERT_TRUE(eval_status->IsUnknownSet()); } // Test message creation if unknown argument is passed TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknownRecursive) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); Arena arena; TestMessage test_msg; UnknownSet unknown_set; auto eval_status = RunExpression(env, "bool_value", CelValue::CreateUnknownSet(&unknown_set), &arena, true, /*enable_recursive_planning=*/true); ASSERT_OK(eval_status); ASSERT_TRUE(eval_status->IsUnknownSet()) << eval_status->DebugString(); } // Test that fields of type bool are set correctly TEST_P(CreateCreateStructStepTest, TestSetBoolField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "bool_value", CelValue::CreateBool(true), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.bool_value(), true); } // Test that fields of type int32 are set correctly TEST_P(CreateCreateStructStepTest, TestSetInt32Field) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "int32_value", CelValue::CreateInt64(1), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.int32_value(), 1); } // Test that fields of type uint32 are set correctly. TEST_P(CreateCreateStructStepTest, TestSetUInt32Field) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "uint32_value", CelValue::CreateUint64(1), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.uint32_value(), 1); } // Test that fields of type int64 are set correctly. TEST_P(CreateCreateStructStepTest, TestSetInt64Field) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "int64_value", CelValue::CreateInt64(1), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.int64_value(), 1); } // Test that fields of type uint64 are set correctly. TEST_P(CreateCreateStructStepTest, TestSetUInt64Field) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "uint64_value", CelValue::CreateUint64(1), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.uint64_value(), 1); } // Test that fields of type float are set correctly TEST_P(CreateCreateStructStepTest, TestSetFloatField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "float_value", CelValue::CreateDouble(2.0), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_DOUBLE_EQ(test_msg.float_value(), 2.0); } // Test that fields of type double are set correctly TEST_P(CreateCreateStructStepTest, TestSetDoubleField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "double_value", CelValue::CreateDouble(2.0), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_DOUBLE_EQ(test_msg.double_value(), 2.0); } // Test that fields of type string are set correctly. TEST_P(CreateCreateStructStepTest, TestSetStringField) { const std::string kTestStr = "test"; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "string_value", CelValue::CreateString(&kTestStr), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.string_value(), kTestStr); } // Test that fields of type bytes are set correctly. TEST_P(CreateCreateStructStepTest, TestSetBytesField) { const std::string kTestStr = "test"; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "bytes_value", CelValue::CreateBytes(&kTestStr), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.bytes_value(), kTestStr); } // Test that fields of type duration are set correctly. TEST_P(CreateCreateStructStepTest, TestSetDurationField) { google::protobuf::Duration test_duration; test_duration.set_seconds(2); test_duration.set_nanos(3); TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "duration_value", CelProtoWrapper::CreateDuration(&test_duration), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.duration_value(), EqualsProto(test_duration)); } // Test that fields of type timestamp are set correctly. TEST_P(CreateCreateStructStepTest, TestSetTimestampField) { google::protobuf::Timestamp test_timestamp; test_timestamp.set_seconds(2); test_timestamp.set_nanos(3); TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "timestamp_value", CelProtoWrapper::CreateTimestamp(&test_timestamp), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.timestamp_value(), EqualsProto(test_timestamp)); } // Test that fields of type Message are set correctly. TEST_P(CreateCreateStructStepTest, TestSetMessageField) { // Create payload message and set some fields. TestMessage orig_msg; orig_msg.set_bool_value(true); orig_msg.set_string_value("test"); TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "message_value", CelProtoWrapper::CreateMessage(&orig_msg, &arena_), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.message_value(), EqualsProto(orig_msg)); } // Test that fields of type Any are set correctly. TEST_P(CreateCreateStructStepTest, TestSetAnyField) { // Create payload message and set some fields. TestMessage orig_embedded_msg; orig_embedded_msg.set_bool_value(true); orig_embedded_msg.set_string_value("embedded"); TestMessage orig_msg; orig_msg.mutable_any_value()->PackFrom(orig_embedded_msg); TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "any_value", CelProtoWrapper::CreateMessage(&orig_embedded_msg, &arena_), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg, EqualsProto(orig_msg)); TestMessage test_embedded_msg; ASSERT_TRUE(test_msg.any_value().UnpackTo(&test_embedded_msg)); EXPECT_THAT(test_embedded_msg, EqualsProto(orig_embedded_msg)); } // Test that fields of type Message are set correctly. TEST_P(CreateCreateStructStepTest, TestSetEnumField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "enum_value", CelValue::CreateInt64(TestMessage::TEST_ENUM_2), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.enum_value(), TestMessage::TEST_ENUM_2); } // Test that fields of type bool are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedBoolField) { TestMessage test_msg; std::vector kValues = {true, false}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateBool(value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "bool_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.bool_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type int32 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { TestMessage test_msg; std::vector kValues = {23, 12}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateInt64(value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "int32_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.int32_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint32 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { TestMessage test_msg; std::vector kValues = {23, 12}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateUint64(value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "uint32_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.uint32_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type int64 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { TestMessage test_msg; std::vector kValues = {23, 12}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateInt64(value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "int64_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.int64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint64 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { TestMessage test_msg; std::vector kValues = {23, 12}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateUint64(value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "uint64_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.uint64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type float are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedFloatField) { TestMessage test_msg; std::vector kValues = {23, 12}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateDouble(value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "float_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.float_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint32 are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { TestMessage test_msg; std::vector kValues = {23, 12}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateDouble(value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "double_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.double_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedStringField) { TestMessage test_msg; std::vector kValues = {"test1", "test2"}; std::vector values; for (const auto& value : kValues) { values.push_back(CelValue::CreateString(&value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "string_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.string_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { TestMessage test_msg; std::vector kValues = {"test1", "test2"}; std::vector values; for (const auto& value : kValues) { values.push_back(CelValue::CreateBytes(&value)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "bytes_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.bytes_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type Message are set correctly TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { TestMessage test_msg; std::vector kValues(2); kValues[0].set_string_value("test1"); kValues[1].set_string_value("test2"); std::vector values; for (const auto& value : kValues) { values.push_back(CelProtoWrapper::CreateMessage(&value, &arena_)); } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "message_list", values, &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_THAT(test_msg.message_list()[0], EqualsProto(kValues[0])); ASSERT_THAT(test_msg.message_list()[1], EqualsProto(kValues[1])); } // Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { TestMessage test_msg; std::vector> entries; const std::vector kKeys = {"test2", "test1"}; entries.push_back( {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); entries.push_back( {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); auto cel_map = *CreateContainerBackedMap(absl::Span>( entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.string_int32_map().size(), 2); ASSERT_EQ(test_msg.string_int32_map().at(kKeys[0]), 2); ASSERT_EQ(test_msg.string_int32_map().at(kKeys[1]), 1); } // Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { TestMessage test_msg; std::vector> entries; const std::vector kKeys = {3, 4}; entries.push_back( {CelValue::CreateInt64(kKeys[0]), CelValue::CreateInt64(1)}); entries.push_back( {CelValue::CreateInt64(kKeys[1]), CelValue::CreateInt64(2)}); auto cel_map = *CreateContainerBackedMap(absl::Span>( entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.int64_int32_map().size(), 2); ASSERT_EQ(test_msg.int64_int32_map().at(kKeys[0]), 1); ASSERT_EQ(test_msg.int64_int32_map().at(kKeys[1]), 2); } // Test that fields of type map are set correctly TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { TestMessage test_msg; std::vector> entries; const std::vector kKeys = {3, 4}; entries.push_back( {CelValue::CreateUint64(kKeys[0]), CelValue::CreateInt64(1)}); entries.push_back( {CelValue::CreateUint64(kKeys[1]), CelValue::CreateInt64(2)}); auto cel_map = *CreateContainerBackedMap(absl::Span>( entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( env_, "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena_, &test_msg, enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.uint64_int32_map().size(), 2); ASSERT_EQ(test_msg.uint64_int32_map().at(kKeys[0]), 1); ASSERT_EQ(test_msg.uint64_int32_map().at(kKeys[1]), 2); } INSTANTIATE_TEST_SUITE_P(CombinedCreateStructTest, CreateCreateStructStepTest, testing::Combine(testing::Bool(), testing::Bool())); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/direct_expression_step.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/direct_expression_step.h" #include #include "absl/status/status.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/evaluator_core.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { absl::Status WrappedDirectStep::Evaluate(ExecutionFrame* frame) const { cel::Value result; AttributeTrail attribute_trail; CEL_RETURN_IF_ERROR(impl_->Evaluate(*frame, result, attribute_trail)); frame->value_stack().Push(std::move(result), std::move(attribute_trail)); return absl::OkStatus(); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/direct_expression_step.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ #include #include #include #include #include "absl/status/status.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Represents a directly evaluated CEL expression. // // Subexpressions assign to values on the C++ program stack and call their // dependencies directly. // // This reduces the setup overhead for evaluation and minimizes value churn // to / from a heap based value stack managed by the CEL runtime, but can't be // used for arbitrarily nested expressions. class DirectExpressionStep { public: explicit DirectExpressionStep(int64_t expr_id) : expr_id_(expr_id) {} DirectExpressionStep() : expr_id_(-1) {} virtual ~DirectExpressionStep() = default; int64_t expr_id() const { return expr_id_; } bool comes_from_ast() const { return expr_id_ >= 0; } virtual absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& attribute) const = 0; // Return a type id for this node. // // Users must not make any assumptions about the type if the default value is // returned. virtual cel::NativeTypeId GetNativeTypeId() const { return cel::NativeTypeId(); } // Implementations optionally support inspecting the program tree. virtual absl::optional> GetDependencies() const { return absl::nullopt; } // Implementations optionally support extracting the program tree. // // Extract prevents the callee from functioning, and is only intended for use // when replacing a given expression step. virtual absl::optional>> ExtractDependencies() { return absl::nullopt; }; protected: int64_t expr_id_; }; // Wrapper for direct steps to work with the stack machine impl. class WrappedDirectStep : public ExpressionStep { public: WrappedDirectStep(std::unique_ptr impl, int64_t expr_id) : ExpressionStep(expr_id, false), impl_(std::move(impl)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; cel::NativeTypeId GetNativeTypeId() const override { return cel::NativeTypeId::For(); } const DirectExpressionStep* wrapped() const { return impl_.get(); } private: std::unique_ptr impl_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ ================================================ FILE: eval/eval/equality_steps.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/equality_steps.h" #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/builtins.h" #include "common/value.h" #include "common/value_kind.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/number.h" #include "internal/status_macros.h" #include "runtime/internal/errors.h" #include "runtime/standard/equality_functions.h" namespace google::api::expr::runtime { namespace { using ::cel::BoolValue; using ::cel::IntValue; using ::cel::MapValue; using ::cel::UintValue; using ::cel::Value; using ::cel::ValueKind; using ::cel::internal::Number; using ::cel::runtime_internal::ValueEqualImpl; absl::StatusOr EvaluateEquality( ExecutionFrameBase& frame, const Value& lhs, const AttributeTrail& lhs_attr, const Value& rhs, const AttributeTrail& rhs_attr, bool negation) { if (lhs.IsError()) { return lhs; } if (rhs.IsError()) { return rhs; } if (frame.unknown_processing_enabled()) { auto accu = frame.attribute_utility().CreateAccumulator(); accu.MaybeAdd(lhs, lhs_attr); accu.MaybeAdd(rhs, rhs_attr); if (!accu.IsEmpty()) { return std::move(accu).Build(); } } CEL_ASSIGN_OR_RETURN(auto is_equal, ValueEqualImpl(lhs, rhs, frame.descriptor_pool(), frame.message_factory(), frame.arena())); if (!is_equal.has_value()) { return cel::ErrorValue(cel::runtime_internal::CreateNoMatchingOverloadError( negation ? cel::builtin::kInequal : cel::builtin::kEqual)); } return negation ? BoolValue(!*is_equal) : BoolValue(*is_equal); } class DirectEqualityStep : public DirectExpressionStep { public: explicit DirectEqualityStep(std::unique_ptr lhs, std::unique_ptr rhs, bool negation, int64_t expr_id) : DirectExpressionStep(expr_id), lhs_(std::move(lhs)), rhs_(std::move(rhs)), negation_(negation) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const override { AttributeTrail lhs_attr; CEL_RETURN_IF_ERROR(lhs_->Evaluate(frame, result, lhs_attr)); Value rhs_result; AttributeTrail rhs_attr; CEL_RETURN_IF_ERROR(rhs_->Evaluate(frame, rhs_result, rhs_attr)); CEL_ASSIGN_OR_RETURN( result, EvaluateEquality(frame, result, lhs_attr, rhs_result, rhs_attr, negation_)); return absl::OkStatus(); } private: std::unique_ptr lhs_; std::unique_ptr rhs_; bool negation_; }; class IterativeEqualityStep : public ExpressionStepBase { public: explicit IterativeEqualityStep(bool negation, int64_t expr_id) : ExpressionStepBase(expr_id), negation_(negation) {} absl::Status Evaluate(ExecutionFrame* frame) const override { if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } auto args = frame->value_stack().GetSpan(2); auto attrs = frame->value_stack().GetAttributeSpan(2); CEL_ASSIGN_OR_RETURN(Value result, EvaluateEquality(*frame, args[0], attrs[0], args[1], attrs[1], negation_)); frame->value_stack().PopAndPush(2, std::move(result)); return absl::OkStatus(); } private: bool negation_; }; absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, const Value& item, const MapValue& container) { switch (item.kind()) { case ValueKind::kBool: case ValueKind::kString: case ValueKind::kInt: case ValueKind::kUint: case ValueKind::kDouble: break; default: return cel::ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError( cel::builtin::kIn)); } Value result; CEL_RETURN_IF_ERROR(container.Has(item, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); if (result.IsTrue()) { return result; } if (item.IsDouble() || item.IsUint()) { Number number = item.IsDouble() ? Number::FromDouble(item.GetDouble().NativeValue()) : Number::FromUint64(item.GetUint().NativeValue()); if (number.LosslessConvertibleToInt()) { CEL_RETURN_IF_ERROR( container.Has(IntValue(number.AsInt()), frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); if (result.IsTrue()) { return result; } } } if (item.IsDouble() || item.IsInt()) { Number number = item.IsDouble() ? Number::FromDouble(item.GetDouble().NativeValue()) : Number::FromInt64(item.GetInt().NativeValue()); if (number.LosslessConvertibleToUint()) { CEL_RETURN_IF_ERROR( container.Has(UintValue(number.AsUint()), frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); if (result.IsTrue()) { return result; } } } return BoolValue(false); } absl::StatusOr EvaluateIn(ExecutionFrameBase& frame, const Value& item, const AttributeTrail& item_attr, const Value& container, const AttributeTrail& container_attr) { if (item.IsError()) { return item; } if (container.IsError()) { return container; } if (frame.unknown_processing_enabled()) { auto accu = frame.attribute_utility().CreateAccumulator(); accu.MaybeAdd(item, item_attr); accu.MaybeAdd(container, container_attr); if (!accu.IsEmpty()) { return std::move(accu).Build(); } } if (container.IsList()) { return container.GetList().Contains(item, frame.descriptor_pool(), frame.message_factory(), frame.arena()); } if (container.IsMap()) { return EvaluateInMap(frame, item, container.GetMap()); } return cel::ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError(cel::builtin::kIn)); } class DirectInStep : public DirectExpressionStep { public: explicit DirectInStep(std::unique_ptr item, std::unique_ptr container, int64_t expr_id) : DirectExpressionStep(expr_id), item_(std::move(item)), container_(std::move(container)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const override { AttributeTrail item_attr; CEL_RETURN_IF_ERROR(item_->Evaluate(frame, result, item_attr)); Value container_result; AttributeTrail container_attr; CEL_RETURN_IF_ERROR( container_->Evaluate(frame, container_result, container_attr)); CEL_ASSIGN_OR_RETURN(result, EvaluateIn(frame, result, item_attr, container_result, container_attr)); return absl::OkStatus(); } private: std::unique_ptr item_; std::unique_ptr container_; }; class IterativeInStep : public ExpressionStepBase { public: explicit IterativeInStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override { if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } auto args = frame->value_stack().GetSpan(2); auto attrs = frame->value_stack().GetAttributeSpan(2); CEL_ASSIGN_OR_RETURN( Value result, EvaluateIn(*frame, args[0], attrs[0], args[1], attrs[1])); frame->value_stack().PopAndPush(2, std::move(result)); return absl::OkStatus(); } }; } // namespace // Factory method for recursive _==_ and _!=_ Execution step std::unique_ptr CreateDirectEqualityStep( std::unique_ptr lhs, std::unique_ptr rhs, bool negation, int64_t expr_id) { return std::make_unique(std::move(lhs), std::move(rhs), negation, expr_id); } // Factory method for iterative _==_ and _!=_ Execution step std::unique_ptr CreateEqualityStep(bool negation, int64_t expr_id) { return std::make_unique(negation, expr_id); } // Factory method for recursive @in Execution step std::unique_ptr CreateDirectInStep( std::unique_ptr item, std::unique_ptr container, int64_t expr_id) { return std::make_unique(std::move(item), std::move(container), expr_id); } // Factory method for iterative @in Execution step std::unique_ptr CreateInStep(int64_t expr_id) { return std::make_unique(expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/equality_steps.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EQUALITY_STEPS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EQUALITY_STEPS_H_ #include #include #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Factory method for recursive _==_/_!=_ Execution step std::unique_ptr CreateDirectEqualityStep( std::unique_ptr lhs, std::unique_ptr rhs, bool negation, int64_t expr_id); // Factory method for iterative _==_/_!=_ Execution step std::unique_ptr CreateEqualityStep(bool negation, int64_t expr_id); // Factory method for recursive @in Execution step std::unique_ptr CreateDirectInStep( std::unique_ptr item, std::unique_ptr container, int64_t expr_id); // Factory method for iterative @in Execution step std::unique_ptr CreateInStep(int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EQUALITY_STEPS_H_ ================================================ FILE: eval/eval/equality_steps_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/equality_steps.h" #include #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "base/attribute.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::cel::Attribute; using ::cel::DoubleValue; using ::cel::ErrorValue; using ::cel::IntValue; using ::cel::UnknownValue; using ::cel::Value; using ::cel::ValueKind; using ::cel::test::BoolValueIs; using ::cel::test::ValueKindIs; class ValueStep : public ExpressionStep, public DirectExpressionStep { public: ValueStep(Value value, Attribute attr) : ExpressionStep(-1), DirectExpressionStep(-1), value_(std::move(value)), attr_(std::move(attr)) {} explicit ValueStep(Value value) : ExpressionStep(-1), DirectExpressionStep(-1), value_(std::move(value)), attr_() {} absl::Status Evaluate(ExecutionFrame* frame) const override { frame->value_stack().Push(value_, attr_); return absl::OkStatus(); } absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const override { result = value_; attribute_trail = attr_; return absl::OkStatus(); } private: Value value_; AttributeTrail attr_; }; TEST(RecursiveTest, PartialAttrUnknown) { cel::Activation activation; google::protobuf::Arena arena; cel::RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); // A little contrived for simplicity, but this is for cases where e.g. // `msg == Msg{}` but msg.foo is unknown. auto plan = CreateDirectEqualityStep( std::make_unique(IntValue(1), cel::Attribute("foo")), std::make_unique(IntValue(2)), false, -1); activation.SetUnknownPatterns({cel::AttributePattern( "foo", {cel::AttributeQualifierPattern::OfString("bar")})}); ExecutionFrameBase frame(activation, opts, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); cel::Value result; AttributeTrail attribute_trail; ASSERT_THAT(plan->Evaluate(frame, result, attribute_trail), IsOk()); EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); } TEST(RecursiveTest, PartialAttrUnknownDisabled) { cel::Activation activation; google::protobuf::Arena arena; cel::RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kDisabled; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); auto plan = CreateDirectEqualityStep( std::make_unique(IntValue(1), cel::Attribute("foo")), std::make_unique(IntValue(2)), false, -1); activation.SetUnknownPatterns({cel::AttributePattern( "foo", {cel::AttributeQualifierPattern::OfString("bar")})}); ExecutionFrameBase frame(activation, opts, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); cel::Value result; AttributeTrail attribute_trail; ASSERT_THAT(plan->Evaluate(frame, result, attribute_trail), IsOk()); EXPECT_THAT(result, BoolValueIs(false)); } TEST(IterativeTest, PartialAttrUnknown) { cel::Activation activation; google::protobuf::Arena arena; cel::RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); FlatExpressionEvaluatorState state( /*value_stack_size=*/5, /*comprehension_slot_count=*/0, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> steps; steps.push_back( std::make_unique(IntValue(1), cel::Attribute("foo"))); steps.push_back(std::make_unique(IntValue(2))); steps.push_back(CreateEqualityStep(false, -1)); activation.SetUnknownPatterns({cel::AttributePattern( "foo", {cel::AttributeQualifierPattern::OfString("bar")})}); ExecutionFrame frame(steps, activation, opts, state); ASSERT_OK_AND_ASSIGN(Value result, frame.Evaluate()); EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); } TEST(IterativeTest, PartialAttrUnknownDisabled) { cel::Activation activation; google::protobuf::Arena arena; cel::RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kDisabled; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); FlatExpressionEvaluatorState state( /*value_stack_size=*/5, /*comprehension_slot_count=*/0, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> steps; steps.push_back( std::make_unique(IntValue(1), cel::Attribute("foo"))); steps.push_back(std::make_unique(IntValue(2))); steps.push_back(CreateEqualityStep(false, -1)); activation.SetUnknownPatterns({cel::AttributePattern( "foo", {cel::AttributeQualifierPattern::OfString("bar")})}); ExecutionFrame frame(steps, activation, opts, state); ASSERT_OK_AND_ASSIGN(Value result, frame.Evaluate()); EXPECT_THAT(result, BoolValueIs(false)); } enum class InputType { kInt1, kInt2, kDouble1, kList, kMap, kError, kUnknown }; enum class OutputType { kBoolTrue, kBoolFalse, kError, kUnknown }; struct EqualsTestCase { InputType lhs; InputType rhs; bool negation; OutputType expected_result; }; class EqualsTest : public ::testing::TestWithParam {}; Value MakeValue(InputType type, google::protobuf::Arena* absl_nonnull arena) { switch (type) { case InputType::kInt1: return IntValue(1); case InputType::kInt2: return IntValue(2); case InputType::kDouble1: return DoubleValue(1.0); case InputType::kUnknown: return UnknownValue(); case InputType::kList: { auto builder = cel::NewListValueBuilder(arena); ABSL_CHECK_OK((builder)->Add(IntValue(1))); return (std::move(*builder)).Build(); } case InputType::kMap: { auto builder = cel::NewMapValueBuilder(arena); ABSL_CHECK_OK((builder)->Put(IntValue(1), IntValue(2))); return (std::move(*builder)).Build(); } case InputType::kError: default: return ErrorValue(absl::InternalError("error")); } } TEST_P(EqualsTest, Recursive) { const EqualsTestCase& test_case = GetParam(); cel::Activation activation; google::protobuf::Arena arena; cel::RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); auto plan = CreateDirectEqualityStep( std::make_unique(MakeValue(test_case.lhs, &arena)), std::make_unique(MakeValue(test_case.rhs, &arena)), test_case.negation, -1); ExecutionFrameBase frame(activation, opts, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); cel::Value result; AttributeTrail attribute_trail; ASSERT_THAT(plan->Evaluate(frame, result, attribute_trail), IsOk()); switch (test_case.expected_result) { case OutputType::kBoolTrue: EXPECT_THAT(result, BoolValueIs(true)); break; case OutputType::kBoolFalse: EXPECT_THAT(result, BoolValueIs(false)); break; case OutputType::kError: EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); break; case OutputType::kUnknown: EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); break; } } TEST_P(EqualsTest, Iterative) { const EqualsTestCase& test_case = GetParam(); cel::Activation activation; google::protobuf::Arena arena; cel::RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); FlatExpressionEvaluatorState state( /*value_stack_size=*/5, /*comprehension_slot_count=*/0, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> steps; steps.push_back( std::make_unique(MakeValue(test_case.lhs, &arena))); steps.push_back( std::make_unique(MakeValue(test_case.rhs, &arena))); steps.push_back(CreateEqualityStep(test_case.negation, -1)); ExecutionFrame frame(steps, activation, opts, state); ASSERT_OK_AND_ASSIGN(Value result, frame.Evaluate()); switch (test_case.expected_result) { case OutputType::kBoolTrue: EXPECT_THAT(result, BoolValueIs(true)); break; case OutputType::kBoolFalse: EXPECT_THAT(result, BoolValueIs(false)); break; case OutputType::kError: EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); break; case OutputType::kUnknown: EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); break; } } INSTANTIATE_TEST_SUITE_P(EqualsTest, EqualsTest, testing::Values( EqualsTestCase{ InputType::kInt1, InputType::kInt2, false, OutputType::kBoolFalse, }, EqualsTestCase{ InputType::kInt1, InputType::kInt1, false, OutputType::kBoolTrue, }, EqualsTestCase{ InputType::kInt1, InputType::kList, false, OutputType::kBoolFalse, }, EqualsTestCase{ InputType::kInt1, InputType::kDouble1, false, OutputType::kBoolTrue, }, EqualsTestCase{ InputType::kInt2, InputType::kDouble1, false, OutputType::kBoolFalse, }, EqualsTestCase{ InputType::kInt1, InputType::kError, false, OutputType::kError, }, EqualsTestCase{ InputType::kError, InputType::kInt1, false, OutputType::kError, }, EqualsTestCase{ InputType::kInt1, InputType::kUnknown, false, OutputType::kUnknown, }, EqualsTestCase{ InputType::kUnknown, InputType::kInt1, false, OutputType::kUnknown, }, EqualsTestCase{ InputType::kError, InputType::kUnknown, false, OutputType::kError, }, EqualsTestCase{ InputType::kUnknown, InputType::kError, false, OutputType::kError, }, // != EqualsTestCase{ InputType::kInt1, InputType::kInt2, true, OutputType::kBoolTrue, }, EqualsTestCase{ InputType::kError, InputType::kInt1, true, OutputType::kError, }, EqualsTestCase{ InputType::kUnknown, InputType::kInt1, true, OutputType::kUnknown, }, EqualsTestCase{ InputType::kInt1, InputType::kDouble1, true, OutputType::kBoolFalse, })); struct InTestCase { InputType lhs; InputType rhs; OutputType expected_result; }; class InTest : public ::testing::TestWithParam {}; TEST_P(InTest, Recursive) { const InTestCase& test_case = GetParam(); cel::Activation activation; google::protobuf::Arena arena; cel::RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); auto plan = CreateDirectInStep( std::make_unique(MakeValue(test_case.lhs, &arena)), std::make_unique(MakeValue(test_case.rhs, &arena)), -1); ExecutionFrameBase frame(activation, opts, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); cel::Value result; AttributeTrail attribute_trail; ASSERT_THAT(plan->Evaluate(frame, result, attribute_trail), IsOk()); switch (test_case.expected_result) { case OutputType::kBoolTrue: EXPECT_THAT(result, BoolValueIs(true)); break; case OutputType::kBoolFalse: EXPECT_THAT(result, BoolValueIs(false)); break; case OutputType::kError: EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); break; case OutputType::kUnknown: EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); break; } } TEST_P(InTest, Iterative) { const InTestCase& test_case = GetParam(); cel::Activation activation; google::protobuf::Arena arena; cel::RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); FlatExpressionEvaluatorState state( /*value_stack_size=*/5, /*comprehension_slot_count=*/0, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); std::vector> steps; steps.push_back( std::make_unique(MakeValue(test_case.lhs, &arena))); steps.push_back( std::make_unique(MakeValue(test_case.rhs, &arena))); steps.push_back(CreateInStep(-1)); ExecutionFrame frame(steps, activation, opts, state); ASSERT_OK_AND_ASSIGN(Value result, frame.Evaluate()); switch (test_case.expected_result) { case OutputType::kBoolTrue: EXPECT_THAT(result, BoolValueIs(true)); break; case OutputType::kBoolFalse: EXPECT_THAT(result, BoolValueIs(false)); break; case OutputType::kError: EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); break; case OutputType::kUnknown: EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); break; } } INSTANTIATE_TEST_SUITE_P(InTest, InTest, testing::Values( InTestCase{ InputType::kInt1, InputType::kInt2, OutputType::kError, }, InTestCase{ InputType::kInt1, InputType::kList, OutputType::kBoolTrue, }, InTestCase{ InputType::kInt1, InputType::kMap, OutputType::kBoolTrue, }, InTestCase{ InputType::kDouble1, InputType::kList, OutputType::kBoolTrue, }, InTestCase{ InputType::kInt2, InputType::kList, OutputType::kBoolFalse, }, InTestCase{ InputType::kDouble1, InputType::kMap, OutputType::kBoolTrue, }, InTestCase{ InputType::kInt2, InputType::kMap, OutputType::kBoolFalse, }, InTestCase{ InputType::kList, InputType::kMap, OutputType::kError, }, InTestCase{ InputType::kList, InputType::kList, OutputType::kBoolFalse, }, InTestCase{ InputType::kError, InputType::kList, OutputType::kError, }, InTestCase{ InputType::kInt1, InputType::kError, OutputType::kError, }, InTestCase{ InputType::kUnknown, InputType::kList, OutputType::kUnknown, }, InTestCase{ InputType::kInt1, InputType::kUnknown, OutputType::kUnknown, }, InTestCase{ InputType::kUnknown, InputType::kError, OutputType::kError, })); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/evaluator_core.cc ================================================ // Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/evaluator_core.h" #include #include #include #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "common/value.h" #include "runtime/activation_interface.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { void FlatExpressionEvaluatorState::Reset() { value_stack_.Clear(); iterator_stack_.Clear(); comprehension_slots_.Reset(); } const ExpressionStep* ExecutionFrame::Next() { while (true) { const size_t end_pos = execution_path_.size(); if (ABSL_PREDICT_TRUE(pc_ < end_pos)) { const auto* step = execution_path_[pc_++].get(); ABSL_ASSUME(step != nullptr); return step; } if (ABSL_PREDICT_TRUE(pc_ == end_pos)) { if (!call_stack_.empty()) { SubFrame& subframe = call_stack_.back(); pc_ = subframe.return_pc; execution_path_ = subframe.return_expression; ABSL_DCHECK_EQ(value_stack().size(), subframe.expected_stack_size); comprehension_slots().Set(subframe.slot_index, value_stack().Peek(), value_stack().PeekAttribute()); call_stack_.pop_back(); continue; } } else { ABSL_LOG(ERROR) << "Attempting to step beyond the end of execution path."; } return nullptr; } } namespace { // This class abuses the fact that `absl::Status` is trivially destructible when // `absl::Status::ok()` is `true`. If the implementation of `absl::Status` every // changes, LSan and ASan should catch it. We cannot deal with the cost of extra // move assignment and destructor calls. // // This is useful only in the evaluation loop and is a direct replacement for // `RETURN_IF_ERROR`. It yields the most improvements on benchmarks with lots of // steps which never return non-OK `absl::Status`. class EvaluationStatus final { public: explicit EvaluationStatus(absl::Status&& status) { ::new (static_cast(&status_[0])) absl::Status(std::move(status)); } EvaluationStatus() = delete; EvaluationStatus(const EvaluationStatus&) = delete; EvaluationStatus(EvaluationStatus&&) = delete; EvaluationStatus& operator=(const EvaluationStatus&) = delete; EvaluationStatus& operator=(EvaluationStatus&&) = delete; absl::Status Consume() && { return std::move(*reinterpret_cast(&status_[0])); } bool ok() const { return ABSL_PREDICT_TRUE( reinterpret_cast(&status_[0])->ok()); } private: alignas(absl::Status) char status_[sizeof(absl::Status)]; }; } // namespace absl::StatusOr ExecutionFrame::Evaluate( EvaluationListener& listener) { const size_t initial_stack_size = value_stack().size(); if (!listener) { for (const ExpressionStep* expr = Next(); ABSL_PREDICT_TRUE(expr != nullptr); expr = Next()) { if (EvaluationStatus status(expr->Evaluate(this)); !status.ok()) { return std::move(status).Consume(); } } } else { for (const ExpressionStep* expr = Next(); ABSL_PREDICT_TRUE(expr != nullptr); expr = Next()) { if (EvaluationStatus status(expr->Evaluate(this)); !status.ok()) { return std::move(status).Consume(); } if (pc_ == 0 || !expr->comes_from_ast()) { // Skip if we just started a Call or if the step doesn't map to an // AST id. continue; } if (ABSL_PREDICT_FALSE(value_stack().empty())) { ABSL_LOG(ERROR) << "Stack is empty after a ExpressionStep.Evaluate. " "Try to disable short-circuiting."; continue; } if (EvaluationStatus status(listener(expr->id(), value_stack().Peek(), descriptor_pool(), message_factory(), arena())); !status.ok()) { return std::move(status).Consume(); } } } const size_t final_stack_size = value_stack().size(); if (ABSL_PREDICT_FALSE(final_stack_size != initial_stack_size + 1 || final_stack_size == 0)) { return absl::InternalError(absl::StrCat( "Stack error during evaluation: expected=", initial_stack_size + 1, ", actual=", final_stack_size)); } cel::Value value = std::move(value_stack().Peek()); value_stack().Pop(1); return value; } FlatExpressionEvaluatorState FlatExpression::MakeEvaluatorState( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { return FlatExpressionEvaluatorState(path_.size(), comprehension_slots_size_, type_provider_, descriptor_pool, message_factory, arena); } absl::StatusOr FlatExpression::EvaluateWithCallback( const cel::ActivationInterface& activation, const cel::EmbedderContext* absl_nullable embedder_context, EvaluationListener listener, FlatExpressionEvaluatorState& state) const { state.Reset(); ExecutionFrame frame(subexpressions_, activation, options_, state, std::move(listener), embedder_context); return frame.Evaluate(frame.callback()); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/evaluator_core.h ================================================ // Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/type_provider.h" #include "common/native_type.h" #include "common/value.h" #include "eval/eval/attribute_utility.h" #include "eval/eval/comprehension_slots.h" #include "eval/eval/evaluator_stack.h" #include "eval/eval/iterator_stack.h" #include "runtime/activation_interface.h" #include "runtime/internal/activation_attribute_matcher_access.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { class EmbedderContext; } // namespace cel namespace google::api::expr::runtime { // Forward declaration of ExecutionFrame, to resolve circular dependency. class ExecutionFrame; using EvaluationListener = cel::TraceableProgram::EvaluationListener; // Class Expression represents single execution step. class ExpressionStep { public: explicit ExpressionStep(int64_t id, bool comes_from_ast = true) : id_(id), comes_from_ast_(comes_from_ast) {} ExpressionStep(const ExpressionStep&) = delete; ExpressionStep& operator=(const ExpressionStep&) = delete; virtual ~ExpressionStep() = default; // Performs actual evaluation. // Values are passed between Expression objects via EvaluatorStack, which is // supplied with context. // Also, Expression gets values supplied by caller though Activation // interface. // ExpressionStep instances can in specific cases // modify execution order(perform jumps). virtual absl::Status Evaluate(ExecutionFrame* context) const = 0; // Returns corresponding expression object ID. // Requires that the input expression has IDs assigned to sub-expressions, // e.g. via a checker. The default value 0 is returned if there is no // expression associated (e.g. a jump step), or if there is no ID assigned to // the corresponding expression. Useful for error scenarios where information // from Expr object is needed to create CelError. int64_t id() const { return id_; } // Returns if the execution step comes from AST. bool comes_from_ast() const { return comes_from_ast_; } // Return the type of the underlying expression step for special handling in // the planning phase. This should only be overridden by special cases, and // callers must not make any assumptions about the default case. virtual cel::NativeTypeId GetNativeTypeId() const { return cel::NativeTypeId(); } private: const int64_t id_; const bool comes_from_ast_; }; using ExecutionPath = std::vector>; using ExecutionPathView = absl::Span>; // Class that wraps the state that needs to be allocated for expression // evaluation. This can be reused to save on allocations. class FlatExpressionEvaluatorState { public: FlatExpressionEvaluatorState( size_t value_stack_size, size_t comprehension_slot_count, const cel::TypeProvider& type_provider, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) : value_stack_(value_stack_size), // We currently use comprehension_slot_count because it is less of an // over estimate than value_stack_size. In future we should just // calculate the correct capacity. iterator_stack_(comprehension_slot_count), comprehension_slots_(comprehension_slot_count), type_provider_(type_provider), descriptor_pool_(descriptor_pool), message_factory_(message_factory), arena_(arena) {} void Reset(); EvaluatorStack& value_stack() { return value_stack_; } cel::runtime_internal::IteratorStack& iterator_stack() { return iterator_stack_; } ComprehensionSlots& comprehension_slots() { return comprehension_slots_; } const cel::TypeProvider& type_provider() { return type_provider_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return descriptor_pool_; } google::protobuf::MessageFactory* absl_nonnull message_factory() { return message_factory_; } google::protobuf::Arena* absl_nonnull arena() { return arena_; } private: EvaluatorStack value_stack_; cel::runtime_internal::IteratorStack iterator_stack_; ComprehensionSlots comprehension_slots_; const cel::TypeProvider& type_provider_; const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; google::protobuf::MessageFactory* absl_nonnull message_factory_; google::protobuf::Arena* absl_nonnull arena_; }; // Context needed for evaluation. This is sufficient for supporting // recursive evaluation, but stack machine programs require an // ExecutionFrame instance for managing a heap-backed stack. class ExecutionFrameBase { public: // Overload for test usages. ExecutionFrameBase(const cel::ActivationInterface& activation, const cel::RuntimeOptions& options, const cel::TypeProvider& type_provider, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) : activation_(&activation), callback_(), options_(&options), type_provider_(type_provider), descriptor_pool_(descriptor_pool), message_factory_(message_factory), arena_(arena), embedder_context_(nullptr), attribute_utility_(activation.GetUnknownAttributes(), activation.GetMissingAttributes()), slots_(&ComprehensionSlots::GetEmptyInstance()), max_iterations_(options.comprehension_max_iterations), iterations_(0) { if (unknown_processing_enabled()) { if (auto matcher = cel::runtime_internal:: ActivationAttributeMatcherAccess::GetAttributeMatcher(activation); matcher != nullptr) { attribute_utility_.set_matcher(matcher); } } } ExecutionFrameBase(const cel::ActivationInterface& activation, EvaluationListener callback, const cel::RuntimeOptions& options, const cel::TypeProvider& type_provider, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, const cel::EmbedderContext* absl_nullable embedder_context, ComprehensionSlots& slots) : activation_(&activation), callback_(std::move(callback)), options_(&options), type_provider_(type_provider), descriptor_pool_(descriptor_pool), message_factory_(message_factory), arena_(arena), embedder_context_(embedder_context), attribute_utility_(activation.GetUnknownAttributes(), activation.GetMissingAttributes()), slots_(&slots), max_iterations_(options.comprehension_max_iterations), iterations_(0) { if (unknown_processing_enabled()) { if (auto matcher = cel::runtime_internal:: ActivationAttributeMatcherAccess::GetAttributeMatcher(activation); matcher != nullptr) { attribute_utility_.set_matcher(matcher); } } } const cel::ActivationInterface& activation() const { return *activation_; } EvaluationListener& callback() { return callback_; } const cel::RuntimeOptions& options() const { return *options_; } const cel::TypeProvider& type_provider() { return type_provider_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { return descriptor_pool_; } google::protobuf::MessageFactory* absl_nonnull message_factory() const { return message_factory_; } google::protobuf::Arena* absl_nonnull arena() const { return arena_; } const cel::EmbedderContext* absl_nullable embedder_context() const { return embedder_context_; } const AttributeUtility& attribute_utility() const { return attribute_utility_; } bool attribute_tracking_enabled() const { return options_->unknown_processing != cel::UnknownProcessingOptions::kDisabled || options_->enable_missing_attribute_errors; } bool missing_attribute_errors_enabled() const { return options_->enable_missing_attribute_errors; } bool unknown_processing_enabled() const { return options_->unknown_processing != cel::UnknownProcessingOptions::kDisabled; } bool unknown_function_results_enabled() const { return options_->unknown_processing == cel::UnknownProcessingOptions::kAttributeAndFunction; } ComprehensionSlots& comprehension_slots() { return *slots_; } // Increment iterations and return an error if the iteration budget is // exceeded absl::Status IncrementIterations() { if (max_iterations_ == 0) { return absl::OkStatus(); } iterations_++; if (iterations_ >= max_iterations_) { return absl::Status(absl::StatusCode::kInternal, "Iteration budget exceeded"); } return absl::OkStatus(); } protected: const cel::ActivationInterface* absl_nonnull activation_; EvaluationListener callback_; const cel::RuntimeOptions* absl_nonnull options_; const cel::TypeProvider& type_provider_; const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; google::protobuf::MessageFactory* absl_nonnull message_factory_; google::protobuf::Arena* absl_nonnull arena_; const cel::EmbedderContext* absl_nullable embedder_context_; AttributeUtility attribute_utility_; ComprehensionSlots* absl_nonnull slots_; const int max_iterations_; int iterations_; }; // ExecutionFrame manages the context needed for expression evaluation. // The lifecycle of the object is bound to a FlateExpression::Evaluate*(...) // call. class ExecutionFrame : public ExecutionFrameBase { public: // flat is the flattened sequence of execution steps that will be evaluated. // activation provides bindings between parameter names and values. // state contains the value factory for evaluation and the allocated data // structures needed for evaluation. ExecutionFrame( ExecutionPathView flat, const cel::ActivationInterface& activation, const cel::RuntimeOptions& options, FlatExpressionEvaluatorState& state, EvaluationListener callback = EvaluationListener(), const cel::EmbedderContext* absl_nullable embedder_context = nullptr) : ExecutionFrameBase(activation, std::move(callback), options, state.type_provider(), state.descriptor_pool(), state.message_factory(), state.arena(), embedder_context, state.comprehension_slots()), pc_(0UL), execution_path_(flat), value_stack_(&state.value_stack()), iterator_stack_(&state.iterator_stack()), subexpressions_() {} ExecutionFrame( absl::Span subexpressions, const cel::ActivationInterface& activation, const cel::RuntimeOptions& options, FlatExpressionEvaluatorState& state, EvaluationListener callback = EvaluationListener(), const cel::EmbedderContext* absl_nullable embedder_context = nullptr) : ExecutionFrameBase(activation, std::move(callback), options, state.type_provider(), state.descriptor_pool(), state.message_factory(), state.arena(), embedder_context, state.comprehension_slots()), pc_(0UL), execution_path_(subexpressions[0]), value_stack_(&state.value_stack()), iterator_stack_(&state.iterator_stack()), subexpressions_(subexpressions) { ABSL_DCHECK(!subexpressions.empty()); } // Returns next expression to evaluate. const ExpressionStep* Next(); // Evaluate the execution frame to completion. absl::StatusOr Evaluate(EvaluationListener& listener); // Evaluate the execution frame to completion. absl::StatusOr Evaluate() { return Evaluate(callback()); } // Intended for use in builtin shortcutting operations. // // Offset applies after normal pc increment. For example, JumpTo(0) is a // no-op, JumpTo(1) skips the expected next step. absl::Status JumpTo(int offset) { ABSL_DCHECK_LE(offset, static_cast(execution_path_.size())); ABSL_DCHECK_GE(offset, -static_cast(pc_)); int new_pc = static_cast(pc_) + offset; if (new_pc < 0 || new_pc > static_cast(execution_path_.size())) { return absl::Status(absl::StatusCode::kInternal, absl::StrCat("Jump address out of range: position: ", pc_, ", offset: ", offset, ", range: ", execution_path_.size())); } pc_ = static_cast(new_pc); return absl::OkStatus(); } // Move pc to a subexpression. // // Unlike a `Call` in a programming language, the subexpression is evaluated // in the same context as the caller (e.g. no stack isolation or scope change) // // Only intended for use in built-in notion of lazily evaluated // subexpressions. void Call(size_t slot_index, size_t subexpression_index) { ABSL_DCHECK_LT(subexpression_index, subexpressions_.size()); ExecutionPathView subexpression = subexpressions_[subexpression_index]; ABSL_DCHECK(subexpression != execution_path_); size_t return_pc = pc_; // return pc == size() is supported (a tail call). ABSL_DCHECK_LE(return_pc, execution_path_.size()); call_stack_.push_back(SubFrame{return_pc, slot_index, execution_path_, value_stack().size() + 1}); pc_ = 0UL; execution_path_ = subexpression; } EvaluatorStack& value_stack() { return *value_stack_; } cel::runtime_internal::IteratorStack& iterator_stack() { return *iterator_stack_; } bool enable_attribute_tracking() const { return attribute_tracking_enabled(); } bool enable_unknowns() const { return unknown_processing_enabled(); } bool enable_unknown_function_results() const { return unknown_function_results_enabled(); } bool enable_missing_attribute_errors() const { return missing_attribute_errors_enabled(); } bool enable_heterogeneous_numeric_lookups() const { return options().enable_heterogeneous_equality; } bool enable_comprehension_list_append() const { return options().enable_comprehension_list_append; } // Returns reference to the modern API activation. const cel::ActivationInterface& modern_activation() const { return *activation_; } private: struct SubFrame { size_t return_pc; size_t slot_index; ExecutionPathView return_expression; size_t expected_stack_size; }; size_t pc_; // pc_ - Program Counter. Current position on execution path. ExecutionPathView execution_path_; EvaluatorStack* absl_nonnull const value_stack_; cel::runtime_internal::IteratorStack* absl_nonnull const iterator_stack_; absl::Span subexpressions_; std::vector call_stack_; }; // A flattened representation of the input CEL AST. class FlatExpression { public: // path is flat execution path that is based upon the flattened AST tree // type_provider is the configured type system that should be used for // value creation in evaluation FlatExpression(ExecutionPath path, size_t comprehension_slots_size, const cel::TypeProvider& type_provider, const cel::RuntimeOptions& options, absl_nullable std::shared_ptr arena = nullptr) : path_(std::move(path)), subexpressions_({path_}), comprehension_slots_size_(comprehension_slots_size), type_provider_(type_provider), options_(options), arena_(std::move(arena)) {} FlatExpression(ExecutionPath path, std::vector subexpressions, size_t comprehension_slots_size, const cel::TypeProvider& type_provider, const cel::RuntimeOptions& options, absl_nullable std::shared_ptr arena = nullptr) : path_(std::move(path)), subexpressions_(std::move(subexpressions)), comprehension_slots_size_(comprehension_slots_size), type_provider_(type_provider), options_(options), arena_(std::move(arena)) {} // Move-only FlatExpression(FlatExpression&&) = default; FlatExpression& operator=(FlatExpression&&) = delete; // Create new evaluator state instance with the configured options and type // provider. FlatExpressionEvaluatorState MakeEvaluatorState( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; // Evaluate the expression. // // A status may be returned if an unexpected error occurs. Recoverable errors // will be represented as a cel::ErrorValue result. // // If the listener is not empty, it will be called after each evaluation step // that correlates to an AST node. The value passed to the will be the top of // the evaluation stack, corresponding to the result of the subexpression. absl::StatusOr EvaluateWithCallback( const cel::ActivationInterface& activation, const cel::EmbedderContext* absl_nullable embedder_context, EvaluationListener listener, FlatExpressionEvaluatorState& state) const; const ExecutionPath& path() const { return path_; } absl::Span subexpressions() const { return subexpressions_; } const cel::RuntimeOptions& options() const { return options_; } size_t comprehension_slots_size() const { return comprehension_slots_size_; } const cel::TypeProvider& type_provider() const { return type_provider_; } private: ExecutionPath path_; std::vector subexpressions_; size_t comprehension_slots_size_; const cel::TypeProvider& type_provider_; cel::RuntimeOptions options_; // Arena used during planning phase, may hold constant values so should be // kept alive. absl_nullable std::shared_ptr arena_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ ================================================ FILE: eval/eval/evaluator_core_test.cc ================================================ #include "eval/eval/evaluator_core.h" #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "base/type_provider.h" #include "common/value.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/internal/interop.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_value.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { using ::cel::IntValue; using ::cel::TypeProvider; using ::cel::interop_internal::CreateIntValue; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::Expr; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::testing::_; using ::testing::Eq; // Fake expression implementation // Pushes int64(0) on top of value stack. class FakeConstExpressionStep : public ExpressionStep { public: FakeConstExpressionStep() : ExpressionStep(0, true) {} absl::Status Evaluate(ExecutionFrame* frame) const override { frame->value_stack().Push(CreateIntValue(0)); return absl::OkStatus(); } }; // Fake expression implementation // Increments argument on top of the stack. class FakeIncrementExpressionStep : public ExpressionStep { public: FakeIncrementExpressionStep() : ExpressionStep(0, true) {} absl::Status Evaluate(ExecutionFrame* frame) const override { auto value = frame->value_stack().Peek(); frame->value_stack().Pop(1); EXPECT_TRUE(value->Is()); int64_t val = value.GetInt().NativeValue(); frame->value_stack().Push(CreateIntValue(val + 1)); return absl::OkStatus(); } }; TEST(EvaluatorCoreTest, ExecutionFrameNext) { ExecutionPath path; google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); auto const_step = std::make_unique(); auto incr_step1 = std::make_unique(); auto incr_step2 = std::make_unique(); path.push_back(std::move(const_step)); path.push_back(std::move(incr_step1)); path.push_back(std::move(incr_step2)); auto dummy_expr = std::make_unique(); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; cel::Activation activation; FlatExpressionEvaluatorState state( path.size(), /*comprehension_slots_size=*/0, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); ExecutionFrame frame(path, activation, options, state); EXPECT_THAT(frame.Next(), Eq(path[0].get())); EXPECT_THAT(frame.Next(), Eq(path[1].get())); EXPECT_THAT(frame.Next(), Eq(path[2].get())); EXPECT_THAT(frame.Next(), Eq(nullptr)); } TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { ExecutionPath path; auto const_step = std::make_unique(); auto incr_step1 = std::make_unique(); auto incr_step2 = std::make_unique(); path.push_back(std::move(const_step)); path.push_back(std::move(incr_step1)); path.push_back(std::move(incr_step2)); auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), 0, env->type_registry.GetComposedTypeProvider(), cel::RuntimeOptions{})); Activation activation; google::protobuf::Arena arena; auto status = impl.Evaluate(activation, &arena); EXPECT_OK(status); auto value = status.value(); EXPECT_TRUE(value.IsInt64()); EXPECT_THAT(value.Int64OrDie(), Eq(2)); } class MockTraceCallback { public: MOCK_METHOD(void, Call, (int64_t expr_id, const CelValue& value, google::protobuf::Arena*)); }; TEST(EvaluatorCoreTest, TraceTest) { Expr expr; cel::expr::SourceInfo source_info; // 1 && [1,2,3].all(x, x > 0) expr.set_id(1); auto and_call = expr.mutable_call_expr(); and_call->set_function("_&&_"); auto true_expr = and_call->add_args(); true_expr->set_id(2); true_expr->mutable_const_expr()->set_int64_value(1); auto comp_expr = and_call->add_args(); comp_expr->set_id(3); auto comp = comp_expr->mutable_comprehension_expr(); comp->set_iter_var("x"); comp->set_accu_var("accu"); auto list_expr = comp->mutable_iter_range(); list_expr->set_id(4); auto el1_expr = list_expr->mutable_list_expr()->add_elements(); el1_expr->set_id(11); el1_expr->mutable_const_expr()->set_int64_value(1); auto el2_expr = list_expr->mutable_list_expr()->add_elements(); el2_expr->set_id(12); el2_expr->mutable_const_expr()->set_int64_value(2); auto el3_expr = list_expr->mutable_list_expr()->add_elements(); el3_expr->set_id(13); el3_expr->mutable_const_expr()->set_int64_value(3); auto accu_init_expr = comp->mutable_accu_init(); accu_init_expr->set_id(20); accu_init_expr->mutable_const_expr()->set_bool_value(true); auto loop_cond_expr = comp->mutable_loop_condition(); loop_cond_expr->set_id(21); loop_cond_expr->mutable_const_expr()->set_bool_value(true); auto loop_step_expr = comp->mutable_loop_step(); loop_step_expr->set_id(22); auto condition = loop_step_expr->mutable_call_expr(); condition->set_function("_>_"); auto iter_expr = condition->add_args(); iter_expr->set_id(23); iter_expr->mutable_ident_expr()->set_name("x"); auto zero_expr = condition->add_args(); zero_expr->set_id(24); zero_expr->mutable_const_expr()->set_int64_value(0); auto result_expr = comp->mutable_result(); result_expr->set_id(25); result_expr->mutable_const_expr()->set_bool_value(true); cel::RuntimeOptions options; options.short_circuiting = false; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); Activation activation; google::protobuf::Arena arena; MockTraceCallback callback; EXPECT_CALL(callback, Call(accu_init_expr->id(), _, &arena)); EXPECT_CALL(callback, Call(el1_expr->id(), _, &arena)); EXPECT_CALL(callback, Call(el2_expr->id(), _, &arena)); EXPECT_CALL(callback, Call(el3_expr->id(), _, &arena)); EXPECT_CALL(callback, Call(list_expr->id(), _, &arena)); EXPECT_CALL(callback, Call(loop_cond_expr->id(), _, &arena)).Times(3); EXPECT_CALL(callback, Call(iter_expr->id(), _, &arena)).Times(3); EXPECT_CALL(callback, Call(zero_expr->id(), _, &arena)).Times(3); EXPECT_CALL(callback, Call(loop_step_expr->id(), _, &arena)).Times(3); EXPECT_CALL(callback, Call(result_expr->id(), _, &arena)); EXPECT_CALL(callback, Call(comp_expr->id(), _, &arena)); EXPECT_CALL(callback, Call(true_expr->id(), _, &arena)); EXPECT_CALL(callback, Call(expr.id(), _, &arena)); auto eval_status = cel_expr->Trace( activation, &arena, [&](int64_t expr_id, const CelValue& value, google::protobuf::Arena* arena) { callback.Call(expr_id, value, arena); return absl::OkStatus(); }); ASSERT_OK(eval_status); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/evaluator_stack.cc ================================================ #include "eval/eval/evaluator_stack.h" #include #include #include #include #include #include "absl/base/dynamic_annotations.h" #include "absl/base/nullability.h" #include "absl/log/absl_log.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "internal/new.h" namespace google::api::expr::runtime { void EvaluatorStack::Grow() { const size_t new_max_size = std::max(max_size() * 2, size_t{1}); ABSL_LOG(ERROR) << "evaluation stack is unexpectedly full: growing from " << max_size() << " to " << new_max_size << " as a last resort to avoid crashing: this should not " "have happened so there must be a bug somewhere in " "the planner or evaluator"; Reserve(new_max_size); } void EvaluatorStack::Reserve(size_t size) { static_assert(alignof(cel::Value) <= __STDCPP_DEFAULT_NEW_ALIGNMENT__); static_assert(alignof(AttributeTrail) <= __STDCPP_DEFAULT_NEW_ALIGNMENT__); if (max_size_ >= size) { return; } void* absl_nullability_unknown data = cel::internal::New(SizeBytes(size)); cel::Value* absl_nullability_unknown values_begin = reinterpret_cast(data); cel::Value* absl_nullability_unknown values = values_begin; AttributeTrail* absl_nullability_unknown attributes_begin = reinterpret_cast(reinterpret_cast(data) + AttributesBytesOffset(size)); AttributeTrail* absl_nullability_unknown attributes = attributes_begin; if (max_size_ > 0) { const size_t n = this->size(); const size_t m = std::min(n, size); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin, values_begin + size, values_begin + size, values + m); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin, attributes_begin + size, attributes_begin + size, attributes + m); for (size_t i = 0; i < m; ++i) { ::new (static_cast(values++)) cel::Value(std::move(values_begin_[i])); ::new (static_cast(attributes++)) AttributeTrail(std::move(attributes_begin_[i])); } std::destroy_n(values_begin_, n); std::destroy_n(attributes_begin_, n); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin_, values_begin_ + max_size_, values_, values_begin_ + max_size_); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER( attributes_begin_, attributes_begin_ + max_size_, attributes_, attributes_begin_ + max_size_); cel::internal::SizedDelete(data_, SizeBytes(max_size_)); } else { ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin, values_begin + size, values_begin + size, values); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin, attributes_begin + size, attributes_begin + size, attributes); } values_ = values; values_begin_ = values_begin; values_end_ = values_begin + size; attributes_ = attributes; attributes_begin_ = attributes_begin; data_ = data; max_size_ = size; } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/evaluator_stack.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_STACK_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_STACK_H_ #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/dynamic_annotations.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/meta/type_traits.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "internal/align.h" #include "internal/new.h" namespace google::api::expr::runtime { // CelValue stack. // Implementation is based on vector to allow passing parameters from // stack as Span<>. class EvaluatorStack { public: explicit EvaluatorStack(size_t max_size) { Reserve(max_size); } EvaluatorStack(const EvaluatorStack&) = delete; EvaluatorStack(EvaluatorStack&&) = delete; ~EvaluatorStack() { if (max_size() > 0) { const size_t n = size(); std::destroy_n(values_begin_, n); std::destroy_n(attributes_begin_, n); cel::internal::SizedDelete(data_, SizeBytes(max_size_)); } } EvaluatorStack& operator=(const EvaluatorStack&) = delete; EvaluatorStack& operator=(EvaluatorStack&&) = delete; // Return the current stack size. size_t size() const { ABSL_DCHECK_GE(values_, values_begin_); ABSL_DCHECK_LE(values_, values_begin_ + max_size_); ABSL_DCHECK_GE(attributes_, attributes_begin_); ABSL_DCHECK_LE(attributes_, attributes_begin_ + max_size_); ABSL_DCHECK_EQ(values_ - values_begin_, attributes_ - attributes_begin_); return values_ - values_begin_; } // Return the maximum size of the stack. size_t max_size() const { ABSL_DCHECK_GE(values_, values_begin_); ABSL_DCHECK_LE(values_, values_begin_ + max_size_); ABSL_DCHECK_GE(attributes_, attributes_begin_); ABSL_DCHECK_LE(attributes_, attributes_begin_ + max_size_); ABSL_DCHECK_EQ(values_ - values_begin_, attributes_ - attributes_begin_); return max_size_; } // Returns true if stack is empty. bool empty() const { ABSL_DCHECK_GE(values_, values_begin_); ABSL_DCHECK_LE(values_, values_begin_ + max_size_); ABSL_DCHECK_GE(attributes_, attributes_begin_); ABSL_DCHECK_LE(attributes_, attributes_begin_ + max_size_); ABSL_DCHECK_EQ(values_ - values_begin_, attributes_ - attributes_begin_); return values_ == values_begin_; } bool full() const { ABSL_DCHECK_GE(values_, values_begin_); ABSL_DCHECK_LE(values_, values_begin_ + max_size_); ABSL_DCHECK_GE(attributes_, attributes_begin_); ABSL_DCHECK_LE(attributes_, attributes_begin_ + max_size_); ABSL_DCHECK_EQ(values_ - values_begin_, attributes_ - attributes_begin_); return values_ == values_end_; } // Attributes stack size. ABSL_DEPRECATED("Use size()") size_t attribute_size() const { return size(); } // Check that stack has enough elements. bool HasEnough(size_t size) const { return this->size() >= size; } // Dumps the entire stack state as is. void Clear() { if (max_size() > 0) { const size_t n = size(); std::destroy_n(values_begin_, n); std::destroy_n(attributes_begin_, n); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER( values_begin_, values_begin_ + max_size_, values_, values_begin_); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin_, attributes_begin_ + max_size_, attributes_, attributes_begin_); values_ = values_begin_; attributes_ = attributes_begin_; } } // Gets the last size elements of the stack. // Checking that stack has enough elements is caller's responsibility. // Please note that calls to Push may invalidate returned Span object. absl::Span GetSpan(size_t size) const { ABSL_DCHECK(HasEnough(size)); return absl::Span(values_ - size, size); } // Gets the last size attribute trails of the stack. // Checking that stack has enough elements is caller's responsibility. // Please note that calls to Push may invalidate returned Span object. absl::Span GetAttributeSpan(size_t size) const { ABSL_DCHECK(HasEnough(size)); return absl::Span(attributes_ - size, size); } // Peeks the last element of the stack. // Checking that stack is not empty is caller's responsibility. cel::Value& Peek() { ABSL_DCHECK(HasEnough(1)); return *(values_ - 1); } // Peeks the last element of the stack. // Checking that stack is not empty is caller's responsibility. const cel::Value& Peek() const { ABSL_DCHECK(HasEnough(1)); return *(values_ - 1); } // Peeks the last element of the attribute stack. // Checking that stack is not empty is caller's responsibility. const AttributeTrail& PeekAttribute() const { ABSL_DCHECK(HasEnough(1)); return *(attributes_ - 1); } // Peeks the last element of the attribute stack. // Checking that stack is not empty is caller's responsibility. AttributeTrail& PeekAttribute() { ABSL_DCHECK(HasEnough(1)); return *(attributes_ - 1); } void Pop() { ABSL_DCHECK(!empty()); --values_; values_->~Value(); --attributes_; attributes_->~AttributeTrail(); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin_, values_begin_ + max_size_, values_ + 1, values_); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin_, attributes_begin_ + max_size_, attributes_ + 1, attributes_); } // Clears the last size elements of the stack. // Checking that stack has enough elements is caller's responsibility. void Pop(size_t size) { ABSL_DCHECK(HasEnough(size)); for (; size > 0; --size) { Pop(); } } template , std::is_convertible>>> void Push(V&& value, A&& attribute) { ABSL_DCHECK(!full()); if (ABSL_PREDICT_FALSE(full())) { Grow(); } ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(values_begin_, values_begin_ + max_size_, values_, values_ + 1); ABSL_ANNOTATE_CONTIGUOUS_CONTAINER(attributes_begin_, attributes_begin_ + max_size_, attributes_, attributes_ + 1); ::new (static_cast(values_++)) cel::Value(std::forward(value)); ::new (static_cast(attributes_++)) AttributeTrail(std::forward(attribute)); } template >> void Push(V&& value) { ABSL_DCHECK(!full()); Push(std::forward(value), absl::nullopt); } // Equivalent to `PopAndPush(1, ...)`. template , std::is_convertible>>> void PopAndPush(V&& value, A&& attribute) { ABSL_DCHECK(!empty()); *(values_ - 1) = std::forward(value); *(attributes_ - 1) = std::forward(attribute); } // Equivalent to `PopAndPush(1, ...)`. template >> void PopAndPush(V&& value) { ABSL_DCHECK(!empty()); PopAndPush(std::forward(value), absl::nullopt); } // Equivalent to `Pop(n)` followed by `Push(...)`. Both `V` and `A` MUST NOT // be located on the stack. If this is the case, use SwapAndPop instead. template , std::is_convertible>>> void PopAndPush(size_t n, V&& value, A&& attribute) { if (n > 0) { if constexpr (std::is_same_v>) { ABSL_DCHECK(&value < values_begin_ || &value >= values_begin_ + max_size_) << "Attmpting to push a value about to be popped, use PopAndSwap " "instead."; } if constexpr (std::is_same_v>) { ABSL_DCHECK(&attribute < attributes_begin_ || &attribute >= attributes_begin_ + max_size_) << "Attmpting to push an attribute about to be popped, use " "PopAndSwap instead."; } Pop(n - 1); ABSL_DCHECK(!empty()); *(values_ - 1) = std::forward(value); *(attributes_ - 1) = std::forward(attribute); } else { Push(std::forward(value), std::forward(attribute)); } } // Equivalent to `Pop(n)` followed by `Push(...)`. `V` MUST NOT be located on // the stack. If this is the case, use SwapAndPop instead. template >> void PopAndPush(size_t n, V&& value) { PopAndPush(n, std::forward(value), absl::nullopt); } // Swaps the `n - i` element (from the top of the stack) with the `n` element, // and pops `n - 1` elements. This results in the `n - i` element being at the // top of the stack. void SwapAndPop(size_t n, size_t i) { ABSL_DCHECK_GT(n, 0); ABSL_DCHECK_LT(i, n); ABSL_DCHECK(HasEnough(n - 1)); using std::swap; if (i > 0) { swap(*(values_ - n), *(values_ - n + i)); swap(*(attributes_ - n), *(attributes_ - n + i)); } Pop(n - 1); } // Update the max size of the stack and update capacity if needed. void SetMaxSize(size_t size) { Reserve(size); } private: static size_t AttributesBytesOffset(size_t size) { return cel::internal::AlignUp(sizeof(cel::Value) * size, __STDCPP_DEFAULT_NEW_ALIGNMENT__); } static size_t SizeBytes(size_t size) { return AttributesBytesOffset(size) + (sizeof(AttributeTrail) * size); } void Grow(); // Preallocate stack. void Reserve(size_t size); cel::Value* absl_nullability_unknown values_ = nullptr; cel::Value* absl_nullability_unknown values_begin_ = nullptr; AttributeTrail* absl_nullability_unknown attributes_ = nullptr; AttributeTrail* absl_nullability_unknown attributes_begin_ = nullptr; cel::Value* absl_nullability_unknown values_end_ = nullptr; void* absl_nullability_unknown data_ = nullptr; size_t max_size_ = 0; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_STACK_H_ ================================================ FILE: eval/eval/evaluator_stack_test.cc ================================================ #include "eval/eval/evaluator_stack.h" #include "base/attribute.h" #include "common/value.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { // Test Value Stack Push/Pop operation TEST(EvaluatorStackTest, StackPushPop) { cel::Attribute attribute("name", {}); EvaluatorStack stack(10); stack.Push(cel::IntValue(1)); stack.Push(cel::IntValue(2), AttributeTrail()); stack.Push(cel::IntValue(3), AttributeTrail("name")); ASSERT_EQ(stack.Peek().GetInt().NativeValue(), 3); ASSERT_FALSE(stack.PeekAttribute().empty()); ASSERT_EQ(stack.PeekAttribute().attribute(), attribute); stack.Pop(1); ASSERT_EQ(stack.Peek().GetInt().NativeValue(), 2); ASSERT_TRUE(stack.PeekAttribute().empty()); stack.Pop(1); ASSERT_EQ(stack.Peek().GetInt().NativeValue(), 1); ASSERT_TRUE(stack.PeekAttribute().empty()); } // Test that inner stacks within value stack retain the equality of their sizes. TEST(EvaluatorStackTest, StackBalanced) { EvaluatorStack stack(10); ASSERT_EQ(stack.size(), stack.attribute_size()); stack.Push(cel::IntValue(1)); ASSERT_EQ(stack.size(), stack.attribute_size()); stack.Push(cel::IntValue(2), AttributeTrail()); stack.Push(cel::IntValue(3), AttributeTrail()); ASSERT_EQ(stack.size(), stack.attribute_size()); stack.PopAndPush(cel::IntValue(4), AttributeTrail()); ASSERT_EQ(stack.size(), stack.attribute_size()); stack.PopAndPush(cel::IntValue(5)); ASSERT_EQ(stack.size(), stack.attribute_size()); stack.Pop(3); ASSERT_EQ(stack.size(), stack.attribute_size()); } TEST(EvaluatorStackTest, Clear) { EvaluatorStack stack(10); ASSERT_EQ(stack.size(), stack.attribute_size()); stack.Push(cel::IntValue(1)); stack.Push(cel::IntValue(2), AttributeTrail()); stack.Push(cel::IntValue(3), AttributeTrail()); ASSERT_EQ(stack.size(), 3); stack.Clear(); ASSERT_EQ(stack.size(), 0); ASSERT_TRUE(stack.empty()); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/expression_step_base.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_STEP_BASE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_STEP_BASE_H_ #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { using ExpressionStepBase = ExpressionStep; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_STEP_BASE_H_ ================================================ FILE: eval/eval/function_step.cc ================================================ #include "eval/eval/function_step.h" #include #include #include #include #include #include #include #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/casting.h" #include "common/expr.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "common/value.h" #include "common/value_kind.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" #include "internal/status_macros.h" #include "runtime/activation_interface.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "runtime/function_provider.h" #include "runtime/function_registry.h" #include "runtime/internal/errors.h" namespace google::api::expr::runtime { namespace { using ::cel::ErrorValue; using ::cel::UnknownValue; using ::cel::Value; using ::cel::ValueKindToKind; using ::cel::runtime_internal::CreateNoMatchingOverloadError; // Determine if the overload should be considered. Overloads that can consume // errors or unknown sets must be allowed as a non-strict function. bool ShouldAcceptOverload(const cel::FunctionDescriptor& descriptor, absl::Span arguments) { for (size_t i = 0; i < arguments.size(); i++) { if (arguments[i]->Is() || arguments[i]->Is()) { return !descriptor.is_strict(); } } return true; } bool ArgumentKindsMatch(const cel::FunctionDescriptor& descriptor, absl::Span arguments) { auto types_size = descriptor.types().size(); if (types_size != arguments.size()) { return false; } for (size_t i = 0; i < types_size; i++) { const auto& arg = arguments[i]; cel::Kind param_kind = descriptor.types()[i]; if (arg->kind() != param_kind && param_kind != cel::Kind::kAny) { return false; } } return true; } // Adjust new type names to legacy equivalent. int -> int64. // Temporary fix to migrate value types without breaking clients. // TODO(uncreated-issue/46): Update client tests that depend on this value. std::string ToLegacyKindName(absl::string_view type_name) { if (type_name == "int" || type_name == "uint") { return absl::StrCat(type_name, "64"); } return std::string(type_name); } std::string CallArgTypeString(absl::Span args) { std::string call_sig_string = ""; for (size_t i = 0; i < args.size(); i++) { const auto& arg = args[i]; if (!call_sig_string.empty()) { absl::StrAppend(&call_sig_string, ", "); } absl::StrAppend( &call_sig_string, ToLegacyKindName(cel::KindToString(ValueKindToKind(arg->kind())))); } return absl::StrCat("(", call_sig_string, ")"); } // Convert partially unknown arguments to unknowns before passing to the // function. // TODO(issues/52): See if this can be refactored to remove the eager // arguments copy. // Argument and attribute spans are expected to be equal length. std::vector CheckForPartialUnknowns( ExecutionFrame* frame, absl::Span args, absl::Span attrs) { std::vector result; result.reserve(args.size()); for (size_t i = 0; i < args.size(); i++) { const AttributeTrail& trail = attrs.subspan(i, 1)[0]; if (frame->attribute_utility().CheckForUnknown(trail, /*use_partial=*/true)) { result.push_back( frame->attribute_utility().CreateUnknownSet(trail.attribute())); } else { result.push_back(args.at(i)); } } return result; } bool IsUnknownFunctionResultError(const Value& result) { if (!result->Is()) { return false; } const auto& status = result.GetError().NativeValue(); if (status.code() != absl::StatusCode::kUnavailable) { return false; } auto payload = status.GetPayload( cel::runtime_internal::kPayloadUrlUnknownFunctionResult); return payload.has_value() && payload.value() == "true"; } // Simple wrapper around a function resolution result. A function call should // resolve to a single function implementation and a descriptor or none. using ResolveResult = absl::optional; // Implementation of ExpressionStep that finds suitable CelFunction overload and // invokes it. Abstract base class standardizes behavior between lazy and eager // function bindings. Derived classes provide ResolveFunction behavior. class AbstractFunctionStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. AbstractFunctionStep(const std::string& name, size_t num_arguments, bool receiver_style, int64_t expr_id) : ExpressionStepBase(expr_id), name_(name), num_arguments_(num_arguments), receiver_style_(receiver_style) {} absl::Status Evaluate(ExecutionFrame* frame) const override; // Handles overload resolution and updating result appropriately. // Shouldn't update frame state. // // A non-ok result is an unrecoverable error, either from an illegal // evaluation state or forwarded from an extension function. Errors where // evaluation can reasonably condition are returned in the result as a // cel::ErrorValue. absl::StatusOr DoEvaluate(ExecutionFrame* frame) const; virtual absl::StatusOr ResolveFunction( absl::Span args, const ExecutionFrame* frame) const = 0; protected: std::string name_; size_t num_arguments_; bool receiver_style_; }; inline absl::StatusOr Invoke( const cel::FunctionOverloadReference& overload, int64_t expr_id, absl::Span args, ExecutionFrameBase& frame) { cel::Function::InvokeContext context(frame.descriptor_pool(), frame.message_factory(), frame.arena()); if (overload.descriptor.is_contextual()) { context.set_embedder_context(frame.embedder_context()); } CEL_ASSIGN_OR_RETURN(Value result, overload.implementation.Invoke(args, context)); if (frame.unknown_function_results_enabled() && IsUnknownFunctionResultError(result)) { return frame.attribute_utility().CreateUnknownSet(overload.descriptor, expr_id, args); } return result; } Value NoOverloadResult(absl::string_view name, absl::Span args, bool receiver_style, ExecutionFrameBase& frame) { // No matching overloads. // Such absence can be caused by presence of CelError in arguments. // To enable behavior of functions that accept CelError( &&, || ), CelErrors // should be propagated along execution path. for (size_t i = 0; i < args.size(); i++) { const auto& arg = args[i]; if (cel::InstanceOf(arg)) { return arg; } } if (frame.unknown_processing_enabled()) { // Already converted partial unknowns to unknown sets so just merge. absl::optional unknown_set = frame.attribute_utility().MergeUnknowns(args); if (unknown_set.has_value()) { return *unknown_set; } } // If no errors or unknowns in input args, create new CelError for missing // overload. std::string signature; if (receiver_style) { if (args.empty()) { // Should not be possible, but return a sensible error in case of logic // error. return ErrorValue( CreateNoMatchingOverloadError(absl::StrCat("().", name, "()"))); } return ErrorValue(CreateNoMatchingOverloadError(absl::StrCat( "(", ToLegacyKindName(cel::KindToString(ValueKindToKind(args[0].kind()))), ").", name, CallArgTypeString(args.subspan(1))))); } return cel::ErrorValue(CreateNoMatchingOverloadError( absl::StrCat(name, CallArgTypeString(args)))); } absl::StatusOr AbstractFunctionStep::DoEvaluate( ExecutionFrame* frame) const { // Create Span object that contains input arguments to the function. auto input_args = frame->value_stack().GetSpan(num_arguments_); std::vector unknowns_args; // Preprocess args. If an argument is partially unknown, convert it to an // unknown attribute set. if (frame->enable_unknowns()) { auto input_attrs = frame->value_stack().GetAttributeSpan(num_arguments_); unknowns_args = CheckForPartialUnknowns(frame, input_args, input_attrs); input_args = absl::MakeConstSpan(unknowns_args); } // Derived class resolves to a single function overload or none. CEL_ASSIGN_OR_RETURN(ResolveResult matched_function, ResolveFunction(input_args, frame)); // Overload found and is allowed to consume the arguments. if (matched_function.has_value() && ShouldAcceptOverload(matched_function->descriptor, input_args)) { return Invoke(*matched_function, id(), input_args, *frame); } return NoOverloadResult(name_, input_args, receiver_style_, *frame); } absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(num_arguments_)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } // DoEvaluate may return a status for non-recoverable errors (e.g. // unexpected typing, illegal expression state). Application errors that can // reasonably be handled as a cel error will appear in the result value. CEL_ASSIGN_OR_RETURN(auto result, DoEvaluate(frame)); frame->value_stack().PopAndPush(num_arguments_, std::move(result)); return absl::OkStatus(); } absl::StatusOr ResolveStatic( absl::Span input_args, absl::Span overloads) { ResolveResult result = absl::nullopt; for (const auto& overload : overloads) { if (ArgumentKindsMatch(overload.descriptor, input_args)) { // More than one overload matches our arguments. if (result.has_value()) { return absl::Status(absl::StatusCode::kInternal, "Cannot resolve overloads"); } result.emplace(overload); } } return result; } absl::StatusOr ResolveLazy( absl::Span input_args, absl::string_view name, bool receiver_style, absl::Span providers, const ExecutionFrameBase& frame) { ResolveResult result = absl::nullopt; std::vector arg_types(input_args.size()); std::transform( input_args.begin(), input_args.end(), arg_types.begin(), [](const cel::Value& value) { return ValueKindToKind(value->kind()); }); cel::FunctionDescriptor matcher{name, receiver_style, arg_types}; const cel::ActivationInterface& activation = frame.activation(); for (auto provider : providers) { // The LazyFunctionStep has so far only resolved by function shape, check // that the runtime argument kinds agree with the specific descriptor for // the provider candidates. if (!ArgumentKindsMatch(provider.descriptor, input_args)) { continue; } CEL_ASSIGN_OR_RETURN(auto overload, provider.provider.GetFunction(matcher, activation)); if (overload.has_value()) { // More than one overload matches our arguments. if (result.has_value()) { return absl::Status(absl::StatusCode::kInternal, "Cannot resolve overloads"); } result.emplace(overload.value()); } } return result; } class EagerFunctionStep : public AbstractFunctionStep { public: EagerFunctionStep(std::vector overloads, const std::string& name, size_t num_args, bool receiver_style, int64_t expr_id) : AbstractFunctionStep(name, num_args, receiver_style, expr_id), overloads_(std::move(overloads)) {} absl::StatusOr ResolveFunction( absl::Span input_args, const ExecutionFrame* frame) const override { return ResolveStatic(input_args, overloads_); } private: std::vector overloads_; }; class LazyFunctionStep : public AbstractFunctionStep { public: // Constructs LazyFunctionStep that attempts to lookup function implementation // at runtime. LazyFunctionStep(const std::string& name, size_t num_args, bool receiver_style, std::vector providers, int64_t expr_id) : AbstractFunctionStep(name, num_args, receiver_style, expr_id), providers_(std::move(providers)) {} absl::StatusOr ResolveFunction( absl::Span input_args, const ExecutionFrame* frame) const override; private: std::vector providers_; }; absl::StatusOr LazyFunctionStep::ResolveFunction( absl::Span input_args, const ExecutionFrame* frame) const { return ResolveLazy(input_args, name_, receiver_style_, providers_, *frame); } class StaticResolver { public: explicit StaticResolver(std::vector overloads) : overloads_(std::move(overloads)) {} absl::StatusOr Resolve(ExecutionFrameBase& frame, absl::Span input) const { return ResolveStatic(input, overloads_); } private: std::vector overloads_; }; class LazyResolver { public: explicit LazyResolver( std::vector providers, std::string name, bool receiver_style) : providers_(std::move(providers)), name_(std::move(name)), receiver_style_(receiver_style) {} absl::StatusOr Resolve(ExecutionFrameBase& frame, absl::Span input) const { return ResolveLazy(input, name_, receiver_style_, providers_, frame); } private: std::vector providers_; std::string name_; bool receiver_style_; }; template class DirectFunctionStepImpl : public DirectExpressionStep { public: DirectFunctionStepImpl( int64_t expr_id, const std::string& name, std::vector> arg_steps, bool receiver_style, Resolver&& resolver) : DirectExpressionStep(expr_id), name_(name), arg_steps_(std::move(arg_steps)), receiver_style_(receiver_style), resolver_(std::forward(resolver)) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& trail) const override { absl::InlinedVector args; absl::InlinedVector arg_trails; args.resize(arg_steps_.size()); arg_trails.resize(arg_steps_.size()); for (size_t i = 0; i < arg_steps_.size(); i++) { CEL_RETURN_IF_ERROR( arg_steps_[i]->Evaluate(frame, args[i], arg_trails[i])); } if (frame.unknown_processing_enabled()) { for (size_t i = 0; i < arg_trails.size(); i++) { if (frame.attribute_utility().CheckForUnknown(arg_trails[i], /*use_partial=*/true)) { args[i] = frame.attribute_utility().CreateUnknownSet( arg_trails[i].attribute()); } } } CEL_ASSIGN_OR_RETURN(ResolveResult resolved_function, resolver_.Resolve(frame, args)); if (resolved_function.has_value() && ShouldAcceptOverload(resolved_function->descriptor, args)) { CEL_ASSIGN_OR_RETURN(result, Invoke(*resolved_function, expr_id_, args, frame)); return absl::OkStatus(); } result = NoOverloadResult(name_, args, receiver_style_, frame); return absl::OkStatus(); } absl::optional> GetDependencies() const override { std::vector dependencies; dependencies.reserve(arg_steps_.size()); for (const auto& arg_step : arg_steps_) { dependencies.push_back(arg_step.get()); } return dependencies; } absl::optional>> ExtractDependencies() override { return std::move(arg_steps_); } private: friend Resolver; std::string name_; std::vector> arg_steps_; bool receiver_style_; Resolver resolver_; }; } // namespace std::unique_ptr CreateDirectFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, std::vector overloads) { return std::make_unique>( expr_id, call.function(), std::move(deps), call.has_target(), StaticResolver(std::move(overloads))); } std::unique_ptr CreateDirectLazyFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, std::vector providers) { return std::make_unique>( expr_id, call.function(), std::move(deps), call.has_target(), LazyResolver(std::move(providers), call.function(), call.has_target())); } absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call_expr, int64_t expr_id, std::vector lazy_overloads) { bool receiver_style = call_expr.has_target(); size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0); const std::string& name = call_expr.function(); return std::make_unique(name, num_args, receiver_style, std::move(lazy_overloads), expr_id); } absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call_expr, int64_t expr_id, std::vector overloads) { bool receiver_style = call_expr.has_target(); size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0); const std::string& name = call_expr.function(); return std::make_unique(std::move(overloads), name, num_args, receiver_style, expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/function_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ #include #include #include #include "absl/status/statusor.h" #include "common/expr.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" namespace google::api::expr::runtime { // Factory method for Call-based execution step where the function has been // statically resolved from a set of eagerly functions configured in the // CelFunctionRegistry. std::unique_ptr CreateDirectFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, std::vector overloads); // Factory method for Call-based execution step where the function has been // statically resolved from a set of lazy functions configured in the // CelFunctionRegistry. std::unique_ptr CreateDirectLazyFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, std::vector providers); // Factory method for Call-based execution step where the function will be // resolved at runtime (lazily) from an input Activation. absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call, int64_t expr_id, std::vector lazy_overloads); // Factory method for Call-based execution step where the function has been // statically resolved from a set of eagerly functions configured in the // CelFunctionRegistry. absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call, int64_t expr_id, std::vector overloads); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ ================================================ FILE: eval/eval/function_step_test.cc ================================================ #include "eval/eval/function_step.h" #include #include #include #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/builtins.h" #include "base/type_provider.h" #include "common/constant.h" #include "common/expr.h" #include "common/kind.h" #include "common/value.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/const_value_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/internal/interop.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/portable_cel_function_adapter.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" #include "runtime/standard_functions.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::CallExpr; using ::cel::Expr; using ::cel::IdentExpr; using ::cel::TypeProvider; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::testing::Eq; using ::testing::Not; using ::testing::Truly; int GetExprId() { static int id = 0; id++; return id; } // Simple function that takes no arguments and returns a constant value. class ConstFunction : public CelFunction { public: explicit ConstFunction(const CelValue& value, absl::string_view name) : CelFunction(CreateDescriptor(name)), value_(value) {} static CelFunctionDescriptor CreateDescriptor(absl::string_view name) { return CelFunctionDescriptor{name, false, {}}; } static CallExpr MakeCall(absl::string_view name) { CallExpr call; call.set_function(std::string(name)); call.set_target(nullptr); return call; } absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (!args.empty()) { return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } *result = value_; return absl::OkStatus(); } private: CelValue value_; }; enum class ShouldReturnUnknown : bool { kYes = true, kNo = false }; class AddFunction : public CelFunction { public: AddFunction() : CelFunction(CreateDescriptor()), should_return_unknown_(false) {} explicit AddFunction(ShouldReturnUnknown should_return_unknown) : CelFunction(CreateDescriptor()), should_return_unknown_(static_cast(should_return_unknown)) {} static CelFunctionDescriptor CreateDescriptor() { return CelFunctionDescriptor{ "_+_", false, {CelValue::Type::kInt64, CelValue::Type::kInt64}}; } static CallExpr MakeCall() { CallExpr call; call.set_function("_+_"); call.mutable_args().emplace_back(); call.mutable_args().emplace_back(); call.set_target(nullptr); return call; } absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 2 || !args[0].IsInt64() || !args[1].IsInt64()) { return absl::Status(absl::StatusCode::kInvalidArgument, "Mismatched arguments passed to method"); } if (should_return_unknown_) { *result = CreateUnknownFunctionResultError(arena, "Add can't be resolved."); return absl::OkStatus(); } int64_t arg0 = args[0].Int64OrDie(); int64_t arg1 = args[1].Int64OrDie(); *result = CelValue::CreateInt64(arg0 + arg1); return absl::OkStatus(); } private: bool should_return_unknown_; }; class SinkFunction : public CelFunction { public: explicit SinkFunction(CelValue::Type type, bool is_strict = true) : CelFunction(CreateDescriptor(type, is_strict)) {} static CelFunctionDescriptor CreateDescriptor(CelValue::Type type, bool is_strict = true) { return CelFunctionDescriptor{"Sink", false, {type}, is_strict}; } static CallExpr MakeCall() { CallExpr call; call.set_function("Sink"); call.mutable_args().emplace_back(); call.set_target(nullptr); return call; } absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { // Return value is ignored. *result = CelValue::CreateInt64(0); return absl::OkStatus(); } }; // Create and initialize a registry with some default functions. void AddDefaults(CelFunctionRegistry& registry) { static UnknownSet* unknown_set = new UnknownSet(); EXPECT_TRUE(registry .Register(std::make_unique( CelValue::CreateInt64(3), "Const3")) .ok()); EXPECT_TRUE(registry .Register(std::make_unique( CelValue::CreateInt64(2), "Const2")) .ok()); EXPECT_TRUE(registry .Register(std::make_unique( CelValue::CreateUnknownSet(unknown_set), "ConstUnknown")) .ok()); EXPECT_TRUE(registry.Register(std::make_unique()).ok()); EXPECT_TRUE( registry.Register(std::make_unique(CelValue::Type::kList)) .ok()); EXPECT_TRUE( registry.Register(std::make_unique(CelValue::Type::kMap)) .ok()); EXPECT_TRUE( registry .Register(std::make_unique(CelValue::Type::kMessage)) .ok()); } std::vector ArgumentMatcher(int argument_count) { std::vector argument_matcher(argument_count); for (int i = 0; i < argument_count; i++) { argument_matcher[i] = CelValue::Type::kAny; } return argument_matcher; } std::vector ArgumentMatcher(const CallExpr& call) { return ArgumentMatcher(call.has_target() ? call.args().size() + 1 : call.args().size()); } std::unique_ptr CreateExpressionImpl( const cel::RuntimeOptions& options, std::unique_ptr expr) { ExecutionPath path; path.push_back(std::make_unique(std::move(expr), -1)); auto env = NewTestingRuntimeEnv(); return std::make_unique( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); } absl::StatusOr> MakeTestFunctionStep( const CallExpr& call, const CelFunctionRegistry& registry) { auto argument_matcher = ArgumentMatcher(call); auto lazy_overloads = registry.ModernFindLazyOverloads( call.function(), call.has_target(), argument_matcher); if (!lazy_overloads.empty()) { return CreateFunctionStep(call, GetExprId(), lazy_overloads); } auto overloads = registry.FindStaticOverloads( call.function(), call.has_target(), argument_matcher); return CreateFunctionStep(call, GetExprId(), overloads); } // Test common functions with varying levels of unknown support. class FunctionStepTest : public testing::TestWithParam { public: // underlying expression impl moves path std::unique_ptr GetExpression(ExecutionPath&& path) { cel::RuntimeOptions options; options.unknown_processing = GetParam(); auto env = NewTestingRuntimeEnv(); return std::make_unique( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); } }; TEST_P(FunctionStepTest, SimpleFunctionTest) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); CallExpr call1 = ConstFunction::MakeCall("Const3"); CallExpr call2 = ConstFunction::MakeCall("Const2"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsInt64()); EXPECT_THAT(value.Int64OrDie(), Eq(5)); } TEST_P(FunctionStepTest, TestStackUnderflow) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); AddFunction add_func; CallExpr call1 = ConstFunction::MakeCall("Const3"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; EXPECT_THAT(impl->Evaluate(activation, &arena), Not(IsOk())); } // Test situation when no overloads match input arguments during evaluation. TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); ASSERT_TRUE(registry .Register(std::make_unique( CelValue::CreateUint64(4), "Const4")) .ok()); CallExpr call1 = ConstFunction::MakeCall("Const3"); CallExpr call2 = ConstFunction::MakeCall("Const4"); // Add expects {int64, int64} but it's {int64, uint64}. CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); EXPECT_THAT(*value.ErrorOrDie(), StatusIs(absl::StatusCode::kUnknown, testing::HasSubstr("_+_(int64, uint64)"))); } TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationReceiver) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); CallExpr call1 = ConstFunction::MakeCall("Const3"); CallExpr call2 = ConstFunction::MakeCall("Const3"); // Add expects {int64, int64} but it's {int64, uint64}. CallExpr add_call; add_call.add_args(); add_call.set_target(Expr()); add_call.set_function("_+_"); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); EXPECT_THAT(*value.ErrorOrDie(), StatusIs(absl::StatusCode::kUnknown, testing::HasSubstr("(int64)._+_(int64)"))); } // Test situation when no overloads match input arguments during evaluation. TEST_P(FunctionStepTest, TestNoMatchingOverloadsUnexpectedArgCount) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); CallExpr call1 = ConstFunction::MakeCall("Const3"); // expect overloads for {int64, int64} but get call for {int64, int64, int64}. CallExpr add_call = AddFunction::MakeCall(); add_call.mutable_args().emplace_back(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN( auto step3, CreateFunctionStep(add_call, -1, registry.FindStaticOverloads( add_call.function(), false, {cel::Kind::kInt64, cel::Kind::kInt64}))); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); path.push_back(std::move(step3)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); EXPECT_THAT(*value.ErrorOrDie(), StatusIs(absl::StatusCode::kUnknown, testing::HasSubstr("_+_(int64, int64, int64)"))); } // Test situation when no overloads match input arguments during evaluation // and at least one of arguments is error. TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); CelError error0 = absl::CancelledError(); CelError error1 = absl::CancelledError(); // Constants have ERROR type, while AddFunction expects INT. ASSERT_TRUE(registry .Register(std::make_unique( CelValue::CreateError(&error0), "ConstError1")) .ok()); ASSERT_TRUE(registry .Register(std::make_unique( CelValue::CreateError(&error1), "ConstError2")) .ok()); CallExpr call1 = ConstFunction::MakeCall("ConstError1"); CallExpr call2 = ConstFunction::MakeCall("ConstError2"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); EXPECT_THAT(*value.ErrorOrDie(), Eq(error0)); } TEST_P(FunctionStepTest, LazyFunctionTest) { ExecutionPath path; Activation activation; CelFunctionRegistry registry; ASSERT_OK( registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const3"))); ASSERT_OK(activation.InsertFunction( std::make_unique(CelValue::CreateInt64(3), "Const3"))); ASSERT_OK( registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const2"))); ASSERT_OK(activation.InsertFunction( std::make_unique(CelValue::CreateInt64(2), "Const2"))); ASSERT_OK(registry.Register(std::make_unique())); CallExpr call1 = ConstFunction::MakeCall("Const3"); CallExpr call2 = ConstFunction::MakeCall("Const2"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsInt64()); EXPECT_THAT(value.Int64OrDie(), Eq(5)); } TEST_P(FunctionStepTest, LazyFunctionOverloadingTest) { ExecutionPath path; Activation activation; CelFunctionRegistry registry; auto floor_int = PortableUnaryFunctionAdapter::Create( "Floor", false, [](google::protobuf::Arena*, int64_t val) { return val; }); auto floor_double = PortableUnaryFunctionAdapter::Create( "Floor", false, [](google::protobuf::Arena*, double val) { return std::floor(val); }); ASSERT_OK(registry.RegisterLazyFunction(floor_int->descriptor())); ASSERT_OK(activation.InsertFunction(std::move(floor_int))); ASSERT_OK(registry.RegisterLazyFunction(floor_double->descriptor())); ASSERT_OK(activation.InsertFunction(std::move(floor_double))); ASSERT_OK(registry.Register( PortableBinaryFunctionAdapter::Create( "_<_", false, [](google::protobuf::Arena*, int64_t lhs, int64_t rhs) -> bool { return lhs < rhs; }))); cel::Constant lhs; lhs.set_int64_value(20); cel::Constant rhs; rhs.set_double_value(21.9); CallExpr call1; call1.mutable_args().emplace_back(); call1.set_function("Floor"); CallExpr call2; call2.mutable_args().emplace_back(); call2.set_function("Floor"); CallExpr lt_call; lt_call.mutable_args().emplace_back(); lt_call.mutable_args().emplace_back(); lt_call.set_function("_<_"); ASSERT_OK_AND_ASSIGN( auto step0, CreateConstValueStep(cel::interop_internal::CreateIntValue(20), -1)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN( auto step2, CreateConstValueStep(cel::interop_internal::CreateDoubleValue(21.9), -1)); ASSERT_OK_AND_ASSIGN(auto step3, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step4, MakeTestFunctionStep(lt_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); path.push_back(std::move(step3)); path.push_back(std::move(step4)); std::unique_ptr impl = GetExpression(std::move(path)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsBool()); EXPECT_TRUE(value.BoolOrDie()); } // Test situation when no overloads match input arguments during evaluation // and at least one of arguments is error. TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwardingLazy) { ExecutionPath path; Activation activation; google::protobuf::Arena arena; CelFunctionRegistry registry; AddDefaults(registry); CelError error0 = absl::CancelledError(); CelError error1 = absl::CancelledError(); // Constants have ERROR type, while AddFunction expects INT. ASSERT_OK(registry.RegisterLazyFunction( ConstFunction::CreateDescriptor("ConstError1"))); ASSERT_OK(activation.InsertFunction(std::make_unique( CelValue::CreateError(&error0), "ConstError1"))); ASSERT_OK(registry.RegisterLazyFunction( ConstFunction::CreateDescriptor("ConstError2"))); ASSERT_OK(activation.InsertFunction(std::make_unique( CelValue::CreateError(&error1), "ConstError2"))); CallExpr call1 = ConstFunction::MakeCall("ConstError1"); CallExpr call2 = ConstFunction::MakeCall("ConstError2"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); EXPECT_THAT(*value.ErrorOrDie(), Eq(error0)); } std::string TestNameFn(testing::TestParamInfo opt) { switch (opt.param) { case UnknownProcessingOptions::kDisabled: return "disabled"; case UnknownProcessingOptions::kAttributeOnly: return "attribute_only"; case UnknownProcessingOptions::kAttributeAndFunction: return "attribute_and_function"; } return ""; } INSTANTIATE_TEST_SUITE_P( UnknownSupport, FunctionStepTest, testing::Values(UnknownProcessingOptions::kDisabled, UnknownProcessingOptions::kAttributeOnly, UnknownProcessingOptions::kAttributeAndFunction), &TestNameFn); class FunctionStepTestUnknowns : public testing::TestWithParam { public: std::unique_ptr GetExpression(ExecutionPath&& path) { cel::RuntimeOptions options; options.unknown_processing = GetParam(); auto env = NewTestingRuntimeEnv(); return std::make_unique( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); } }; TEST_P(FunctionStepTestUnknowns, PassedUnknownTest) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); CallExpr call1 = ConstFunction::MakeCall("Const3"); CallExpr call2 = ConstFunction::MakeCall("ConstUnknown"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); } TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); // Build the expression path that corresponds to CEL expression // "sink(param)". IdentExpr ident1; ident1.set_name("param"); CallExpr call1 = SinkFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep("param", GetExprId())); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call1, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; TestMessage msg; google::protobuf::Arena arena; activation.InsertValue("param", CelProtoWrapper::CreateMessage(&msg, &arena)); CelAttributePattern pattern( "param", {CreateCelAttributeQualifierPattern(CelValue::CreateBool(true))}); // Set attribute pattern that marks attribute "param[true]" as unknown. // It should result in "param" being handled as partially unknown, which is // is handled as fully unknown when used as function input argument. activation.set_unknown_attribute_patterns({pattern}); ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); } TEST_P(FunctionStepTestUnknowns, UnknownVsErrorPrecedenceTest) { ExecutionPath path; CelFunctionRegistry registry; AddDefaults(registry); CelError error0 = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error0); ASSERT_TRUE( registry .Register(std::make_unique(error_value, "ConstError")) .ok()); CallExpr call1 = ConstFunction::MakeCall("ConstError"); CallExpr call2 = ConstFunction::MakeCall("ConstUnknown"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); // Making sure we propagate the error. ASSERT_EQ(*value.ErrorOrDie(), *error_value.ErrorOrDie()); } INSTANTIATE_TEST_SUITE_P( UnknownFunctionSupport, FunctionStepTestUnknowns, testing::Values(UnknownProcessingOptions::kAttributeOnly, UnknownProcessingOptions::kAttributeAndFunction), &TestNameFn); TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { ExecutionPath path; CelFunctionRegistry registry; ASSERT_OK(registry.Register( std::make_unique(CelValue::CreateInt64(2), "Const2"))); ASSERT_OK(registry.Register( std::make_unique(CelValue::CreateInt64(3), "Const3"))); ASSERT_OK(registry.Register( std::make_unique(ShouldReturnUnknown::kYes))); CallExpr call1 = ConstFunction::MakeCall("Const2"); CallExpr call2 = ConstFunction::MakeCall("Const3"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); } TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { ExecutionPath path; CelFunctionRegistry registry; ASSERT_OK(registry.Register( std::make_unique(CelValue::CreateInt64(2), "Const2"))); ASSERT_OK(registry.Register( std::make_unique(CelValue::CreateInt64(3), "Const3"))); ASSERT_OK(registry.Register( std::make_unique(ShouldReturnUnknown::kYes))); // Add(Add(2, 3), Add(2, 3)) CallExpr call1 = ConstFunction::MakeCall("Const2"); CallExpr call2 = ConstFunction::MakeCall("Const3"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); ASSERT_OK_AND_ASSIGN(auto step3, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step4, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step5, MakeTestFunctionStep(add_call, registry)); ASSERT_OK_AND_ASSIGN(auto step6, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); path.push_back(std::move(step3)); path.push_back(std::move(step4)); path.push_back(std::move(step5)); path.push_back(std::move(step6)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); } TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { ExecutionPath path; CelFunctionRegistry registry; ASSERT_OK(registry.Register( std::make_unique(CelValue::CreateInt64(2), "Const2"))); ASSERT_OK(registry.Register( std::make_unique(CelValue::CreateInt64(3), "Const3"))); ASSERT_OK(registry.Register( std::make_unique(ShouldReturnUnknown::kYes))); // Add(Add(2, 3), Add(3, 2)) CallExpr call1 = ConstFunction::MakeCall("Const2"); CallExpr call2 = ConstFunction::MakeCall("Const3"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); ASSERT_OK_AND_ASSIGN(auto step3, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step4, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step5, MakeTestFunctionStep(add_call, registry)); ASSERT_OK_AND_ASSIGN(auto step6, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); path.push_back(std::move(step3)); path.push_back(std::move(step4)); path.push_back(std::move(step5)); path.push_back(std::move(step6)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()) << *(value.ErrorOrDie()); } TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { ExecutionPath path; CelFunctionRegistry registry; CelError error0 = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error0); UnknownSet unknown_set; CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); ASSERT_OK(registry.Register( std::make_unique(error_value, "ConstError"))); ASSERT_OK(registry.Register( std::make_unique(unknown_value, "ConstUnknown"))); ASSERT_OK(registry.Register( std::make_unique(ShouldReturnUnknown::kYes))); CallExpr call1 = ConstFunction::MakeCall("ConstError"); CallExpr call2 = ConstFunction::MakeCall("ConstUnknown"); CallExpr add_call = AddFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(auto step0, MakeTestFunctionStep(call1, registry)); ASSERT_OK_AND_ASSIGN(auto step1, MakeTestFunctionStep(call2, registry)); ASSERT_OK_AND_ASSIGN(auto step2, MakeTestFunctionStep(add_call, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsError()); // Making sure we propagate the error. ASSERT_EQ(*value.ErrorOrDie(), *error_value.ErrorOrDie()); } class MessageFunction : public CelFunction { public: MessageFunction() : CelFunction( CelFunctionDescriptor("Fn", false, {CelValue::Type::kMessage})) {} absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 1 || !args.at(0).IsMessage()) { return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } *result = CelValue::CreateStringView("message"); return absl::OkStatus(); } }; class MessageIdentityFunction : public CelFunction { public: MessageIdentityFunction() : CelFunction( CelFunctionDescriptor("Fn", false, {CelValue::Type::kMessage})) {} absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 1 || !args.at(0).IsMessage()) { return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } *result = args.at(0); return absl::OkStatus(); } }; class NullFunction : public CelFunction { public: NullFunction() : CelFunction( CelFunctionDescriptor("Fn", false, {CelValue::Type::kNullType})) {} absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 1 || args.at(0).type() != CelValue::Type::kNullType) { return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } *result = CelValue::CreateStringView("null"); return absl::OkStatus(); } }; TEST(FunctionStepStrictnessTest, IfFunctionStrictAndGivenUnknownSkipsInvocation) { UnknownSet unknown_set; CelFunctionRegistry registry; ASSERT_OK(registry.Register(std::make_unique( CelValue::CreateUnknownSet(&unknown_set), "ConstUnknown"))); ASSERT_OK(registry.Register(std::make_unique( CelValue::Type::kUnknownSet, /*is_strict=*/true))); ExecutionPath path; CallExpr call0 = ConstFunction::MakeCall("ConstUnknown"); CallExpr call1 = SinkFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(std::unique_ptr step0, MakeTestFunctionStep(call0, registry)); ASSERT_OK_AND_ASSIGN(std::unique_ptr step1, MakeTestFunctionStep(call1, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_TRUE(value.IsUnknownSet()); } TEST(FunctionStepStrictnessTest, IfFunctionNonStrictAndGivenUnknownInvokesIt) { UnknownSet unknown_set; CelFunctionRegistry registry; ASSERT_OK(registry.Register(std::make_unique( CelValue::CreateUnknownSet(&unknown_set), "ConstUnknown"))); ASSERT_OK(registry.Register(std::make_unique( CelValue::Type::kUnknownSet, /*is_strict=*/false))); ExecutionPath path; CallExpr call0 = ConstFunction::MakeCall("ConstUnknown"); CallExpr call1 = SinkFunction::MakeCall(); ASSERT_OK_AND_ASSIGN(std::unique_ptr step0, MakeTestFunctionStep(call0, registry)); ASSERT_OK_AND_ASSIGN(std::unique_ptr step1, MakeTestFunctionStep(call1, registry)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); Expr placeholder_expr; cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeAndFunction; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue value, impl.Evaluate(activation, &arena)); ASSERT_THAT(value, test::IsCelInt64(Eq(0))); } class DirectFunctionStepTest : public testing::Test { public: DirectFunctionStepTest() = default; void SetUp() override { ASSERT_OK(cel::RegisterStandardFunctions(registry_, options_)); } std::vector GetOverloads( absl::string_view name, int64_t arguments_size) { std::vector matcher; matcher.resize(arguments_size, cel::Kind::kAny); return registry_.FindStaticOverloads(name, false, matcher); } // Helper for shorthand constructing direct expr deps. // // Works around copies in init-list construction. std::vector> MakeDeps( std::unique_ptr dep, std::unique_ptr dep2) { std::vector> result; result.reserve(2); result.push_back(std::move(dep)); result.push_back(std::move(dep2)); return result; }; protected: cel::FunctionRegistry registry_; cel::RuntimeOptions options_; google::protobuf::Arena arena_; }; TEST_F(DirectFunctionStepTest, SimpleCall) { cel::IntValue(1); CallExpr call; call.set_function(cel::builtin::kAdd); call.mutable_args().emplace_back(); call.mutable_args().emplace_back(); std::vector> deps; deps.push_back(CreateConstValueDirectStep(cel::IntValue(1))); deps.push_back(CreateConstValueDirectStep(cel::IntValue(1))); auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), GetOverloads(cel::builtin::kAdd, 2)); auto plan = CreateExpressionImpl(options_, std::move(expr)); Activation activation; ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); EXPECT_THAT(value, test::IsCelInt64(2)); } TEST_F(DirectFunctionStepTest, RecursiveCall) { cel::IntValue(1); CallExpr call; call.set_function(cel::builtin::kAdd); call.mutable_args().emplace_back(); call.mutable_args().emplace_back(); auto overloads = GetOverloads(cel::builtin::kAdd, 2); auto MakeLeaf = [&]() { return CreateDirectFunctionStep( -1, call, MakeDeps(CreateConstValueDirectStep(cel::IntValue(1)), CreateConstValueDirectStep(cel::IntValue(1))), overloads); }; auto expr = CreateDirectFunctionStep( -1, call, MakeDeps(CreateDirectFunctionStep( -1, call, MakeDeps(MakeLeaf(), MakeLeaf()), overloads), CreateDirectFunctionStep( -1, call, MakeDeps(MakeLeaf(), MakeLeaf()), overloads)), overloads); auto plan = CreateExpressionImpl(options_, std::move(expr)); Activation activation; ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); EXPECT_THAT(value, test::IsCelInt64(8)); } TEST_F(DirectFunctionStepTest, ErrorHandlingCall) { cel::IntValue(1); CallExpr add_call; add_call.set_function(cel::builtin::kAdd); add_call.mutable_args().emplace_back(); add_call.mutable_args().emplace_back(); CallExpr div_call; div_call.set_function(cel::builtin::kDivide); div_call.mutable_args().emplace_back(); div_call.mutable_args().emplace_back(); auto add_overloads = GetOverloads(cel::builtin::kAdd, 2); auto div_overloads = GetOverloads(cel::builtin::kDivide, 2); auto error_expr = CreateDirectFunctionStep( -1, div_call, MakeDeps(CreateConstValueDirectStep(cel::IntValue(1)), CreateConstValueDirectStep(cel::IntValue(0))), div_overloads); auto expr = CreateDirectFunctionStep( -1, add_call, MakeDeps(std::move(error_expr), CreateConstValueDirectStep(cel::IntValue(1))), add_overloads); auto plan = CreateExpressionImpl(options_, std::move(expr)); Activation activation; ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); EXPECT_THAT(value, test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, testing::HasSubstr("divide by zero")))); } TEST_F(DirectFunctionStepTest, NoOverload) { cel::IntValue(1); CallExpr call; call.set_function(cel::builtin::kAdd); call.mutable_args().emplace_back(); call.mutable_args().emplace_back(); std::vector> deps; deps.push_back(CreateConstValueDirectStep(cel::IntValue(1))); deps.push_back(CreateConstValueDirectStep(cel::StringValue("2"))); auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), GetOverloads(cel::builtin::kAdd, 2)); auto plan = CreateExpressionImpl(options_, std::move(expr)); Activation activation; ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); EXPECT_THAT(value, Truly(CheckNoMatchingOverloadError)); } TEST_F(DirectFunctionStepTest, NoOverload0Args) { cel::IntValue(1); CallExpr call; call.set_function(cel::builtin::kAdd); std::vector> deps; auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), GetOverloads(cel::builtin::kAdd, 2)); auto plan = CreateExpressionImpl(options_, std::move(expr)); Activation activation; ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); EXPECT_THAT(value, Truly(CheckNoMatchingOverloadError)); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/ident_step.cc ================================================ #include "eval/eval/ident_step.h" #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/comprehension_slots.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { using ::cel::Value; using ::cel::runtime_internal::CreateError; class IdentStep : public ExpressionStepBase { public: IdentStep(absl::string_view name, int64_t expr_id) : ExpressionStepBase(expr_id), name_(name) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: std::string name_; }; absl::Status LookupIdent(absl::string_view name, ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) { if (frame.attribute_tracking_enabled()) { attribute = AttributeTrail(std::string(name)); if (frame.missing_attribute_errors_enabled() && frame.attribute_utility().CheckForMissingAttribute(attribute)) { CEL_ASSIGN_OR_RETURN( result, frame.attribute_utility().CreateMissingAttributeError( attribute.attribute())); return absl::OkStatus(); } if (frame.unknown_processing_enabled() && frame.attribute_utility().CheckForUnknownExact(attribute)) { result = frame.attribute_utility().CreateUnknownSet(attribute.attribute()); return absl::OkStatus(); } } CEL_ASSIGN_OR_RETURN( auto found, frame.activation().FindVariable(name, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); if (found) { return absl::OkStatus(); } result = cel::ErrorValue(CreateError( absl::StrCat("No value with name \"", name, "\" found in Activation"))); return absl::OkStatus(); } absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { Value value; AttributeTrail attribute; CEL_RETURN_IF_ERROR(LookupIdent(name_, *frame, value, attribute)); frame->value_stack().Push(std::move(value), std::move(attribute)); return absl::OkStatus(); } absl::StatusOr LookupSlot( absl::string_view name, size_t slot_index, ExecutionFrameBase& frame) { ComprehensionSlots::Slot* slot = frame.comprehension_slots().Get(slot_index); if (!slot->Has()) { return absl::InternalError( absl::StrCat("Comprehension variable accessed out of scope: ", name)); } return slot; } class SlotStep : public ExpressionStepBase { public: SlotStep(absl::string_view name, size_t slot_index, int64_t expr_id) : ExpressionStepBase(expr_id), name_(name), slot_index_(slot_index) {} absl::Status Evaluate(ExecutionFrame* frame) const override { CEL_ASSIGN_OR_RETURN(const ComprehensionSlots::Slot* slot, LookupSlot(name_, slot_index_, *frame)); frame->value_stack().Push(slot->value(), slot->attribute()); return absl::OkStatus(); } private: std::string name_; size_t slot_index_; }; class DirectIdentStep : public DirectExpressionStep { public: DirectIdentStep(absl::string_view name, int64_t expr_id) : DirectExpressionStep(expr_id), name_(name) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override { return LookupIdent(name_, frame, result, attribute); } private: std::string name_; }; class DirectSlotStep : public DirectExpressionStep { public: DirectSlotStep(absl::string_view name, size_t slot_index, int64_t expr_id) : DirectExpressionStep(expr_id), name_(std::move(name)), slot_index_(slot_index) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override { CEL_ASSIGN_OR_RETURN(const ComprehensionSlots::Slot* slot, LookupSlot(name_, slot_index_, frame)); if (frame.attribute_tracking_enabled()) { attribute = slot->attribute(); } result = slot->value(); return absl::OkStatus(); } private: std::string name_; size_t slot_index_; }; } // namespace std::unique_ptr CreateDirectIdentStep( absl::string_view identifier, int64_t expr_id) { return std::make_unique(identifier, expr_id); } std::unique_ptr CreateDirectSlotIdentStep( absl::string_view identifier, size_t slot_index, int64_t expr_id) { return std::make_unique(identifier, slot_index, expr_id); } absl::StatusOr> CreateIdentStep( const absl::string_view name, int64_t expr_id) { return std::make_unique(name, expr_id); } absl::StatusOr> CreateIdentStepForSlot( const absl::string_view name, size_t slot_index, int64_t expr_id) { return std::make_unique(name, slot_index, expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/ident_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_IDENT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_IDENT_STEP_H_ #include #include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { std::unique_ptr CreateDirectIdentStep( absl::string_view identifier, int64_t expr_id); std::unique_ptr CreateDirectSlotIdentStep( absl::string_view identifier, size_t slot_index, int64_t expr_id); // Factory method for Ident - based Execution step absl::StatusOr> CreateIdentStep( absl::string_view name, int64_t expr_id); // Factory method for identifier that has been assigned to a slot. absl::StatusOr> CreateIdentStepForSlot( absl::string_view name, size_t slot_index, int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_IDENT_STEP_H_ ================================================ FILE: eval/eval/ident_step_test.cc ================================================ #include "eval/eval/ident_step.h" #include #include #include #include #include "absl/status/status.h" #include "base/type_provider.h" #include "common/casting.h" #include "common/memory.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/evaluator_core.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::Cast; using ::cel::ErrorValue; using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::MemoryManagerRef; using ::cel::RuntimeOptions; using ::cel::TypeProvider; using ::cel::UnknownValue; using ::cel::Value; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::google::protobuf::Arena; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::SizeIs; TEST(IdentStepTest, TestIdentStep) { ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*id=*/-1)); ExecutionPath path; path.push_back(std::move(step)); auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), cel::RuntimeOptions{})); Activation activation; Arena arena; std::string value("test"); activation.InsertValue("name0", CelValue::CreateString(&value)); auto status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); CelValue result = status0.value(); ASSERT_TRUE(result.IsString()); EXPECT_THAT(result.StringOrDie().value(), Eq("test")); } TEST(IdentStepTest, TestIdentStepNameNotFound) { ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*id=*/-1)); ExecutionPath path; path.push_back(std::move(step)); auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), cel::RuntimeOptions{})); Activation activation; Arena arena; std::string value("test"); auto status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); CelValue result = status0.value(); ASSERT_TRUE(result.IsError()); } TEST(IdentStepTest, DisableMissingAttributeErrorsOK) { ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*id=*/-1)); ExecutionPath path; path.push_back(std::move(step)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; std::string value("test"); activation.InsertValue("name0", CelValue::CreateString(&value)); auto status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); CelValue result = status0.value(); ASSERT_TRUE(result.IsString()); EXPECT_THAT(result.StringOrDie().value(), Eq("test")); const CelAttributePattern pattern("name0", {}); activation.set_missing_attribute_patterns({pattern}); status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); EXPECT_THAT(status0->StringOrDie().value(), Eq("test")); } TEST(IdentStepTest, TestIdentStepMissingAttributeErrors) { ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*expr_id=*/1)); ExecutionPath path; path.push_back(std::move(step)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kDisabled; options.enable_missing_attribute_errors = true; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; std::string value("test"); activation.InsertValue("name0", CelValue::CreateString(&value)); auto status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); CelValue result = status0.value(); ASSERT_TRUE(result.IsString()); EXPECT_THAT(result.StringOrDie().value(), Eq("test")); CelAttributePattern pattern("name0", {}); activation.set_missing_attribute_patterns({pattern}); status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); EXPECT_EQ(status0->ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status0->ErrorOrDie()->message(), "MissingAttributeError: name0"); } TEST(IdentStepTest, TestIdentStepUnknownAttribute) { ASSERT_OK_AND_ASSIGN(auto step, CreateIdentStep("name0", /*expr_id=*/1)); ExecutionPath path; path.push_back(std::move(step)); // Expression with unknowns enabled. cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; auto env = NewTestingRuntimeEnv(); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), options)); Activation activation; Arena arena; std::string value("test"); activation.InsertValue("name0", CelValue::CreateString(&value)); std::vector unknown_patterns; unknown_patterns.push_back(CelAttributePattern("name_bad", {})); activation.set_unknown_attribute_patterns(unknown_patterns); auto status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); CelValue result = status0.value(); ASSERT_TRUE(result.IsString()); EXPECT_THAT(result.StringOrDie().value(), Eq("test")); unknown_patterns.push_back(CelAttributePattern("name0", {})); activation.set_unknown_attribute_patterns(unknown_patterns); status0 = impl.Evaluate(activation, &arena); ASSERT_OK(status0); result = status0.value(); ASSERT_TRUE(result.IsUnknownSet()); } TEST(DirectIdentStepTest, Basic) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; RuntimeOptions options; activation.InsertOrAssignValue("var1", IntValue(42)); ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); Value result; AttributeTrail trail; auto step = CreateDirectIdentStep("var1", -1); ASSERT_OK(step->Evaluate(frame, result, trail)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), Eq(42)); } TEST(DirectIdentStepTest, UnknownAttribute) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; activation.InsertOrAssignValue("var1", IntValue(42)); activation.SetUnknownPatterns({CreateCelAttributePattern("var1", {})}); ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); Value result; AttributeTrail trail; auto step = CreateDirectIdentStep("var1", -1); ASSERT_OK(step->Evaluate(frame, result, trail)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).attribute_set(), SizeIs(1)); } TEST(DirectIdentStepTest, MissingAttribute) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; RuntimeOptions options; options.enable_missing_attribute_errors = true; activation.InsertOrAssignValue("var1", IntValue(42)); activation.SetMissingPatterns({CreateCelAttributePattern("var1", {})}); ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); Value result; AttributeTrail trail; auto step = CreateDirectIdentStep("var1", -1); ASSERT_OK(step->Evaluate(frame, result, trail)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("var1"))); } TEST(DirectIdentStepTest, NotFound) { google::protobuf::Arena arena; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); cel::Activation activation; RuntimeOptions options; ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena); Value result; AttributeTrail trail; auto step = CreateDirectIdentStep("var1", -1); ASSERT_OK(step->Evaluate(frame, result, trail)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kUnknown, HasSubstr("\"var1\" found in Activation"))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/iterator_stack.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_ITERATOR_STACK_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_ITERATOR_STACK_H_ #include #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "common/value.h" namespace cel::runtime_internal { class IteratorStack final { public: explicit IteratorStack(size_t max_size) : max_size_(max_size) { iterators_.reserve(max_size_); } IteratorStack(const IteratorStack&) = delete; IteratorStack(IteratorStack&&) = delete; IteratorStack& operator=(const IteratorStack&) = delete; IteratorStack& operator=(IteratorStack&&) = delete; size_t size() const { return iterators_.size(); } bool empty() const { return iterators_.empty(); } bool full() const { return iterators_.size() == max_size_; } size_t max_size() const { return max_size_; } void Clear() { iterators_.clear(); } void Push(absl_nonnull ValueIteratorPtr iterator) { ABSL_DCHECK(!full()); ABSL_DCHECK(iterator != nullptr); iterators_.push_back(std::move(iterator)); } ValueIterator* absl_nonnull Peek() { ABSL_DCHECK(!empty()); ABSL_DCHECK(iterators_.back() != nullptr); return iterators_.back().get(); } void Pop() { ABSL_DCHECK(!empty()); iterators_.pop_back(); } private: std::vector iterators_; size_t max_size_; }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_ITERATOR_STACK_H_ ================================================ FILE: eval/eval/jump_step.cc ================================================ // Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/jump_step.h" #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "common/value.h" #include "eval/internal/errors.h" namespace google::api::expr::runtime { namespace { using ::cel::BoolValue; using ::cel::ErrorValue; using ::cel::UnknownValue; using ::cel::Value; using ::cel::runtime_internal::CreateNoMatchingOverloadError; class JumpStep : public JumpStepBase { public: // Constructs FunctionStep that uses overloads specified. JumpStep(absl::optional jump_offset, int64_t expr_id) : JumpStepBase(jump_offset, expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override { return Jump(frame); } }; class CondJumpStep : public JumpStepBase { public: // Constructs FunctionStep that uses overloads specified. CondJumpStep(bool jump_condition, bool leave_on_stack, absl::optional jump_offset, int64_t expr_id) : JumpStepBase(jump_offset, expr_id), jump_condition_(jump_condition), leave_on_stack_(leave_on_stack) {} absl::Status Evaluate(ExecutionFrame* frame) const override { // Peek the top value if (!frame->value_stack().HasEnough(1)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } const auto& value = frame->value_stack().Peek(); const auto should_jump = value.Is() && jump_condition_ == value.GetBool().NativeValue(); if (!leave_on_stack_) { frame->value_stack().Pop(1); } if (should_jump) { return Jump(frame); } return absl::OkStatus(); } private: const bool jump_condition_; const bool leave_on_stack_; }; class BoolCheckJumpStep : public JumpStepBase { public: // Checks if the top value is a boolean: // - no-op if it is a boolean // - jump to the label if it is an error value // - jump to the label if it is unknown value // - jump to the label if it is neither an error nor a boolean, pops it and // pushes "no matching overload" error BoolCheckJumpStep(absl::optional jump_offset, int64_t expr_id) : JumpStepBase(jump_offset, expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override { // Peek the top value if (!frame->value_stack().HasEnough(1)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } const Value& value = frame->value_stack().Peek(); if (value->Is()) { return absl::OkStatus(); } if (value->Is() || value->Is()) { return Jump(frame); } // Neither bool, error, nor unknown set. Value error_value = cel::ErrorValue(CreateNoMatchingOverloadError("")); frame->value_stack().PopAndPush(std::move(error_value)); return Jump(frame); return absl::OkStatus(); } }; } // namespace // Factory method for Conditional Jump step. // Conditional Jump requires a boolean value to sit on the stack. // It is compared to jump_condition, and if matched, jump is performed. std::unique_ptr CreateCondJumpStep( bool jump_condition, bool leave_on_stack, absl::optional jump_offset, int64_t expr_id) { return std::make_unique(jump_condition, leave_on_stack, jump_offset, expr_id); } // Factory method for Jump step. std::unique_ptr CreateJumpStep(absl::optional jump_offset, int64_t expr_id) { return std::make_unique(jump_offset, expr_id); } // Factory method for Conditional Jump step. // Conditional Jump requires a value to sit on the stack. // If this value is an error or unknown, a jump is performed. std::unique_ptr CreateBoolCheckJumpStep( absl::optional jump_offset, int64_t expr_id) { return std::make_unique(jump_offset, expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/jump_step.h ================================================ // Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_JUMP_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_JUMP_STEP_H_ #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/types/optional.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" namespace google::api::expr::runtime { class JumpStepBase : public ExpressionStepBase { public: JumpStepBase(absl::optional jump_offset, int64_t expr_id) : ExpressionStepBase(expr_id, false), jump_offset_(jump_offset) {} void set_jump_offset(int offset) { jump_offset_ = offset; } absl::Status Jump(ExecutionFrame* frame) const { if (!jump_offset_.has_value()) { return absl::Status(absl::StatusCode::kInternal, "Jump offset not set"); } return frame->JumpTo(jump_offset_.value()); } private: absl::optional jump_offset_; }; // Factory method for Jump step. std::unique_ptr CreateJumpStep(absl::optional jump_offset, int64_t expr_id); // Factory method for Conditional Jump step. // Conditional Jump requires a boolean value to sit on the stack. // It is compared to jump_condition, and if matched, jump is performed. // leave on stack indicates whether value should be kept on top of the stack or // removed. std::unique_ptr CreateCondJumpStep( bool jump_condition, bool leave_on_stack, absl::optional jump_offset, int64_t expr_id); // Factory method for ErrorJump step. // This step performs a Jump when an Error is on the top of the stack. // Value is left on stack if it is a bool or an error. std::unique_ptr CreateBoolCheckJumpStep( absl::optional jump_offset, int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_JUMP_STEP_H_ ================================================ FILE: eval/eval/lazy_init_step.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/lazy_init_step.h" #include #include #include #include #include "cel/expr/value.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/comprehension_slots.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { using ::cel::Value; class LazyInitStep final : public ExpressionStepBase { public: LazyInitStep(size_t slot_index, size_t subexpression_index, int64_t expr_id) : ExpressionStepBase(expr_id), slot_index_(slot_index), subexpression_index_(subexpression_index) {} absl::Status Evaluate(ExecutionFrame* frame) const override { ComprehensionSlot* slot = frame->comprehension_slots().Get(slot_index_); if (slot->Has()) { frame->value_stack().Push(slot->value(), slot->attribute()); } else { frame->Call(slot_index_, subexpression_index_); } return absl::OkStatus(); } private: const size_t slot_index_; const size_t subexpression_index_; }; class DirectLazyInitStep final : public DirectExpressionStep { public: DirectLazyInitStep(size_t slot_index, const DirectExpressionStep* subexpression, int64_t expr_id) : DirectExpressionStep(expr_id), slot_index_(slot_index), subexpression_(subexpression) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override { ComprehensionSlot* slot = frame.comprehension_slots().Get(slot_index_); if (slot->Has()) { result = slot->value(); attribute = slot->attribute(); } else { CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute)); slot->Set(result, attribute); } return absl::OkStatus(); } private: const size_t slot_index_; const DirectExpressionStep* absl_nonnull const subexpression_; }; class BindStep : public DirectExpressionStep { public: BindStep(size_t slot_index, std::unique_ptr subexpression, int64_t expr_id) : DirectExpressionStep(expr_id), slot_index_(slot_index), subexpression_(std::move(subexpression)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override { CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute)); frame.comprehension_slots().ClearSlot(slot_index_); return absl::OkStatus(); } private: size_t slot_index_; std::unique_ptr subexpression_; }; class AssignSlotAndPopStepStep final : public ExpressionStepBase { public: explicit AssignSlotAndPopStepStep(size_t slot_index) : ExpressionStepBase(/*expr_id=*/-1, /*comes_from_ast=*/false), slot_index_(slot_index) {} absl::Status Evaluate(ExecutionFrame* frame) const override { if (!frame->value_stack().HasEnough(1)) { return absl::InternalError("Stack underflow assigning lazy value"); } frame->comprehension_slots().Set(slot_index_, frame->value_stack().Peek(), frame->value_stack().PeekAttribute()); frame->value_stack().Pop(1); return absl::OkStatus(); } private: const size_t slot_index_; }; class ClearSlotStep : public ExpressionStepBase { public: explicit ClearSlotStep(size_t slot_index, int64_t expr_id) : ExpressionStepBase(expr_id), slot_index_(slot_index) {} absl::Status Evaluate(ExecutionFrame* frame) const override { frame->comprehension_slots().ClearSlot(slot_index_); return absl::OkStatus(); } private: size_t slot_index_; }; class ClearSlotsStep final : public ExpressionStepBase { public: explicit ClearSlotsStep(size_t slot_index, size_t slot_count, int64_t expr_id) : ExpressionStepBase(expr_id), slot_index_(slot_index), slot_count_(slot_count) {} absl::Status Evaluate(ExecutionFrame* frame) const override { for (size_t i = 0; i < slot_count_; ++i) { frame->comprehension_slots().ClearSlot(slot_index_ + i); } return absl::OkStatus(); } private: const size_t slot_index_; const size_t slot_count_; }; class BlockStep : public DirectExpressionStep { public: BlockStep(size_t slot_index, size_t slot_count, std::unique_ptr subexpression, int64_t expr_id) : DirectExpressionStep(expr_id), slot_index_(slot_index), slot_count_(slot_count), subexpression_(std::move(subexpression)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override { CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute)); for (size_t i = 0; i < slot_count_; ++i) { frame.comprehension_slots().ClearSlot(slot_index_ + i); } return absl::OkStatus(); } private: size_t slot_index_; size_t slot_count_; std::unique_ptr subexpression_; }; } // namespace std::unique_ptr CreateDirectBindStep( size_t slot_index, std::unique_ptr expression, int64_t expr_id) { return std::make_unique(slot_index, std::move(expression), expr_id); } std::unique_ptr CreateDirectBlockStep( size_t slot_index, size_t slot_count, std::unique_ptr expression, int64_t expr_id) { return std::make_unique(slot_index, slot_count, std::move(expression), expr_id); } std::unique_ptr CreateDirectLazyInitStep( size_t slot_index, const DirectExpressionStep* absl_nonnull subexpression, int64_t expr_id) { return std::make_unique(slot_index, subexpression, expr_id); } std::unique_ptr CreateLazyInitStep(size_t slot_index, size_t subexpression_index, int64_t expr_id) { return std::make_unique(slot_index, subexpression_index, expr_id); } std::unique_ptr CreateAssignSlotAndPopStep(size_t slot_index) { return std::make_unique(slot_index); } std::unique_ptr CreateClearSlotStep(size_t slot_index, int64_t expr_id) { return std::make_unique(slot_index, expr_id); } std::unique_ptr CreateClearSlotsStep(size_t slot_index, size_t slot_count, int64_t expr_id) { ABSL_DCHECK_GT(slot_count, 0); return std::make_unique(slot_index, slot_count, expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/lazy_init_step.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Program steps for lazily initialized aliases (e.g. cel.bind). // // When used, any reference to variable should be replaced with a conditional // step that either runs the initialization routine or pushes the already // initialized variable to the stack. // // All references to the variable should be replaced with: // // +-----------------+-------------------+--------------------+ // | stack | pc | step | // +-----------------+-------------------+--------------------+ // | {} | 0 | check init slot(i) | // +-----------------+-------------------+--------------------+ // | {value} | 1 | assign slot(i) | // +-----------------+-------------------+--------------------+ // | {value} | 2 | | // +-----------------+-------------------+--------------------+ // | .... | // +-----------------+-------------------+--------------------+ // | {...} | n (end of scope) | clear slot(i) | // +-----------------+-------------------+--------------------+ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_LAZY_INIT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_LAZY_INIT_STEP_H_ #include #include #include #include "absl/base/nullability.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Creates a step representing a Bind expression. std::unique_ptr CreateDirectBindStep( size_t slot_index, std::unique_ptr expression, int64_t expr_id); // Creates a step representing a cel.@block expression. std::unique_ptr CreateDirectBlockStep( size_t slot_index, size_t slot_count, std::unique_ptr expression, int64_t expr_id); // Creates a direct step representing accessing a lazily evaluated alias from // a bind or block. std::unique_ptr CreateDirectLazyInitStep( size_t slot_index, const DirectExpressionStep* absl_nonnull subexpression, int64_t expr_id); // Creates a step representing accessing a lazily evaluated alias from // a bind or block. std::unique_ptr CreateLazyInitStep(size_t slot_index, size_t subexpression_index, int64_t expr_id); // Helper step to assign a slot value from the top of stack on initialization. std::unique_ptr CreateAssignSlotAndPopStep(size_t slot_index); // Helper step to clear a slot. // Slots may be reused in different contexts so need to be cleared after a // context is done. std::unique_ptr CreateClearSlotStep(size_t slot_index, int64_t expr_id); std::unique_ptr CreateClearSlotsStep(size_t slot_index, size_t slot_count, int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_LAZY_INIT_STEP_H_ ================================================ FILE: eval/eval/lazy_init_step_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/lazy_init_step.h" #include #include #include "base/type_provider.h" #include "common/value.h" #include "eval/eval/const_value_step.h" #include "eval/eval/evaluator_core.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::cel::Activation; using ::cel::IntValue; using ::cel::RuntimeOptions; using ::cel::TypeProvider; class LazyInitStepTest : public testing::Test { private: // arbitrary numbers enough for basic tests. static constexpr size_t kValueStack = 5; static constexpr size_t kComprehensionSlotCount = 3; public: LazyInitStepTest() : type_provider_(cel::internal::GetTestingDescriptorPool()), evaluator_state_(kValueStack, kComprehensionSlotCount, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_) {} protected: google::protobuf::Arena arena_; cel::runtime_internal::RuntimeTypeProvider type_provider_; FlatExpressionEvaluatorState evaluator_state_; RuntimeOptions runtime_options_; Activation activation_; }; TEST_F(LazyInitStepTest, CreateCheckInitStepDoesInit) { ExecutionPath path; ExecutionPath subpath; path.push_back(CreateLazyInitStep(/*slot_index=*/0, /*subexpression_index=*/1, -1)); ASSERT_OK_AND_ASSIGN(subpath.emplace_back(), CreateConstValueStep(cel::IntValue(42), -1, false)); std::vector expression_table{path, subpath}; ExecutionFrame frame(expression_table, activation_, runtime_options_, evaluator_state_); ASSERT_OK_AND_ASSIGN(auto value, frame.Evaluate()); EXPECT_TRUE(value->Is() && value.GetInt().NativeValue() == 42); } TEST_F(LazyInitStepTest, CreateCheckInitStepSkipInit) { ExecutionPath path; ExecutionPath subpath; // This is the expected usage, but in this test we are just depending on the // fact that these don't change the stack and fit the program layout // requirements. path.push_back(CreateLazyInitStep(/*slot_index=*/0, -1, -1)); ASSERT_OK_AND_ASSIGN(subpath.emplace_back(), CreateConstValueStep(cel::IntValue(42), -1, false)); std::vector expression_table{path, subpath}; ExecutionFrame frame(expression_table, activation_, runtime_options_, evaluator_state_); frame.comprehension_slots().Set(0, cel::IntValue(42)); ASSERT_OK_AND_ASSIGN(auto value, frame.Evaluate()); EXPECT_TRUE(value->Is() && value.GetInt().NativeValue() == 42); } TEST_F(LazyInitStepTest, CreateAssignSlotAndPopStepBasic) { ExecutionPath path; path.push_back(CreateAssignSlotAndPopStep(0)); ExecutionFrame frame(path, activation_, runtime_options_, evaluator_state_); frame.comprehension_slots().ClearSlot(0); frame.value_stack().Push(cel::IntValue(42)); // This will error because no return value, step will still evaluate. frame.Evaluate().IgnoreError(); auto* slot = frame.comprehension_slots().Get(0); ASSERT_TRUE(slot->Has()); EXPECT_TRUE(slot->value()->Is() && slot->value().GetInt().NativeValue() == 42); EXPECT_TRUE(frame.value_stack().empty()); } TEST_F(LazyInitStepTest, CreateClearSlotStepBasic) { ExecutionPath path; path.push_back(CreateClearSlotStep(0, -1)); ExecutionFrame frame(path, activation_, runtime_options_, evaluator_state_); frame.comprehension_slots().Set(0, cel::IntValue(42)); // This will error because no return value, step will still evaluate. frame.Evaluate().IgnoreError(); auto* slot = frame.comprehension_slots().Get(0); ASSERT_FALSE(slot->Has()); } TEST_F(LazyInitStepTest, CreateClearSlotsStepBasic) { ExecutionPath path; path.push_back(CreateClearSlotsStep(0, 2, -1)); ExecutionFrame frame(path, activation_, runtime_options_, evaluator_state_); frame.comprehension_slots().Set(0, cel::IntValue(42)); frame.comprehension_slots().Set(1, cel::IntValue(42)); // This will error because no return value, step will still evaluate. frame.Evaluate().IgnoreError(); EXPECT_FALSE(frame.comprehension_slots().Get(0)->Has()); EXPECT_FALSE(frame.comprehension_slots().Get(1)->Has()); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/logic_step.cc ================================================ #include "eval/eval/logic_step.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/builtins.h" #include "common/casting.h" #include "common/value.h" #include "common/value_kind.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" #include "internal/status_macros.h" #include "runtime/internal/errors.h" namespace google::api::expr::runtime { namespace { using ::cel::BoolValue; using ::cel::Cast; using ::cel::ErrorValue; using ::cel::InstanceOf; using ::cel::UnknownValue; using ::cel::Value; using ::cel::ValueKind; using ::cel::runtime_internal::CreateNoMatchingOverloadError; enum class OpType { kAnd, kOr }; // Shared logic for the fall through case (we didn't see the shortcircuit // value). absl::Status ReturnLogicResult(ExecutionFrameBase& frame, OpType op_type, Value& lhs_result, Value& rhs_result, AttributeTrail& attribute_trail, AttributeTrail& rhs_attr) { ValueKind lhs_kind = lhs_result.kind(); ValueKind rhs_kind = rhs_result.kind(); if (frame.unknown_processing_enabled()) { if (lhs_kind == ValueKind::kUnknown && rhs_kind == ValueKind::kUnknown) { lhs_result = frame.attribute_utility().MergeUnknownValues( Cast(lhs_result), Cast(rhs_result)); // Clear attribute trail so this doesn't get re-identified as a new // unknown and reset the accumulated attributes. attribute_trail = AttributeTrail(); return absl::OkStatus(); } else if (lhs_kind == ValueKind::kUnknown) { return absl::OkStatus(); } else if (rhs_kind == ValueKind::kUnknown) { lhs_result = std::move(rhs_result); attribute_trail = std::move(rhs_attr); return absl::OkStatus(); } } if (lhs_kind == ValueKind::kError) { return absl::OkStatus(); } else if (rhs_kind == ValueKind::kError) { lhs_result = std::move(rhs_result); attribute_trail = std::move(rhs_attr); return absl::OkStatus(); } if (lhs_kind == ValueKind::kBool && rhs_kind == ValueKind::kBool) { return absl::OkStatus(); } // Otherwise, add a no overload error. attribute_trail = AttributeTrail(); lhs_result = cel::ErrorValue(CreateNoMatchingOverloadError( op_type == OpType::kOr ? cel::builtin::kOr : cel::builtin::kAnd)); return absl::OkStatus(); } class ExhaustiveDirectLogicStep : public DirectExpressionStep { public: explicit ExhaustiveDirectLogicStep(std::unique_ptr lhs, std::unique_ptr rhs, OpType op_type, int64_t expr_id) : DirectExpressionStep(expr_id), lhs_(std::move(lhs)), rhs_(std::move(rhs)), op_type_(op_type) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& attribute_trail) const override; private: std::unique_ptr lhs_; std::unique_ptr rhs_; OpType op_type_; }; absl::Status ExhaustiveDirectLogicStep::Evaluate( ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& attribute_trail) const { CEL_RETURN_IF_ERROR(lhs_->Evaluate(frame, result, attribute_trail)); ValueKind lhs_kind = result.kind(); Value rhs_result; AttributeTrail rhs_attr; CEL_RETURN_IF_ERROR(rhs_->Evaluate(frame, rhs_result, attribute_trail)); ValueKind rhs_kind = rhs_result.kind(); if (lhs_kind == ValueKind::kBool) { bool lhs_bool = Cast(result).NativeValue(); if ((op_type_ == OpType::kOr && lhs_bool) || (op_type_ == OpType::kAnd && !lhs_bool)) { return absl::OkStatus(); } } if (rhs_kind == ValueKind::kBool) { bool rhs_bool = Cast(rhs_result).NativeValue(); if ((op_type_ == OpType::kOr && rhs_bool) || (op_type_ == OpType::kAnd && !rhs_bool)) { result = std::move(rhs_result); attribute_trail = std::move(rhs_attr); return absl::OkStatus(); } } return ReturnLogicResult(frame, op_type_, result, rhs_result, attribute_trail, rhs_attr); } class DirectLogicStep : public DirectExpressionStep { public: explicit DirectLogicStep(std::unique_ptr lhs, std::unique_ptr rhs, OpType op_type, int64_t expr_id) : DirectExpressionStep(expr_id), lhs_(std::move(lhs)), rhs_(std::move(rhs)), op_type_(op_type) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& attribute_trail) const override; private: std::unique_ptr lhs_; std::unique_ptr rhs_; OpType op_type_; }; absl::Status DirectLogicStep::Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const { CEL_RETURN_IF_ERROR(lhs_->Evaluate(frame, result, attribute_trail)); ValueKind lhs_kind = result.kind(); if (lhs_kind == ValueKind::kBool) { bool lhs_bool = Cast(result).NativeValue(); if ((op_type_ == OpType::kOr && lhs_bool) || (op_type_ == OpType::kAnd && !lhs_bool)) { return absl::OkStatus(); } } Value rhs_result; AttributeTrail rhs_attr; CEL_RETURN_IF_ERROR(rhs_->Evaluate(frame, rhs_result, attribute_trail)); ValueKind rhs_kind = rhs_result.kind(); if (rhs_kind == ValueKind::kBool) { bool rhs_bool = Cast(rhs_result).NativeValue(); if ((op_type_ == OpType::kOr && rhs_bool) || (op_type_ == OpType::kAnd && !rhs_bool)) { result = std::move(rhs_result); attribute_trail = std::move(rhs_attr); return absl::OkStatus(); } } return ReturnLogicResult(frame, op_type_, result, rhs_result, attribute_trail, rhs_attr); } class LogicalOpStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. LogicalOpStep(OpType op_type, int64_t expr_id) : ExpressionStepBase(expr_id), op_type_(op_type) { shortcircuit_ = (op_type_ == OpType::kOr); } absl::Status Evaluate(ExecutionFrame* frame) const override; private: void Calculate(ExecutionFrame* frame, absl::Span args, Value& result) const { bool bool_args[2]; bool has_bool_args[2]; for (size_t i = 0; i < args.size(); i++) { has_bool_args[i] = args[i]->Is(); if (has_bool_args[i]) { bool_args[i] = args[i].GetBool().NativeValue(); if (bool_args[i] == shortcircuit_) { result = BoolValue{bool_args[i]}; return; } } } if (has_bool_args[0] && has_bool_args[1]) { switch (op_type_) { case OpType::kAnd: result = BoolValue{bool_args[0] && bool_args[1]}; return; case OpType::kOr: result = BoolValue{bool_args[0] || bool_args[1]}; return; } } // As opposed to regular function, logical operation treat Unknowns with // higher precedence than error. This is due to the fact that after Unknown // is resolved to actual value, it may short-circuit and thus hide the // error. if (frame->enable_unknowns()) { // Check if unknown? absl::optional unknown_set = frame->attribute_utility().MergeUnknowns(args); if (unknown_set.has_value()) { result = std::move(*unknown_set); return; } } if (args[0]->Is()) { result = args[0]; return; } else if (args[1]->Is()) { result = args[1]; return; } // Fallback. result = cel::ErrorValue(CreateNoMatchingOverloadError( (op_type_ == OpType::kOr) ? cel::builtin::kOr : cel::builtin::kAnd)); } const OpType op_type_; bool shortcircuit_; }; absl::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { // Must have 2 or more values on the stack. if (!frame->value_stack().HasEnough(2)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } // Create Span object that contains input arguments to the function. auto args = frame->value_stack().GetSpan(2); Value result; Calculate(frame, args, result); frame->value_stack().PopAndPush(args.size(), std::move(result)); return absl::OkStatus(); } std::unique_ptr CreateDirectLogicStep( std::unique_ptr lhs, std::unique_ptr rhs, int64_t expr_id, OpType op_type, bool shortcircuiting) { if (shortcircuiting) { return std::make_unique(std::move(lhs), std::move(rhs), op_type, expr_id); } else { return std::make_unique( std::move(lhs), std::move(rhs), op_type, expr_id); } } class DirectNotStep : public DirectExpressionStep { public: explicit DirectNotStep(std::unique_ptr operand, int64_t expr_id) : DirectExpressionStep(expr_id), operand_(std::move(operand)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const override; private: std::unique_ptr operand_; }; absl::Status DirectNotStep::Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const { CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute_trail)); if (frame.unknown_processing_enabled()) { if (frame.attribute_utility().CheckForUnknownPartial(attribute_trail)) { result = frame.attribute_utility().CreateUnknownSet( attribute_trail.attribute()); return absl::OkStatus(); } } switch (result.kind()) { case ValueKind::kBool: result = BoolValue{!result.GetBool().NativeValue()}; break; case ValueKind::kUnknown: case ValueKind::kError: // just forward. break; default: result = cel::ErrorValue(CreateNoMatchingOverloadError(cel::builtin::kNot)); break; } return absl::OkStatus(); } class IterativeNotStep : public ExpressionStepBase { public: explicit IterativeNotStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; }; absl::Status IterativeNotStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { return absl::InternalError("Value stack underflow"); } const Value& operand = frame->value_stack().Peek(); if (frame->unknown_processing_enabled()) { const AttributeTrail& attribute_trail = frame->value_stack().PeekAttribute(); if (frame->attribute_utility().CheckForUnknownPartial(attribute_trail)) { frame->value_stack().PopAndPush( frame->attribute_utility().CreateUnknownSet( attribute_trail.attribute())); return absl::OkStatus(); } } switch (operand.kind()) { case ValueKind::kBool: frame->value_stack().PopAndPush( BoolValue{!operand.GetBool().NativeValue()}); break; case ValueKind::kUnknown: case ValueKind::kError: // just forward. break; default: frame->value_stack().PopAndPush( cel::ErrorValue(CreateNoMatchingOverloadError(cel::builtin::kNot))); break; } return absl::OkStatus(); } class DirectNotStrictlyFalseStep : public DirectExpressionStep { public: explicit DirectNotStrictlyFalseStep( std::unique_ptr operand, int64_t expr_id) : DirectExpressionStep(expr_id), operand_(std::move(operand)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const override; private: std::unique_ptr operand_; }; absl::Status DirectNotStrictlyFalseStep::Evaluate( ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute_trail) const { CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute_trail)); switch (result.kind()) { case ValueKind::kBool: // just forward. break; case ValueKind::kUnknown: case ValueKind::kError: result = BoolValue(true); break; default: result = cel::ErrorValue(CreateNoMatchingOverloadError(cel::builtin::kNot)); break; } return absl::OkStatus(); } class IterativeNotStrictlyFalseStep : public ExpressionStepBase { public: explicit IterativeNotStrictlyFalseStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; }; absl::Status IterativeNotStrictlyFalseStep::Evaluate( ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { return absl::InternalError("Value stack underflow"); } const Value& operand = frame->value_stack().Peek(); switch (operand.kind()) { case ValueKind::kBool: // just forward. break; case ValueKind::kUnknown: case ValueKind::kError: frame->value_stack().PopAndPush(BoolValue(true)); break; default: frame->value_stack().PopAndPush( cel::ErrorValue(CreateNoMatchingOverloadError(cel::builtin::kNot))); break; } return absl::OkStatus(); } } // namespace // Factory method for "And" Execution step std::unique_ptr CreateDirectAndStep( std::unique_ptr lhs, std::unique_ptr rhs, int64_t expr_id, bool shortcircuiting) { return CreateDirectLogicStep(std::move(lhs), std::move(rhs), expr_id, OpType::kAnd, shortcircuiting); } // Factory method for "Or" Execution step std::unique_ptr CreateDirectOrStep( std::unique_ptr lhs, std::unique_ptr rhs, int64_t expr_id, bool shortcircuiting) { return CreateDirectLogicStep(std::move(lhs), std::move(rhs), expr_id, OpType::kOr, shortcircuiting); } // Factory method for "And" Execution step absl::StatusOr> CreateAndStep(int64_t expr_id) { return std::make_unique(OpType::kAnd, expr_id); } // Factory method for "Or" Execution step absl::StatusOr> CreateOrStep(int64_t expr_id) { return std::make_unique(OpType::kOr, expr_id); } // Factory method for recursive logical not "!" Execution step std::unique_ptr CreateDirectNotStep( std::unique_ptr operand, int64_t expr_id) { return std::make_unique(std::move(operand), expr_id); } // Factory method for iterative logical not "!" Execution step std::unique_ptr CreateNotStep(int64_t expr_id) { return std::make_unique(expr_id); } // Factory method for recursive logical "@not_strictly_false" Execution step. std::unique_ptr CreateDirectNotStrictlyFalseStep( std::unique_ptr operand, int64_t expr_id) { return std::make_unique(std::move(operand), expr_id); } // Factory method for iterative logical "@not_strictly_false" Execution step. std::unique_ptr CreateNotStrictlyFalseStep(int64_t expr_id) { return std::make_unique(expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/logic_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_LOGIC_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_LOGIC_STEP_H_ #include #include #include "absl/status/statusor.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Factory method for "And" Execution step std::unique_ptr CreateDirectAndStep( std::unique_ptr lhs, std::unique_ptr rhs, int64_t expr_id, bool shortcircuiting); // Factory method for "Or" Execution step std::unique_ptr CreateDirectOrStep( std::unique_ptr lhs, std::unique_ptr rhs, int64_t expr_id, bool shortcircuiting); // Factory method for "And" Execution step absl::StatusOr> CreateAndStep(int64_t expr_id); // Factory method for "Or" Execution step absl::StatusOr> CreateOrStep(int64_t expr_id); // Factory method for recursive logical not "!" Execution step std::unique_ptr CreateDirectNotStep( std::unique_ptr operand, int64_t expr_id); // Factory method for iterative logical not "!" Execution step std::unique_ptr CreateNotStep(int64_t expr_id); // Factory method for recursive logical "@not_strictly_false" Execution step. std::unique_ptr CreateDirectNotStrictlyFalseStep( std::unique_ptr operand, int64_t expr_id); // Factory method for iterative logical "@not_strictly_false" Execution step. std::unique_ptr CreateNotStrictlyFalseStep(int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_LOGIC_STEP_H_ ================================================ FILE: eval/eval/logic_step_test.cc ================================================ #include "eval/eval/logic_step.h" #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "base/attribute.h" #include "base/attribute_set.h" #include "base/type_provider.h" #include "common/casting.h" #include "common/expr.h" #include "common/unknown.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/const_value_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::cel::Attribute; using ::cel::AttributeSet; using ::cel::BoolValue; using ::cel::Cast; using ::cel::Expr; using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::TypeProvider; using ::cel::UnknownValue; using ::cel::Value; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::testing::Eq; class LogicStepTest : public testing::TestWithParam { public: LogicStepTest() : env_(NewTestingRuntimeEnv()) {} absl::Status EvaluateLogic(CelValue arg0, CelValue arg1, bool is_or, CelValue* result, bool enable_unknown) { ExecutionPath path; CEL_ASSIGN_OR_RETURN(auto step, CreateIdentStep("name0", /*expr_id=*/-1)); path.push_back(std::move(step)); CEL_ASSIGN_OR_RETURN(step, CreateIdentStep("name1", /*expr_id=*/-1)); path.push_back(std::move(step)); CEL_ASSIGN_OR_RETURN(step, (is_or) ? CreateOrStep(2) : CreateAndStep(2)); path.push_back(std::move(step)); auto dummy_expr = std::make_unique(); cel::RuntimeOptions options; if (enable_unknown) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl impl( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("name0", arg0); activation.InsertValue("name1", arg1); CEL_ASSIGN_OR_RETURN(CelValue value, impl.Evaluate(activation, &arena_)); *result = value; return absl::OkStatus(); } private: absl_nonnull std::shared_ptr env_; Arena arena_; }; TEST_P(LogicStepTest, TestAndLogic) { CelValue result; absl::Status status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), false, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(false), false, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(true), false, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(false), false, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } TEST_P(LogicStepTest, TestOrLogic) { CelValue result; absl::Status status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), true, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(false), true, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(true), true, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(false), true, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } TEST_P(LogicStepTest, TestAndLogicErrorHandling) { CelValue result; CelError error = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error); absl::Status status = EvaluateLogic(error_value, CelValue::CreateBool(true), false, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsError()); status = EvaluateLogic(CelValue::CreateBool(true), error_value, false, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsError()); status = EvaluateLogic(CelValue::CreateBool(false), error_value, false, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(error_value, CelValue::CreateBool(false), false, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } TEST_P(LogicStepTest, TestOrLogicErrorHandling) { CelValue result; CelError error = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error); absl::Status status = EvaluateLogic(error_value, CelValue::CreateBool(false), true, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsError()); status = EvaluateLogic(CelValue::CreateBool(false), error_value, true, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsError()); status = EvaluateLogic(CelValue::CreateBool(true), error_value, true, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(error_value, CelValue::CreateBool(true), true, &result, GetParam()); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } TEST_F(LogicStepTest, TestAndLogicUnknownHandling) { CelValue result; UnknownSet unknown_set; CelError cel_error = absl::CancelledError(); CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); CelValue error_value = CelValue::CreateError(&cel_error); absl::Status status = EvaluateLogic(unknown_value, CelValue::CreateBool(true), false, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, false, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, false, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(unknown_value, CelValue::CreateBool(false), false, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); status = EvaluateLogic(error_value, unknown_value, false, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(unknown_value, error_value, false, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); CelAttribute attr0("name0", {}), attr1("name1", {}); UnknownAttributeSet unknown_attr_set0({attr0}); UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), false, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); ASSERT_THAT(result.UnknownSetOrDie()->unknown_attributes().size(), Eq(2)); } TEST_F(LogicStepTest, TestOrLogicUnknownHandling) { CelValue result; UnknownSet unknown_set; CelError cel_error = absl::CancelledError(); CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); CelValue error_value = CelValue::CreateError(&cel_error); absl::Status status = EvaluateLogic( unknown_value, CelValue::CreateBool(false), true, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, true, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, true, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(unknown_value, CelValue::CreateBool(true), true, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(unknown_value, error_value, true, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); status = EvaluateLogic(error_value, unknown_value, true, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); CelAttribute attr0("name0", {}), attr1("name1", {}); UnknownAttributeSet unknown_attr_set0({attr0}); UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), true, &result, true); ASSERT_THAT(status, IsOk()); ASSERT_TRUE(result.IsUnknownSet()); ASSERT_THAT(result.UnknownSetOrDie()->unknown_attributes().size(), Eq(2)); } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); enum class BinaryOp { kAnd, kOr }; enum class UnaryOp { kNot, kNotStrictlyFalse }; enum class OpArg { kTrue, kFalse, kUnknown, kError, // Arbitrary incorrect type kInt }; enum class OpResult { kTrue, kFalse, kUnknown, kError, }; struct BinaryTestCase { std::string name; BinaryOp op; OpArg arg0; OpArg arg1; OpResult result; }; UnknownValue MakeUnknownValue(std::string attr) { std::vector attrs; attrs.push_back(Attribute(std::move(attr))); return cel::UnknownValue(cel::Unknown(AttributeSet(attrs))); } std::unique_ptr MakeArgStep(OpArg arg, absl::string_view name) { switch (arg) { case OpArg::kTrue: return CreateConstValueDirectStep(BoolValue(true)); case OpArg::kFalse: return CreateConstValueDirectStep(BoolValue(false)); case OpArg::kUnknown: return CreateConstValueDirectStep(MakeUnknownValue(std::string(name))); case OpArg::kError: return CreateConstValueDirectStep( cel::ErrorValue(absl::InternalError(name))); case OpArg::kInt: return CreateConstValueDirectStep(IntValue(42)); } }; class DirectBinaryLogicStepTest : public testing::TestWithParam> { public: DirectBinaryLogicStepTest() = default; bool ShortcircuitingEnabled() { return std::get<0>(GetParam()); } const BinaryTestCase& GetTestCase() { return std::get<1>(GetParam()); } protected: Arena arena_; }; TEST_P(DirectBinaryLogicStepTest, TestCases) { const BinaryTestCase& test_case = GetTestCase(); std::unique_ptr lhs = MakeArgStep(test_case.arg0, "lhs"); std::unique_ptr rhs = MakeArgStep(test_case.arg1, "rhs"); std::unique_ptr op = (test_case.op == BinaryOp::kAnd) ? CreateDirectAndStep(std::move(lhs), std::move(rhs), -1, ShortcircuitingEnabled()) : CreateDirectOrStep(std::move(lhs), std::move(rhs), -1, ShortcircuitingEnabled()); cel::Activation activation; cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value value; AttributeTrail attr; ASSERT_THAT(op->Evaluate(frame, value, attr), IsOk()); switch (test_case.result) { case OpResult::kTrue: ASSERT_TRUE(value.IsBool()); EXPECT_TRUE(value.GetBool().NativeValue()); break; case OpResult::kFalse: ASSERT_TRUE(value.IsBool()); EXPECT_FALSE(value.GetBool().NativeValue()); break; case OpResult::kUnknown: EXPECT_TRUE(value.IsUnknown()); break; case OpResult::kError: EXPECT_TRUE(value.IsError()); break; } } INSTANTIATE_TEST_SUITE_P( DirectBinaryLogicStepTest, DirectBinaryLogicStepTest, testing::Combine(testing::Bool(), testing::ValuesIn>({ { "AndFalseFalse", BinaryOp::kAnd, OpArg::kFalse, OpArg::kFalse, OpResult::kFalse, }, { "AndFalseTrue", BinaryOp::kAnd, OpArg::kFalse, OpArg::kTrue, OpResult::kFalse, }, { "AndTrueFalse", BinaryOp::kAnd, OpArg::kTrue, OpArg::kFalse, OpResult::kFalse, }, { "AndTrueTrue", BinaryOp::kAnd, OpArg::kTrue, OpArg::kTrue, OpResult::kTrue, }, { "AndTrueError", BinaryOp::kAnd, OpArg::kTrue, OpArg::kError, OpResult::kError, }, { "AndErrorTrue", BinaryOp::kAnd, OpArg::kError, OpArg::kTrue, OpResult::kError, }, { "AndFalseError", BinaryOp::kAnd, OpArg::kFalse, OpArg::kError, OpResult::kFalse, }, { "AndErrorFalse", BinaryOp::kAnd, OpArg::kError, OpArg::kFalse, OpResult::kFalse, }, { "AndErrorError", BinaryOp::kAnd, OpArg::kError, OpArg::kError, OpResult::kError, }, { "AndTrueUnknown", BinaryOp::kAnd, OpArg::kTrue, OpArg::kUnknown, OpResult::kUnknown, }, { "AndUnknownTrue", BinaryOp::kAnd, OpArg::kUnknown, OpArg::kTrue, OpResult::kUnknown, }, { "AndFalseUnknown", BinaryOp::kAnd, OpArg::kFalse, OpArg::kUnknown, OpResult::kFalse, }, { "AndUnknownFalse", BinaryOp::kAnd, OpArg::kUnknown, OpArg::kFalse, OpResult::kFalse, }, { "AndUnknownUnknown", BinaryOp::kAnd, OpArg::kUnknown, OpArg::kUnknown, OpResult::kUnknown, }, { "AndUnknownError", BinaryOp::kAnd, OpArg::kUnknown, OpArg::kError, OpResult::kUnknown, }, { "AndErrorUnknown", BinaryOp::kAnd, OpArg::kError, OpArg::kUnknown, OpResult::kUnknown, }, // Or cases are simplified since the logic generalizes // and is covered by and cases. })), [](const testing::TestParamInfo& info) -> std::string { bool shortcircuiting_enabled = std::get<0>(info.param); absl::string_view name = std::get<1>(info.param).name; return absl::StrCat( name, (shortcircuiting_enabled ? "ShortcircuitingEnabled" : "")); }); struct UnaryTestCase { std::string name; UnaryOp op; OpArg arg; OpResult result; }; class DirectUnaryLogicStepTest : public testing::TestWithParam { public: DirectUnaryLogicStepTest() = default; const UnaryTestCase& GetTestCase() { return GetParam(); } protected: Arena arena_; }; TEST_P(DirectUnaryLogicStepTest, TestCases) { const UnaryTestCase& test_case = GetTestCase(); std::unique_ptr arg = MakeArgStep(test_case.arg, "arg"); std::unique_ptr op = (test_case.op == UnaryOp::kNot) ? CreateDirectNotStep(std::move(arg), -1) : CreateDirectNotStrictlyFalseStep(std::move(arg), -1); cel::Activation activation; cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; cel::runtime_internal::RuntimeTypeProvider type_provider( cel::internal::GetTestingDescriptorPool()); ExecutionFrameBase frame(activation, options, type_provider, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value value; AttributeTrail attr; ASSERT_THAT(op->Evaluate(frame, value, attr), IsOk()); switch (test_case.result) { case OpResult::kTrue: ASSERT_TRUE(value.IsBool()); EXPECT_TRUE(value.GetBool().NativeValue()); break; case OpResult::kFalse: ASSERT_TRUE(value.IsBool()); EXPECT_FALSE(value.GetBool().NativeValue()); break; case OpResult::kUnknown: EXPECT_TRUE(value.IsUnknown()); break; case OpResult::kError: EXPECT_TRUE(value.IsError()); break; } } INSTANTIATE_TEST_SUITE_P( DirectUnaryLogicStepTest, DirectUnaryLogicStepTest, testing::ValuesIn>( {UnaryTestCase{"NotTrue", UnaryOp::kNot, OpArg::kTrue, OpResult::kFalse}, UnaryTestCase{"NotError", UnaryOp::kNot, OpArg::kError, OpResult::kError}, UnaryTestCase{"NotUnknown", UnaryOp::kNot, OpArg::kUnknown, OpResult::kUnknown}, UnaryTestCase{"NotInt", UnaryOp::kNot, OpArg::kInt, OpResult::kError}, UnaryTestCase{"NotFalse", UnaryOp::kNot, OpArg::kFalse, OpResult::kTrue}, UnaryTestCase{"NotStrictlyFalseTrue", UnaryOp::kNotStrictlyFalse, OpArg::kTrue, OpResult::kTrue}, UnaryTestCase{"NotStrictlyFalseError", UnaryOp::kNotStrictlyFalse, OpArg::kError, OpResult::kTrue}, UnaryTestCase{"NotStrictlyFalseUnknown", UnaryOp::kNotStrictlyFalse, OpArg::kUnknown, OpResult::kTrue}, UnaryTestCase{"NotStrictlyFalseInt", UnaryOp::kNotStrictlyFalse, OpArg::kInt, OpResult::kError}, UnaryTestCase{"NotStrictlyFalseFalse", UnaryOp::kNotStrictlyFalse, OpArg::kFalse, OpResult::kFalse}}), [](const testing::TestParamInfo& info) -> std::string { return info.param.name; }); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/optional_or_step.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/optional_or_step.h" #include #include #include #include "absl/base/optimization.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/casting.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/eval/jump_step.h" #include "internal/status_macros.h" #include "runtime/internal/errors.h" namespace google::api::expr::runtime { namespace { using ::cel::As; using ::cel::ErrorValue; using ::cel::InstanceOf; using ::cel::OptionalValue; using ::cel::UnknownValue; using ::cel::Value; using ::cel::runtime_internal::CreateNoMatchingOverloadError; enum class OptionalOrKind { kOrOptional, kOrValue }; ErrorValue MakeNoOverloadError(OptionalOrKind kind) { switch (kind) { case OptionalOrKind::kOrOptional: return ErrorValue(CreateNoMatchingOverloadError("or")); case OptionalOrKind::kOrValue: return ErrorValue(CreateNoMatchingOverloadError("orValue")); } ABSL_UNREACHABLE(); } // Implements short-circuiting for optional.or. // Expected layout if short-circuiting enabled: // // +--------+-----------------------+-------------------------------+ // | idx | Step | Stack After | // +--------+-----------------------+-------------------------------+ // | 1 | | OptionalValue | // +--------+-----------------------+-------------------------------+ // | 2 | Jump to 5 if present | OptionalValue | // +--------+-----------------------+-------------------------------+ // | 3 | | OptionalValue, OptionalValue | // +--------+-----------------------+-------------------------------+ // | 4 | optional.or | OptionalValue | // +--------+-----------------------+-------------------------------+ // | 5 | | ... | // +--------------------------------+-------------------------------+ // // If implementing the orValue variant, the jump step handles unwrapping ( // getting the result of optional.value()) class OptionalHasValueJumpStep final : public JumpStepBase { public: OptionalHasValueJumpStep(int64_t expr_id, OptionalOrKind kind) : JumpStepBase({}, expr_id), kind_(kind) {} absl::Status Evaluate(ExecutionFrame* frame) const override { if (!frame->value_stack().HasEnough(1)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } const auto& value = frame->value_stack().Peek(); auto optional_value = As(value); // We jump if the receiver is `optional_type` which has a value or the // receiver is an error/unknown. Unlike `_||_` we are not commutative. If // we run into an error/unknown, we skip the `else` branch. const bool should_jump = (optional_value.has_value() && optional_value->HasValue()) || (!optional_value.has_value() && (cel::InstanceOf(value) || cel::InstanceOf(value))); if (should_jump) { if (kind_ == OptionalOrKind::kOrValue && optional_value.has_value()) { frame->value_stack().PopAndPush(optional_value->Value()); } return Jump(frame); } return absl::OkStatus(); } private: const OptionalOrKind kind_; }; class OptionalOrStep : public ExpressionStepBase { public: explicit OptionalOrStep(int64_t expr_id, OptionalOrKind kind) : ExpressionStepBase(expr_id), kind_(kind) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: const OptionalOrKind kind_; }; // Shared implementation for optional or. // // If return value is Ok, the result is assigned to the result reference // argument. absl::Status EvalOptionalOr(OptionalOrKind kind, const Value& lhs, const Value& rhs, const AttributeTrail& lhs_attr, const AttributeTrail& rhs_attr, Value& result, AttributeTrail& result_attr) { if (InstanceOf(lhs) || InstanceOf(lhs)) { result = lhs; result_attr = lhs_attr; return absl::OkStatus(); } auto lhs_optional_value = As(lhs); if (!lhs_optional_value.has_value()) { result = MakeNoOverloadError(kind); result_attr = AttributeTrail(); return absl::OkStatus(); } if (lhs_optional_value->HasValue()) { if (kind == OptionalOrKind::kOrValue) { result = lhs_optional_value->Value(); } else { result = lhs; } result_attr = lhs_attr; return absl::OkStatus(); } if (kind == OptionalOrKind::kOrOptional && !InstanceOf(rhs) && !InstanceOf(rhs) && !InstanceOf(rhs)) { result = MakeNoOverloadError(kind); result_attr = AttributeTrail(); return absl::OkStatus(); } result = rhs; result_attr = rhs_attr; return absl::OkStatus(); } absl::Status OptionalOrStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(2)) { return absl::InternalError("Value stack underflow"); } absl::Span args = frame->value_stack().GetSpan(2); absl::Span args_attr = frame->value_stack().GetAttributeSpan(2); Value result; AttributeTrail result_attr; CEL_RETURN_IF_ERROR(EvalOptionalOr(kind_, args[0], args[1], args_attr[0], args_attr[1], result, result_attr)); frame->value_stack().PopAndPush(2, std::move(result), std::move(result_attr)); return absl::OkStatus(); } class ExhaustiveDirectOptionalOrStep : public DirectExpressionStep { public: ExhaustiveDirectOptionalOrStep( int64_t expr_id, std::unique_ptr optional, std::unique_ptr alternative, OptionalOrKind kind) : DirectExpressionStep(expr_id), kind_(kind), optional_(std::move(optional)), alternative_(std::move(alternative)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override; private: OptionalOrKind kind_; std::unique_ptr optional_; std::unique_ptr alternative_; }; absl::Status ExhaustiveDirectOptionalOrStep::Evaluate( ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { CEL_RETURN_IF_ERROR(optional_->Evaluate(frame, result, attribute)); Value rhs; AttributeTrail rhs_attr; CEL_RETURN_IF_ERROR(alternative_->Evaluate(frame, rhs, rhs_attr)); CEL_RETURN_IF_ERROR(EvalOptionalOr(kind_, result, rhs, attribute, rhs_attr, result, attribute)); return absl::OkStatus(); } class DirectOptionalOrStep : public DirectExpressionStep { public: DirectOptionalOrStep(int64_t expr_id, std::unique_ptr optional, std::unique_ptr alternative, OptionalOrKind kind) : DirectExpressionStep(expr_id), kind_(kind), optional_(std::move(optional)), alternative_(std::move(alternative)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override; private: OptionalOrKind kind_; std::unique_ptr optional_; std::unique_ptr alternative_; }; absl::Status DirectOptionalOrStep::Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { CEL_RETURN_IF_ERROR(optional_->Evaluate(frame, result, attribute)); if (InstanceOf(result) || InstanceOf(result)) { // Forward the lhs error instead of attempting to evaluate the alternative // (unlike CEL's commutative logic operators). return absl::OkStatus(); } auto optional_value = As(static_cast(result)); if (!optional_value.has_value()) { result = MakeNoOverloadError(kind_); return absl::OkStatus(); } if (optional_value->HasValue()) { if (kind_ == OptionalOrKind::kOrValue) { result = optional_value->Value(); } return absl::OkStatus(); } CEL_RETURN_IF_ERROR(alternative_->Evaluate(frame, result, attribute)); // If optional.or check that rhs is an optional. // // Otherwise, we don't know what type to expect so can't check anything. if (kind_ == OptionalOrKind::kOrOptional) { if (!InstanceOf(result) && !InstanceOf(result) && !InstanceOf(result)) { result = MakeNoOverloadError(kind_); } } return absl::OkStatus(); } } // namespace std::unique_ptr CreateOptionalHasValueJumpStep(bool or_value, int64_t expr_id) { return std::make_unique( expr_id, or_value ? OptionalOrKind::kOrValue : OptionalOrKind::kOrOptional); } std::unique_ptr CreateOptionalOrStep(bool is_or_value, int64_t expr_id) { return std::make_unique( expr_id, is_or_value ? OptionalOrKind::kOrValue : OptionalOrKind::kOrOptional); } std::unique_ptr CreateDirectOptionalOrStep( int64_t expr_id, std::unique_ptr optional, std::unique_ptr alternative, bool is_or_value, bool short_circuiting) { auto kind = is_or_value ? OptionalOrKind::kOrValue : OptionalOrKind::kOrOptional; if (short_circuiting) { return std::make_unique(expr_id, std::move(optional), std::move(alternative), kind); } else { return std::make_unique( expr_id, std::move(optional), std::move(alternative), kind); } } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/optional_or_step.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_OPTIONAL_OR_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_OPTIONAL_OR_STEP_H_ #include #include #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/jump_step.h" namespace google::api::expr::runtime { // Factory method for OptionalHasValueJump step, used to implement // short-circuiting optional.or and optional.orValue. // // Requires that the top of the stack is an optional. If `optional.hasValue` is // true, performs a jump. If `or_value` is true and we are jumping, // `optional.value` is called and the result replaces the optional at the top of // the stack. std::unique_ptr CreateOptionalHasValueJumpStep(bool or_value, int64_t expr_id); // Factory method for OptionalOr step, used to implement optional.or and // optional.orValue. std::unique_ptr CreateOptionalOrStep(bool is_or_value, int64_t expr_id); // Creates a step implementing the short-circuiting optional.or or // optional.orValue step. std::unique_ptr CreateDirectOptionalOrStep( int64_t expr_id, std::unique_ptr optional, std::unique_ptr alternative, bool is_or_value, bool short_circuiting); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_OPTIONAL_OR_STEP_H_ ================================================ FILE: eval/eval/optional_or_step_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/optional_or_step.h" #include #include "absl/memory/memory.h" #include "absl/status/status.h" #include "common/casting.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/const_value_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/errors.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::Activation; using ::cel::As; using ::cel::ErrorValue; using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::OptionalValue; using ::cel::RuntimeOptions; using ::cel::UnknownValue; using ::cel::Value; using ::cel::ValueKind; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::cel::test::OptionalValueIs; using ::cel::test::ValueKindIs; using ::testing::HasSubstr; using ::testing::NiceMock; class MockDirectStep : public DirectExpressionStep { public: MOCK_METHOD(absl::Status, Evaluate, (ExecutionFrameBase & frame, Value& result, AttributeTrail& scratch), (const, override)); }; std::unique_ptr MockNeverCalledDirectStep() { auto* mock = new NiceMock(); EXPECT_CALL(*mock, Evaluate).Times(0); return absl::WrapUnique(mock); } std::unique_ptr MockExpectCallDirectStep() { auto* mock = new NiceMock(); EXPECT_CALL(*mock, Evaluate) .Times(1) .WillRepeatedly( [](ExecutionFrameBase& frame, Value& result, AttributeTrail& attr) { result = ErrorValue(absl::InternalError("expected to be unused")); return absl::OkStatus(); }); return absl::WrapUnique(mock); } class OptionalOrTest : public testing::Test { public: OptionalOrTest() : type_provider_(cel::internal::GetTestingDescriptorPool()) {} protected: google::protobuf::Arena arena_; cel::runtime_internal::RuntimeTypeProvider type_provider_; Activation empty_activation_; }; TEST_F(OptionalOrTest, OptionalOrLeftPresentShortcutRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::Of(IntValue(42), &arena_)), MockNeverCalledDirectStep(), /*is_or_value=*/false, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, OptionalValueIs(IntValueIs(42))); } TEST_F(OptionalOrTest, OptionalOrLeftErrorShortcutsRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(ErrorValue(absl::InternalError("error"))), MockNeverCalledDirectStep(), /*is_or_value=*/false, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); } TEST_F(OptionalOrTest, OptionalOrLeftErrorExhaustiveRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(ErrorValue(absl::InternalError("error"))), MockExpectCallDirectStep(), /*is_or_value=*/false, /*short_circuiting=*/false); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); } TEST_F(OptionalOrTest, OptionalOrLeftUnknownShortcutsRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(UnknownValue()), MockNeverCalledDirectStep(), /*is_or_value=*/false, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); } TEST_F(OptionalOrTest, OptionalOrLeftUnknownExhaustiveRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(UnknownValue()), MockExpectCallDirectStep(), /*is_or_value=*/false, /*short_circuiting=*/false); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); } TEST_F(OptionalOrTest, OptionalOrLeftAbsentReturnRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::None()), CreateConstValueDirectStep(OptionalValue::Of(IntValue(42), &arena_)), /*is_or_value=*/false, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, OptionalValueIs(IntValueIs(42))); } TEST_F(OptionalOrTest, OptionalOrLeftWrongType) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(IntValue(42)), MockNeverCalledDirectStep(), /*is_or_value=*/false, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ErrorValueIs(StatusIs( absl::StatusCode::kUnknown, HasSubstr(cel::runtime_internal::kErrNoMatchingOverload)))); } TEST_F(OptionalOrTest, OptionalOrRightWrongType) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::None()), CreateConstValueDirectStep(IntValue(42)), /*is_or_value=*/false, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ErrorValueIs(StatusIs( absl::StatusCode::kUnknown, HasSubstr(cel::runtime_internal::kErrNoMatchingOverload)))); } TEST_F(OptionalOrTest, OptionalOrValueLeftPresentShortcutRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::Of(IntValue(42), &arena_)), MockNeverCalledDirectStep(), /*is_or_value=*/true, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, IntValueIs(42)); } TEST_F(OptionalOrTest, OptionalOrValueLeftPresentExhaustiveRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::Of(IntValue(42), &arena_)), MockExpectCallDirectStep(), /*is_or_value=*/true, /*short_circuiting=*/false); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, IntValueIs(42)); } TEST_F(OptionalOrTest, OptionalOrValueLeftErrorShortcutsRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(ErrorValue(absl::InternalError("error"))), MockNeverCalledDirectStep(), /*is_or_value=*/true, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ValueKindIs(ValueKind::kError)); } TEST_F(OptionalOrTest, OptionalOrValueLeftUnknownShortcutsRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(UnknownValue()), MockNeverCalledDirectStep(), true, true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ValueKindIs(ValueKind::kUnknown)); } TEST_F(OptionalOrTest, OptionalOrValueLeftAbsentReturnRight) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(OptionalValue::None()), CreateConstValueDirectStep(IntValue(42)), /*is_or_value=*/true, /*short_circuiting=*/true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, IntValueIs(42)); } TEST_F(OptionalOrTest, OptionalOrValueLeftWrongType) { RuntimeOptions options; ExecutionFrameBase frame(empty_activation_, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectOptionalOrStep( /*expr_id=*/-1, CreateConstValueDirectStep(IntValue(42)), MockNeverCalledDirectStep(), true, true); Value result; AttributeTrail scratch; ASSERT_OK(step->Evaluate(frame, result, scratch)); EXPECT_THAT(result, ErrorValueIs(StatusIs( absl::StatusCode::kUnknown, HasSubstr(cel::runtime_internal::kErrNoMatchingOverload)))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/regex_match_step.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/regex_match_step.h" #include #include #include #include #include #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" #include "re2/re2.h" namespace google::api::expr::runtime { namespace { using ::cel::BoolValue; using ::cel::StringValue; using ::cel::Value; inline constexpr int kNumRegexMatchArguments = 1; inline constexpr size_t kRegexMatchStepSubject = 0; struct MatchesVisitor final { const RE2& re; bool operator()(const absl::Cord& value) const { if (auto flat = value.TryFlat(); flat.has_value()) { return RE2::PartialMatch(*flat, re); } return RE2::PartialMatch(static_cast(value), re); } bool operator()(absl::string_view value) const { return RE2::PartialMatch(value, re); } }; class RegexMatchStep final : public ExpressionStepBase { public: RegexMatchStep(int64_t expr_id, std::shared_ptr re2) : ExpressionStepBase(expr_id, /*comes_from_ast=*/true), re2_(std::move(re2)) {} absl::Status Evaluate(ExecutionFrame* frame) const override { if (!frame->value_stack().HasEnough(kNumRegexMatchArguments)) { return absl::Status(absl::StatusCode::kInternal, "Insufficient arguments supplied for regular " "expression match"); } auto input_args = frame->value_stack().GetSpan(kNumRegexMatchArguments); const auto& subject = input_args[kRegexMatchStepSubject]; if (!subject->Is()) { return absl::Status(absl::StatusCode::kInternal, "First argument for regular " "expression match must be a string"); } bool match = subject.GetString().NativeValue(MatchesVisitor{*re2_}); frame->value_stack().Pop(kNumRegexMatchArguments); frame->value_stack().Push(cel::BoolValue(match)); return absl::OkStatus(); } private: const std::shared_ptr re2_; }; class RegexMatchDirectStep final : public DirectExpressionStep { public: RegexMatchDirectStep(int64_t expr_id, std::unique_ptr subject, std::shared_ptr re2) : DirectExpressionStep(expr_id), subject_(std::move(subject)), re2_(std::move(re2)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override { AttributeTrail subject_attr; CEL_RETURN_IF_ERROR(subject_->Evaluate(frame, result, subject_attr)); if (result.IsError() || result.IsUnknown()) { return absl::OkStatus(); } if (!result.IsString()) { return absl::Status(absl::StatusCode::kInternal, "First argument for regular " "expression match must be a string"); } bool match = result.GetString().NativeValue(MatchesVisitor{*re2_}); result = BoolValue(match); return absl::OkStatus(); } private: std::unique_ptr subject_; const std::shared_ptr re2_; }; } // namespace std::unique_ptr CreateDirectRegexMatchStep( int64_t expr_id, std::unique_ptr subject, std::shared_ptr re2) { return std::make_unique(expr_id, std::move(subject), std::move(re2)); } absl::StatusOr> CreateRegexMatchStep( std::shared_ptr re2, int64_t expr_id) { return std::make_unique(expr_id, std::move(re2)); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/regex_match_step.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_REGEX_MATCH_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_REGEX_MATCH_STEP_H_ #include #include #include "absl/status/statusor.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "re2/re2.h" namespace google::api::expr::runtime { std::unique_ptr CreateDirectRegexMatchStep( int64_t expr_id, std::unique_ptr subject, std::shared_ptr re2); absl::StatusOr> CreateRegexMatchStep( std::shared_ptr re2, int64_t expr_id); } #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_REGEX_MATCH_STEP_H_ ================================================ FILE: eval/eval/regex_match_step_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/eval/regex_match_step.h" #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_options.h" #include "internal/testing.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using cel::expr::CheckedExpr; using cel::expr::Reference; using ::testing::Eq; using ::testing::HasSubstr; Reference MakeMatchesStringOverload() { Reference reference; reference.add_overload_id("matches_string"); return reference; } TEST(RegexMatchStep, Precompiled) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("foo.matches('hello')")); CheckedExpr checked_expr; *checked_expr.mutable_expr() = parsed_expr.expr(); *checked_expr.mutable_source_info() = parsed_expr.source_info(); checked_expr.mutable_reference_map()->insert( {checked_expr.expr().id(), MakeMatchesStringOverload()}); InterpreterOptions options; options.enable_regex_precompilation = true; auto expr_builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto expr, expr_builder->CreateExpression(&checked_expr)); activation.InsertValue("foo", CelValue::CreateStringView("hello world!")); ASSERT_OK_AND_ASSIGN(auto result, expr->Evaluate(activation, &arena)); EXPECT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST(RegexMatchStep, PrecompiledInvalidRegex) { Activation activation; ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("foo.matches('(')")); CheckedExpr checked_expr; *checked_expr.mutable_expr() = parsed_expr.expr(); *checked_expr.mutable_source_info() = parsed_expr.source_info(); checked_expr.mutable_reference_map()->insert( {checked_expr.expr().id(), MakeMatchesStringOverload()}); InterpreterOptions options; options.enable_regex_precompilation = true; auto expr_builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid regular expression"))); } TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) { Activation activation; ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("foo.matches('hello')")); CheckedExpr checked_expr; *checked_expr.mutable_expr() = parsed_expr.expr(); *checked_expr.mutable_source_info() = parsed_expr.source_info(); checked_expr.mutable_reference_map()->insert( {checked_expr.expr().id(), MakeMatchesStringOverload()}); InterpreterOptions options; options.regex_max_program_size = 1; options.enable_regex_precompilation = true; auto expr_builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), StatusIs(absl::StatusCode::kInvalidArgument, Eq("regular expression exceeds max allowed size"))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/select_step.cc ================================================ #include "eval/eval/select_step.h" #include #include #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/expr.h" #include "common/value.h" #include "common/value_kind.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::cel::BoolValue; using ::cel::ErrorValue; using ::cel::MapValue; using ::cel::NullValue; using ::cel::OptionalValue; using ::cel::ProtoWrapperTypeOptions; using ::cel::StringValue; using ::cel::StructValue; using ::cel::Value; using ::cel::ValueKind; // Common error for cases where evaluation attempts to perform select operations // on an unsupported type. // // This should not happen under normal usage of the evaluator, but useful for // troubleshooting broken invariants. absl::Status InvalidSelectTargetError() { return absl::Status(absl::StatusCode::kInvalidArgument, "Applying SELECT to non-message type"); } absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, ExecutionFrameBase& frame) { if (frame.unknown_processing_enabled() && frame.attribute_utility().CheckForUnknownExact(trail)) { return frame.attribute_utility().CreateUnknownSet(trail.attribute()); } if (frame.missing_attribute_errors_enabled() && frame.attribute_utility().CheckForMissingAttribute(trail)) { auto result = frame.attribute_utility().CreateMissingAttributeError( trail.attribute()); if (result.ok()) { return std::move(result).value(); } // Invariant broken (an invalid CEL Attribute shouldn't match anything). // Log and return a CelError. ABSL_LOG(ERROR) << "Invalid attribute pattern matched select path: " << result.status().ToString(); // NOLINT: OSS compatibility return cel::ErrorValue(std::move(result).status()); } return absl::nullopt; } void TestOnlySelect(const StructValue& msg, const std::string& field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { absl::StatusOr has_field = msg.HasFieldByName(field); if (!has_field.ok()) { *result = ErrorValue(std::move(has_field).status()); return; } *result = BoolValue{*has_field}; } void TestOnlySelect(const MapValue& map, const StringValue& field_name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) { // Field presence only supports string keys containing valid identifier // characters. absl::Status presence = map.Has(field_name, descriptor_pool, message_factory, arena, result); if (!presence.ok()) { *result = ErrorValue(std::move(presence)); return; } ABSL_DCHECK(!result->IsUnknown()); } // SelectStep performs message field access specified by Expr::Select // message. class SelectStep : public ExpressionStepBase { public: SelectStep(StringValue value, bool test_field_presence, int64_t expr_id, bool enable_wrapper_type_null_unboxing, bool enable_optional_types) : ExpressionStepBase(expr_id), field_value_(std::move(value)), field_(field_value_.ToString()), test_field_presence_(test_field_presence), unboxing_option_(enable_wrapper_type_null_unboxing ? ProtoWrapperTypeOptions::kUnsetNull : ProtoWrapperTypeOptions::kUnsetProtoDefault), enable_optional_types_(enable_optional_types) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: absl::Status PerformTestOnlySelect(ExecutionFrame* frame, const Value& arg) const; absl::StatusOr PerformSelect(ExecutionFrame* frame, const Value& arg, Value& result) const; cel::StringValue field_value_; std::string field_; bool test_field_presence_; ProtoWrapperTypeOptions unboxing_option_; bool enable_optional_types_; }; absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { return absl::Status(absl::StatusCode::kInternal, "No arguments supplied for Select-type expression"); } const Value& arg = frame->value_stack().Peek(); const AttributeTrail& trail = frame->value_stack().PeekAttribute(); if (arg.IsUnknown() || arg.IsError()) { // Bubble up unknowns and errors. return absl::OkStatus(); } AttributeTrail result_trail; // Handle unknown resolution. if (frame->enable_unknowns() || frame->enable_missing_attribute_errors()) { result_trail = trail.Step(&field_); } absl::optional optional_arg; if (enable_optional_types_ && arg.IsOptional()) { optional_arg = arg.GetOptional(); } if (!(optional_arg || arg->Is() || arg->Is())) { frame->value_stack().PopAndPush(cel::ErrorValue(InvalidSelectTargetError()), std::move(result_trail)); return absl::OkStatus(); } absl::optional marked_attribute_check = CheckForMarkedAttributes(result_trail, *frame); if (marked_attribute_check.has_value()) { frame->value_stack().PopAndPush(std::move(marked_attribute_check).value(), std::move(result_trail)); return absl::OkStatus(); } // Handle test only Select. if (test_field_presence_) { if (optional_arg) { if (!optional_arg->HasValue()) { frame->value_stack().PopAndPush(cel::BoolValue{false}); return absl::OkStatus(); } Value value; optional_arg->Value(&value); return PerformTestOnlySelect(frame, value); } return PerformTestOnlySelect(frame, arg); } // Normal select path. // Select steps can be applied to either maps or messages if (optional_arg) { if (!optional_arg->HasValue()) { // Leave optional_arg at the top of the stack. Its empty. return absl::OkStatus(); } Value value; Value result; bool ok; optional_arg->Value(&value); CEL_ASSIGN_OR_RETURN(ok, PerformSelect(frame, value, result)); if (!ok) { frame->value_stack().PopAndPush(cel::OptionalValue::None(), std::move(result_trail)); return absl::OkStatus(); } frame->value_stack().PopAndPush( cel::OptionalValue::Of(std::move(result), frame->arena()), std::move(result_trail)); return absl::OkStatus(); } // Normal select path. // Select steps can be applied to either maps or messages switch (arg.kind()) { case ValueKind::kStruct: { Value result; auto status = arg.GetStruct().GetFieldByName( field_, unboxing_option_, frame->descriptor_pool(), frame->message_factory(), frame->arena(), &result); if (!status.ok()) { result = ErrorValue(std::move(status)); } frame->value_stack().PopAndPush(std::move(result), std::move(result_trail)); return absl::OkStatus(); } case ValueKind::kMap: { Value result; auto status = arg.GetMap().Get(field_value_, frame->descriptor_pool(), frame->message_factory(), frame->arena(), &result); if (!status.ok()) { result = ErrorValue(std::move(status)); } frame->value_stack().PopAndPush(std::move(result), std::move(result_trail)); return absl::OkStatus(); } default: // Control flow should have returned earlier. return InvalidSelectTargetError(); } } absl::Status SelectStep::PerformTestOnlySelect(ExecutionFrame* frame, const Value& arg) const { switch (arg.kind()) { case ValueKind::kMap: { Value result; TestOnlySelect(arg.GetMap(), field_value_, frame->descriptor_pool(), frame->message_factory(), frame->arena(), &result); frame->value_stack().PopAndPush(std::move(result)); return absl::OkStatus(); } case ValueKind::kMessage: { Value result; TestOnlySelect(arg.GetStruct(), field_, frame->descriptor_pool(), frame->message_factory(), frame->arena(), &result); frame->value_stack().PopAndPush(std::move(result)); return absl::OkStatus(); } default: // Control flow should have returned earlier. return InvalidSelectTargetError(); } } absl::StatusOr SelectStep::PerformSelect(ExecutionFrame* frame, const Value& arg, Value& result) const { switch (arg->kind()) { case ValueKind::kStruct: { const auto& struct_value = arg.GetStruct(); CEL_ASSIGN_OR_RETURN(auto ok, struct_value.HasFieldByName(field_)); if (!ok) { result = NullValue{}; return false; } CEL_RETURN_IF_ERROR(struct_value.GetFieldByName( field_, unboxing_option_, frame->descriptor_pool(), frame->message_factory(), frame->arena(), &result)); ABSL_DCHECK(!result.IsUnknown()); return true; } case ValueKind::kMap: { CEL_ASSIGN_OR_RETURN( auto found, arg.GetMap().Find(field_value_, frame->descriptor_pool(), frame->message_factory(), frame->arena(), &result)); ABSL_DCHECK(!found || !result.IsUnknown()); return found; } default: // Control flow should have returned earlier. return InvalidSelectTargetError(); } } class DirectSelectStep : public DirectExpressionStep { public: DirectSelectStep(int64_t expr_id, std::unique_ptr operand, StringValue field, bool test_only, bool enable_wrapper_type_null_unboxing, bool enable_optional_types) : DirectExpressionStep(expr_id), operand_(std::move(operand)), field_value_(std::move(field)), field_(field_value_.ToString()), test_only_(test_only), unboxing_option_(enable_wrapper_type_null_unboxing ? ProtoWrapperTypeOptions::kUnsetNull : ProtoWrapperTypeOptions::kUnsetProtoDefault), enable_optional_types_(enable_optional_types) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override { CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute)); if (result.IsError() || result.IsUnknown()) { // Just forward. return absl::OkStatus(); } if (frame.attribute_tracking_enabled()) { attribute = attribute.Step(&field_); absl::optional value = CheckForMarkedAttributes(attribute, frame); if (value.has_value()) { result = std::move(value).value(); return absl::OkStatus(); } } absl::optional optional_arg; if (enable_optional_types_ && result.IsOptional()) { optional_arg = result.GetOptional(); } switch (result.kind()) { case ValueKind::kStruct: case ValueKind::kMap: break; default: if (optional_arg) { break; } result = cel::ErrorValue(InvalidSelectTargetError()); return absl::OkStatus(); } if (test_only_) { if (optional_arg) { if (!optional_arg->HasValue()) { result = cel::BoolValue{false}; return absl::OkStatus(); } Value value; optional_arg->Value(&value); PerformTestOnlySelect(frame, value, result); return absl::OkStatus(); } PerformTestOnlySelect(frame, result, result); return absl::OkStatus(); } if (optional_arg) { if (!optional_arg->HasValue()) { // result is still buffer for the container. just return. return absl::OkStatus(); } Value value; optional_arg->Value(&value); return PerformOptionalSelect(frame, value, result); } auto status = PerformSelect(frame, result, result); if (!status.ok()) { result = ErrorValue(std::move(status)); } return absl::OkStatus(); } private: std::unique_ptr operand_; void PerformTestOnlySelect(ExecutionFrameBase& frame, const Value& value, Value& result) const; absl::Status PerformOptionalSelect(ExecutionFrameBase& frame, const Value& value, Value& result) const; absl::Status PerformSelect(ExecutionFrameBase& frame, const Value& value, Value& result) const; // Field name in formats supported by each of the map and struct field access // APIs. // // ToString or ValueManager::CreateString may force a copy so we do this at // plan time. StringValue field_value_; std::string field_; // whether this is a has() expression. bool test_only_; ProtoWrapperTypeOptions unboxing_option_; bool enable_optional_types_; }; void DirectSelectStep::PerformTestOnlySelect(ExecutionFrameBase& frame, const cel::Value& value, Value& result) const { switch (value.kind()) { case ValueKind::kMap: TestOnlySelect(value.GetMap(), field_value_, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result); return; case ValueKind::kMessage: TestOnlySelect(value.GetStruct(), field_, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result); return; default: // Control flow should have returned earlier. result = cel::ErrorValue(InvalidSelectTargetError()); return; } } absl::Status DirectSelectStep::PerformOptionalSelect(ExecutionFrameBase& frame, const Value& value, Value& result) const { switch (value.kind()) { case ValueKind::kStruct: { auto struct_value = value.GetStruct(); CEL_ASSIGN_OR_RETURN(auto ok, struct_value.HasFieldByName(field_)); if (!ok) { result = OptionalValue::None(); return absl::OkStatus(); } CEL_RETURN_IF_ERROR(struct_value.GetFieldByName( field_, unboxing_option_, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); ABSL_DCHECK(!result.IsUnknown()); result = OptionalValue::Of(std::move(result), frame.arena()); return absl::OkStatus(); } case ValueKind::kMap: { CEL_ASSIGN_OR_RETURN( auto found, value.GetMap().Find(field_value_, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); if (!found) { result = OptionalValue::None(); return absl::OkStatus(); } ABSL_DCHECK(!result.IsUnknown()); result = OptionalValue::Of(std::move(result), frame.arena()); return absl::OkStatus(); } default: // Control flow should have returned earlier. return InvalidSelectTargetError(); } } absl::Status DirectSelectStep::PerformSelect(ExecutionFrameBase& frame, const cel::Value& value, Value& result) const { switch (value.kind()) { case ValueKind::kStruct: CEL_RETURN_IF_ERROR(value.GetStruct().GetFieldByName( field_, unboxing_option_, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); ABSL_DCHECK(!result.IsUnknown()); return absl::OkStatus(); case ValueKind::kMap: CEL_RETURN_IF_ERROR( value.GetMap().Get(field_value_, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); ABSL_DCHECK(!result.IsUnknown()); return absl::OkStatus(); default: // Control flow should have returned earlier. return InvalidSelectTargetError(); } } } // namespace std::unique_ptr CreateDirectSelectStep( std::unique_ptr operand, StringValue field, bool test_only, int64_t expr_id, bool enable_wrapper_type_null_unboxing, bool enable_optional_types) { return std::make_unique( expr_id, std::move(operand), std::move(field), test_only, enable_wrapper_type_null_unboxing, enable_optional_types); } // Factory method for Select - based Execution step absl::StatusOr> CreateSelectStep( const cel::SelectExpr& select_expr, int64_t expr_id, bool enable_wrapper_type_null_unboxing, bool enable_optional_types) { return std::make_unique( cel::StringValue(select_expr.field()), select_expr.test_only(), expr_id, enable_wrapper_type_null_unboxing, enable_optional_types); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/select_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_SELECT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_SELECT_STEP_H_ #include #include #include "absl/status/statusor.h" #include "common/expr.h" #include "common/value.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Factory method for recursively evaluated select step. std::unique_ptr CreateDirectSelectStep( std::unique_ptr operand, cel::StringValue field, bool test_only, int64_t expr_id, bool enable_wrapper_type_null_unboxing, bool enable_optional_types = false); // Factory method for Select - based Execution step absl::StatusOr> CreateSelectStep( const cel::SelectExpr& select_expr, int64_t expr_id, bool enable_wrapper_type_null_unboxing, bool enable_optional_types = false); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_SELECT_STEP_H_ ================================================ FILE: eval/eval/select_step_test.cc ================================================ #include "eval/eval/select_step.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/attribute.h" #include "base/attribute_set.h" #include "base/type_provider.h" #include "common/casting.h" #include "common/expr.h" #include "common/legacy_value.h" #include "common/value.h" #include "common/value_testing.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/const_value_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_extensions.pb.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/value.h" #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::Attribute; using ::cel::AttributeQualifier; using ::cel::AttributeSet; using ::cel::BoolValue; using ::cel::Cast; using ::cel::ErrorValue; using ::cel::Expr; using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::OptionalValue; using ::cel::RuntimeOptions; using ::cel::TypeProvider; using ::cel::UnknownValue; using ::cel::Value; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::extensions::ProtoMessageToValue; using ::cel::internal::test::EqualsProto; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::cel::test::IntValueIs; using ::testing::_; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Return; using ::testing::UnorderedElementsAre; struct RunExpressionOptions { bool enable_unknowns = false; bool enable_wrapper_type_null_unboxing = false; }; // Simple implementation LegacyTypeAccessApis / LegacyTypeInfoApis that allows // mocking for getters/setters. class MockAccessor : public LegacyTypeAccessApis, public LegacyTypeInfoApis { public: MOCK_METHOD(absl::StatusOr, HasField, (absl::string_view field_name, const CelValue::MessageWrapper& value), (const, override)); MOCK_METHOD(absl::StatusOr, GetField, (absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager), (const, override)); MOCK_METHOD(absl::string_view, GetTypename, (const CelValue::MessageWrapper& instance), (const, override)); MOCK_METHOD(std::string, DebugString, (const CelValue::MessageWrapper& instance), (const, override)); MOCK_METHOD(std::vector, ListFields, (const CelValue::MessageWrapper& value), (const, override)); const LegacyTypeAccessApis* GetAccessApis( const CelValue::MessageWrapper& instance) const override { return this; } }; class SelectStepTest : public testing::Test { public: SelectStepTest() : env_(NewTestingRuntimeEnv()) {} // Helper method. Creates simple pipeline containing Select step and runs it. absl::StatusOr RunExpression(const CelValue target, absl::string_view field, bool test, absl::string_view unknown_path, RunExpressionOptions options) { ExecutionPath path; Expr expr; auto& select = expr.mutable_select_expr(); select.set_field(std::string(field)); select.set_test_only(test); Expr& expr0 = select.mutable_operand(); auto& ident = expr0.mutable_ident_expr(); ident.set_name("target"); CEL_ASSIGN_OR_RETURN(auto step0, CreateIdentStep(ident.name(), expr0.id())); CEL_ASSIGN_OR_RETURN( auto step1, CreateSelectStep(select, expr.id(), options.enable_wrapper_type_null_unboxing)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); cel::RuntimeOptions runtime_options; if (options.enable_unknowns) { runtime_options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), runtime_options)); Activation activation; activation.InsertValue("target", target); return cel_expr.Evaluate(activation, &arena_); } absl::StatusOr RunExpression(const TestExtensions* message, absl::string_view field, bool test, RunExpressionOptions options) { return RunExpression(CelProtoWrapper::CreateMessage(message, &arena_), field, test, "", options); } absl::StatusOr RunExpression(const TestMessage* message, absl::string_view field, bool test, absl::string_view unknown_path, RunExpressionOptions options) { return RunExpression(CelProtoWrapper::CreateMessage(message, &arena_), field, test, unknown_path, options); } absl::StatusOr RunExpression(const TestMessage* message, absl::string_view field, bool test, RunExpressionOptions options) { return RunExpression(message, field, test, "", options); } absl::StatusOr RunExpression(const CelMap* map_value, absl::string_view field, bool test, absl::string_view unknown_path, RunExpressionOptions options) { return RunExpression(CelValue::CreateMap(map_value), field, test, unknown_path, options); } absl::StatusOr RunExpression(const CelMap* map_value, absl::string_view field, bool test, RunExpressionOptions options) { return RunExpression(map_value, field, test, "", options); } protected: absl_nonnull std::shared_ptr env_; google::protobuf::Arena arena_; }; class SelectStepConformanceTest : public SelectStepTest, public testing::WithParamInterface {}; TEST_P(SelectStepConformanceTest, SelectMessageIsNull) { RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(static_cast(nullptr), "bool_value", true, options)); ASSERT_TRUE(result.IsError()); } TEST_P(SelectStepConformanceTest, SelectTargetNotStructOrMap) { RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(CelValue::CreateStringView("some_value"), "some_field", /*test=*/false, /*unknown_path=*/"", options)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Applying SELECT to non-message type"))); } TEST_P(SelectStepConformanceTest, PresenseIsFalseTest) { TestMessage message; RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "bool_value", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } TEST_P(SelectStepConformanceTest, PresenseIsTrueTest) { RunExpressionOptions options; options.enable_unknowns = GetParam(); TestMessage message; message.set_bool_value(true); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "bool_value", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } TEST_P(SelectStepConformanceTest, ExtensionsPresenceIsTrueTest) { TestExtensions exts; TestExtensions* nested = exts.MutableExtension(nested_ext); nested->set_name("nested"); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&exts, "google.api.expr.runtime.nested_ext", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST_P(SelectStepConformanceTest, ExtensionsPresenceIsFalseTest) { TestExtensions exts; RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&exts, "google.api.expr.runtime.nested_ext", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_FALSE(result.BoolOrDie()); } TEST_P(SelectStepConformanceTest, MapPresenseIsFalseTest) { RunExpressionOptions options; options.enable_unknowns = GetParam(); std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}}; auto map_value = CreateContainerBackedMap( absl::Span>(key_values)) .value(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(map_value.get(), "key2", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } TEST_P(SelectStepConformanceTest, MapPresenseIsTrueTest) { RunExpressionOptions options; options.enable_unknowns = GetParam(); std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}}; auto map_value = CreateContainerBackedMap( absl::Span>(key_values)) .value(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(map_value.get(), "key1", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } TEST_F(SelectStepTest, MapPresenseIsErrorTest) { TestMessage message; Expr select_expr; auto& select = select_expr.mutable_select_expr(); select.set_field("1"); select.set_test_only(true); Expr& expr1 = select.mutable_operand(); auto& select_map = expr1.mutable_select_expr(); select_map.set_field("int32_int32_map"); Expr& expr0 = select_map.mutable_operand(); auto& ident = expr0.mutable_ident_expr(); ident.set_name("target"); ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep(ident.name(), expr0.id())); ASSERT_OK_AND_ASSIGN( auto step1, CreateSelectStep(select_map, expr1.id(), /*enable_wrapper_type_null_unboxing=*/false)); ASSERT_OK_AND_ASSIGN( auto step2, CreateSelectStep(select, select_expr.id(), /*enable_wrapper_type_null_unboxing=*/false)); ExecutionPath path; path.push_back(std::move(step0)); path.push_back(std::move(step1)); path.push_back(std::move(step2)); CelExpressionFlatImpl cel_expr( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), cel::RuntimeOptions{})); Activation activation; activation.InsertValue("target", CelProtoWrapper::CreateMessage(&message, &arena_)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); EXPECT_TRUE(result.IsError()); EXPECT_EQ(result.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); } TEST_F(SelectStepTest, MapPresenseIsTrueWithUnknownTest) { UnknownSet unknown_set; std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateUnknownSet(&unknown_set)}}; auto map_value = CreateContainerBackedMap( absl::Span>(key_values)) .value(); RunExpressionOptions options; options.enable_unknowns = true; ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(map_value.get(), "key1", true, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } TEST_P(SelectStepConformanceTest, FieldIsNotPresentInProtoTest) { TestMessage message; RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "fake_field", false, options)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie()->code(), Eq(absl::StatusCode::kNotFound)); } TEST_P(SelectStepConformanceTest, FieldIsNotSetTest) { TestMessage message; RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "bool_value", false, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } TEST_P(SelectStepConformanceTest, SimpleBoolTest) { TestMessage message; message.set_bool_value(true); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "bool_value", false, options)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } TEST_P(SelectStepConformanceTest, SimpleInt32Test) { TestMessage message; message.set_int32_value(1); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "int32_value", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 1); } TEST_P(SelectStepConformanceTest, SimpleInt64Test) { TestMessage message; message.set_int64_value(1); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "int64_value", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 1); } TEST_P(SelectStepConformanceTest, SimpleUInt32Test) { TestMessage message; message.set_uint32_value(1); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "uint32_value", false, options)); ASSERT_TRUE(result.IsUint64()); EXPECT_EQ(result.Uint64OrDie(), 1); } TEST_P(SelectStepConformanceTest, SimpleUint64Test) { TestMessage message; message.set_uint64_value(1); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "uint64_value", false, options)); ASSERT_TRUE(result.IsUint64()); EXPECT_EQ(result.Uint64OrDie(), 1); } TEST_P(SelectStepConformanceTest, SimpleStringTest) { TestMessage message; std::string value = "test"; message.set_string_value(value); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "string_value", false, options)); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.StringOrDie().value(), "test"); } TEST_P(SelectStepConformanceTest, WrapperTypeNullUnboxingEnabledTest) { TestMessage message; message.mutable_string_wrapper_value()->set_value("test"); RunExpressionOptions options; options.enable_unknowns = GetParam(); options.enable_wrapper_type_null_unboxing = true; ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&message, "string_wrapper_value", false, options)); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.StringOrDie().value(), "test"); ASSERT_OK_AND_ASSIGN( result, RunExpression(&message, "int32_wrapper_value", false, options)); EXPECT_TRUE(result.IsNull()); } TEST_P(SelectStepConformanceTest, WrapperTypeNullUnboxingDisabledTest) { TestMessage message; message.mutable_string_wrapper_value()->set_value("test"); RunExpressionOptions options; options.enable_unknowns = GetParam(); options.enable_wrapper_type_null_unboxing = false; ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&message, "string_wrapper_value", false, options)); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.StringOrDie().value(), "test"); ASSERT_OK_AND_ASSIGN( result, RunExpression(&message, "int32_wrapper_value", false, options)); EXPECT_TRUE(result.IsInt64()); } TEST_P(SelectStepConformanceTest, SimpleBytesTest) { TestMessage message; std::string value = "test"; message.set_bytes_value(value); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "bytes_value", false, options)); ASSERT_TRUE(result.IsBytes()); EXPECT_EQ(result.BytesOrDie().value(), "test"); } TEST_P(SelectStepConformanceTest, SimpleMessageTest) { TestMessage message; TestMessage* message2 = message.mutable_message_value(); message2->set_int32_value(1); message2->set_string_value("test"); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "message_value", false, options)); ASSERT_TRUE(result.IsMessage()); EXPECT_THAT(*message2, EqualsProto(*result.MessageOrDie())); } TEST_P(SelectStepConformanceTest, GlobalExtensionsIntTest) { TestExtensions exts; exts.SetExtension(int32_ext, 42); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&exts, "google.api.expr.runtime.int32_ext", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 42L); } TEST_P(SelectStepConformanceTest, GlobalExtensionsMessageTest) { TestExtensions exts; TestExtensions* nested = exts.MutableExtension(nested_ext); nested->set_name("nested"); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&exts, "google.api.expr.runtime.nested_ext", false, options)); ASSERT_TRUE(result.IsMessage()); EXPECT_THAT(result.MessageOrDie(), Eq(nested)); } TEST_P(SelectStepConformanceTest, GlobalExtensionsMessageUnsetTest) { TestExtensions exts; RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&exts, "google.api.expr.runtime.nested_ext", false, options)); ASSERT_TRUE(result.IsMessage()); EXPECT_THAT(result.MessageOrDie(), Eq(&TestExtensions::default_instance())); } TEST_P(SelectStepConformanceTest, GlobalExtensionsWrapperTest) { TestExtensions exts; google::protobuf::Int32Value* wrapper = exts.MutableExtension(int32_wrapper_ext); wrapper->set_value(42); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&exts, "google.api.expr.runtime.int32_wrapper_ext", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(42L)); } TEST_P(SelectStepConformanceTest, GlobalExtensionsWrapperUnsetTest) { TestExtensions exts; RunExpressionOptions options; options.enable_wrapper_type_null_unboxing = true; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&exts, "google.api.expr.runtime.int32_wrapper_ext", false, options)); ASSERT_TRUE(result.IsNull()); } TEST_P(SelectStepConformanceTest, MessageExtensionsEnumTest) { TestExtensions exts; exts.SetExtension(TestMessageExtensions::enum_ext, TestExtEnum::TEST_EXT_1); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&exts, "google.api.expr.runtime.TestMessageExtensions.enum_ext", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestExtEnum::TEST_EXT_1)); } TEST_P(SelectStepConformanceTest, MessageExtensionsRepeatedStringTest) { TestExtensions exts; exts.AddExtension(TestMessageExtensions::repeated_string_exts, "test1"); exts.AddExtension(TestMessageExtensions::repeated_string_exts, "test2"); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression( &exts, "google.api.expr.runtime.TestMessageExtensions.repeated_string_exts", false, options)); ASSERT_TRUE(result.IsList()); const CelList* cel_list = result.ListOrDie(); EXPECT_THAT(cel_list->size(), Eq(2)); } TEST_P(SelectStepConformanceTest, MessageExtensionsRepeatedStringUnsetTest) { TestExtensions exts; RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression( &exts, "google.api.expr.runtime.TestMessageExtensions.repeated_string_exts", false, options)); ASSERT_TRUE(result.IsList()); const CelList* cel_list = result.ListOrDie(); EXPECT_THAT(cel_list->size(), Eq(0)); } TEST_P(SelectStepConformanceTest, NullMessageAccessor) { TestMessage message; TestMessage* message2 = message.mutable_message_value(); message2->set_int32_value(1); message2->set_string_value("test"); RunExpressionOptions options; options.enable_unknowns = GetParam(); CelValue value = CelValue::CreateMessageWrapper( CelValue::MessageWrapper(&message, TrivialTypeInfo::GetInstance())); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(value, "message_value", /*test=*/false, /*unknown_path=*/"", options)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kNotFound)); // same for has ASSERT_OK_AND_ASSIGN(result, RunExpression(value, "message_value", /*test=*/true, /*unknown_path=*/"", options)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kNotFound)); } TEST_P(SelectStepConformanceTest, CustomAccessor) { TestMessage message; TestMessage* message2 = message.mutable_message_value(); message2->set_int32_value(1); message2->set_string_value("test"); RunExpressionOptions options; options.enable_unknowns = GetParam(); testing::NiceMock accessor; CelValue value = CelValue::CreateMessageWrapper( CelValue::MessageWrapper(&message, &accessor)); ON_CALL(accessor, GetField(_, _, _, _)) .WillByDefault(Return(CelValue::CreateInt64(2))); ON_CALL(accessor, HasField(_, _)).WillByDefault(Return(false)); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(value, "message_value", /*test=*/false, /*unknown_path=*/"", options)); EXPECT_THAT(result, test::IsCelInt64(2)); // testonly select (has) ASSERT_OK_AND_ASSIGN(result, RunExpression(value, "message_value", /*test=*/true, /*unknown_path=*/"", options)); EXPECT_THAT(result, test::IsCelBool(false)); } TEST_P(SelectStepConformanceTest, CustomAccessorErrorHandling) { TestMessage message; TestMessage* message2 = message.mutable_message_value(); message2->set_int32_value(1); message2->set_string_value("test"); RunExpressionOptions options; options.enable_unknowns = GetParam(); testing::NiceMock accessor; CelValue value = CelValue::CreateMessageWrapper( CelValue::MessageWrapper(&message, &accessor)); ON_CALL(accessor, GetField(_, _, _, _)) .WillByDefault(Return(absl::InternalError("bad data"))); ON_CALL(accessor, HasField(_, _)) .WillByDefault(Return(absl::NotFoundError("not found"))); // For get field, implementation may return an error-type cel value or a // status (e.g. broken assumption using a core type). ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(value, "message_value", /*test=*/false, /*unknown_path=*/"", options)); EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kInternal))); // testonly select (has) errors are coerced to CelError. ASSERT_OK_AND_ASSIGN(result, RunExpression(value, "message_value", /*test=*/true, /*unknown_path=*/"", options)); EXPECT_THAT(result, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); } TEST_P(SelectStepConformanceTest, SimpleEnumTest) { TestMessage message; message.set_enum_value(TestMessage::TEST_ENUM_1); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "enum_value", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } TEST_P(SelectStepConformanceTest, SimpleListTest) { TestMessage message; message.add_int32_list(1); message.add_int32_list(2); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(&message, "int32_list", false, options)); ASSERT_TRUE(result.IsList()); const CelList* cel_list = result.ListOrDie(); EXPECT_THAT(cel_list->size(), Eq(2)); } TEST_P(SelectStepConformanceTest, SimpleMapTest) { TestMessage message; auto map_field = message.mutable_string_int32_map(); (*map_field)["test0"] = 1; (*map_field)["test1"] = 2; RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN( CelValue result, RunExpression(&message, "string_int32_map", false, options)); ASSERT_TRUE(result.IsMap()); const CelMap* cel_map = result.MapOrDie(); EXPECT_THAT(cel_map->size(), Eq(2)); } TEST_P(SelectStepConformanceTest, MapSimpleInt32Test) { std::string key1 = "key1"; std::string key2 = "key2"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}, {CelValue::CreateString(&key2), CelValue::CreateInt64(2)}}; auto map_value = CreateContainerBackedMap( absl::Span>(key_values)) .value(); RunExpressionOptions options; options.enable_unknowns = GetParam(); ASSERT_OK_AND_ASSIGN(CelValue result, RunExpression(map_value.get(), "key1", false, options)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 1); } // Test Select behavior, when expression to select from is an Error. TEST_P(SelectStepConformanceTest, CelErrorAsArgument) { ExecutionPath path; Expr dummy_expr; auto& select = dummy_expr.mutable_select_expr(); select.set_field("position"); select.set_test_only(false); Expr& expr0 = select.mutable_operand(); auto& ident = expr0.mutable_ident_expr(); ident.set_name("message"); ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep(ident.name(), expr0.id())); ASSERT_OK_AND_ASSIGN( auto step1, CreateSelectStep(select, dummy_expr.id(), /*enable_wrapper_type_null_unboxing=*/false)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); CelError error = absl::CancelledError(); cel::RuntimeOptions options; if (GetParam()) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(*result.ErrorOrDie(), Eq(error)); } TEST_F(SelectStepTest, DisableMissingAttributeOK) { TestMessage message; message.set_bool_value(true); ExecutionPath path; Expr dummy_expr; auto& select = dummy_expr.mutable_select_expr(); select.set_field("bool_value"); select.set_test_only(false); Expr& expr0 = select.mutable_operand(); auto& ident = expr0.mutable_ident_expr(); ident.set_name("message"); ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep(ident.name(), expr0.id())); ASSERT_OK_AND_ASSIGN( auto step1, CreateSelectStep(select, dummy_expr.id(), /*enable_wrapper_type_null_unboxing=*/false)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); CelExpressionFlatImpl cel_expr( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), cel::RuntimeOptions{})); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); CelAttributePattern pattern("message", {}); activation.set_missing_attribute_patterns({pattern}); ASSERT_OK_AND_ASSIGN(result, cel_expr.Evaluate(activation, &arena_)); EXPECT_EQ(result.BoolOrDie(), true); } TEST_F(SelectStepTest, UnrecoverableUnknownValueProducesError) { TestMessage message; message.set_bool_value(true); ExecutionPath path; Expr dummy_expr; auto& select = dummy_expr.mutable_select_expr(); select.set_field("bool_value"); select.set_test_only(false); Expr& expr0 = select.mutable_operand(); auto& ident = expr0.mutable_ident_expr(); ident.set_name("message"); ASSERT_OK_AND_ASSIGN(auto step0, CreateIdentStep(ident.name(), expr0.id())); ASSERT_OK_AND_ASSIGN( auto step1, CreateSelectStep(select, dummy_expr.id(), /*enable_wrapper_type_null_unboxing=*/false)); path.push_back(std::move(step0)); path.push_back(std::move(step1)); cel::RuntimeOptions options; options.enable_missing_attribute_errors = true; CelExpressionFlatImpl cel_expr( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); CelAttributePattern pattern("message", {CreateCelAttributeQualifierPattern( CelValue::CreateStringView("bool_value"))}); activation.set_missing_attribute_patterns({pattern}); ASSERT_OK_AND_ASSIGN(result, cel_expr.Evaluate(activation, &arena_)); EXPECT_THAT(*result.ErrorOrDie(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("MissingAttributeError: message.bool_value"))); } TEST_F(SelectStepTest, UnknownPatternResolvesToUnknown) { TestMessage message; message.set_bool_value(true); ExecutionPath path; Expr dummy_expr; auto& select = dummy_expr.mutable_select_expr(); select.set_field("bool_value"); select.set_test_only(false); Expr& expr0 = select.mutable_operand(); auto& ident = expr0.mutable_ident_expr(); ident.set_name("message"); auto step0_status = CreateIdentStep(ident.name(), expr0.id()); auto step1_status = CreateSelectStep(select, dummy_expr.id(), /*enable_wrapper_type_null_unboxing=*/false); ASSERT_THAT(step0_status, IsOk()); ASSERT_THAT(step1_status, IsOk()); path.push_back(*std::move(step0_status)); path.push_back(*std::move(step1_status)); cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; CelExpressionFlatImpl cel_expr( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), options)); { std::vector unknown_patterns; Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } const std::string kSegmentCorrect1 = "bool_value"; const std::string kSegmentIncorrect = "message_value"; { std::vector unknown_patterns; unknown_patterns.push_back(CelAttributePattern("message", {})); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsUnknownSet()); } { std::vector unknown_patterns; unknown_patterns.push_back(CelAttributePattern( "message", {CreateCelAttributeQualifierPattern( CelValue::CreateString(&kSegmentCorrect1))})); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsUnknownSet()); } { std::vector unknown_patterns; unknown_patterns.push_back(CelAttributePattern( "message", {CelAttributeQualifierPattern::CreateWildcard()})); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsUnknownSet()); } { std::vector unknown_patterns; unknown_patterns.push_back(CelAttributePattern( "message", {CreateCelAttributeQualifierPattern( CelValue::CreateString(&kSegmentIncorrect))})); Activation activation; activation.InsertValue("message", CelProtoWrapper::CreateMessage(&message, &arena_)); activation.set_unknown_attribute_patterns(unknown_patterns); ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr.Evaluate(activation, &arena_)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } } INSTANTIATE_TEST_SUITE_P(UnknownsEnabled, SelectStepConformanceTest, testing::Bool()); class DirectSelectStepTest : public testing::Test { public: DirectSelectStepTest() : type_provider_(cel::internal::GetTestingDescriptorPool()) {} cel::Value TestWrapMessage(const google::protobuf::Message* message) { CelValue value = CelProtoWrapper::CreateMessage(message, &arena_); auto result = cel::interop_internal::FromLegacyValue(&arena_, value); ABSL_DCHECK_OK(result.status()); return std::move(result).value(); } std::vector AttributeStrings(const UnknownValue& v) { std::vector result; for (const Attribute& attr : v.attribute_set()) { auto attr_str = attr.AsString(); ABSL_DCHECK_OK(attr_str.status()); result.push_back(std::move(attr_str).value()); } return result; } protected: google::protobuf::Arena arena_; cel::runtime_internal::RuntimeTypeProvider type_provider_; }; TEST_F(DirectSelectStepTest, SelectFromMap) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep( CreateDirectIdentStep("map_val", -1), cel::StringValue("one"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true); auto map_builder = cel::NewMapValueBuilder(&arena_); ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); activation.InsertOrAssignValue("map_val", std::move(*map_builder).Build()); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_EQ(Cast(result).NativeValue(), 1); } TEST_F(DirectSelectStepTest, HasMap) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep( CreateDirectIdentStep("map_val", -1), cel::StringValue("two"), /*test_only=*/true, -1, /*enable_wrapper_type_null_unboxing=*/true); auto map_builder = cel::NewMapValueBuilder(&arena_); ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); activation.InsertOrAssignValue("map_val", std::move(*map_builder).Build()); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_TRUE(Cast(result).NativeValue()); } TEST_F(DirectSelectStepTest, SelectFromOptionalMap) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), cel::StringValue("one"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true, /*enable_optional_types=*/true); auto map_builder = cel::NewMapValueBuilder(&arena_); ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); activation.InsertOrAssignValue( "map_val", OptionalValue::Of(std::move(*map_builder).Build(), &arena_)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(static_cast(result)).Value(), IntValueIs(1)); } TEST_F(DirectSelectStepTest, SelectFromOptionalMapAbsent) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), cel::StringValue("three"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true, /*enable_optional_types=*/true); auto map_builder = cel::NewMapValueBuilder(&arena_); ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); activation.InsertOrAssignValue( "map_val", OptionalValue::Of(std::move(*map_builder).Build(), &arena_)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_FALSE( Cast(static_cast(result)).HasValue()); } TEST_F(DirectSelectStepTest, SelectFromOptionalStruct) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("struct_val", -1), cel::StringValue("single_int64"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true, /*enable_optional_types=*/true); TestAllTypes message; message.set_single_int64(1); ASSERT_OK_AND_ASSIGN( Value struct_val, ProtoMessageToValue(std::move(message), cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_)); activation.InsertOrAssignValue("struct_val", OptionalValue::Of(struct_val, &arena_)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(static_cast(result)).Value(), IntValueIs(1)); } TEST_F(DirectSelectStepTest, SelectFromOptionalStructFieldNotSet) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("struct_val", -1), cel::StringValue("single_string"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true, /*enable_optional_types=*/true); TestAllTypes message; message.set_single_int64(1); ASSERT_OK_AND_ASSIGN( Value struct_val, ProtoMessageToValue(std::move(message), cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_)); activation.InsertOrAssignValue("struct_val", OptionalValue::Of(struct_val, &arena_)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_FALSE( Cast(static_cast(result)).HasValue()); } TEST_F(DirectSelectStepTest, SelectFromEmptyOptional) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), cel::StringValue("one"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true, /*enable_optional_types=*/true); activation.InsertOrAssignValue("map_val", OptionalValue::None()); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_FALSE( cel::Cast(static_cast(result)).HasValue()); } TEST_F(DirectSelectStepTest, HasOptional) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), cel::StringValue("two"), /*test_only=*/true, -1, /*enable_wrapper_type_null_unboxing=*/true, /*enable_optional_types=*/true); auto map_builder = cel::NewMapValueBuilder(&arena_); ASSERT_THAT(map_builder->Put(cel::StringValue("one"), IntValue(1)), IsOk()); ASSERT_THAT(map_builder->Put(cel::StringValue("two"), IntValue(2)), IsOk()); activation.InsertOrAssignValue( "map_val", OptionalValue::Of(std::move(*map_builder).Build(), &arena_)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_TRUE(Cast(result).NativeValue()); } TEST_F(DirectSelectStepTest, HasEmptyOptional) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("map_val", -1), cel::StringValue("two"), /*test_only=*/true, -1, /*enable_wrapper_type_null_unboxing=*/true, /*enable_optional_types=*/true); activation.InsertOrAssignValue("map_val", OptionalValue::None()); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_FALSE(Cast(result).NativeValue()); } TEST_F(DirectSelectStepTest, SelectFromStruct) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), cel::StringValue("single_int64"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true); TestAllTypes message; message.set_single_int64(1); activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_EQ(Cast(result).NativeValue(), 1); } TEST_F(DirectSelectStepTest, HasStruct) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), cel::StringValue("single_string"), /*test_only=*/true, -1, /*enable_wrapper_type_null_unboxing=*/true); TestAllTypes message; message.set_single_int64(1); activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; // has(test_all_types.single_string) ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_FALSE(Cast(result).NativeValue()); } TEST_F(DirectSelectStepTest, SelectFromUnsupportedType) { cel::Activation activation; RuntimeOptions options; auto step = CreateDirectSelectStep( CreateDirectIdentStep("bool_val", -1), cel::StringValue("one"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true); activation.InsertOrAssignValue("bool_val", BoolValue(false)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Applying SELECT to non-message type"))); } TEST_F(DirectSelectStepTest, AttributeUpdatedIfRequested) { cel::Activation activation; RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; auto step = CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), cel::StringValue("single_int64"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true); TestAllTypes message; message.set_single_int64(1); activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_EQ(Cast(result).NativeValue(), 1); ASSERT_OK_AND_ASSIGN(std::string attr_str, attr.attribute().AsString()); EXPECT_EQ(attr_str, "test_all_types.single_int64"); } TEST_F(DirectSelectStepTest, MissingAttributesToErrors) { cel::Activation activation; RuntimeOptions options; options.enable_missing_attribute_errors = true; auto step = CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), cel::StringValue("single_int64"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true); TestAllTypes message; message.set_single_int64(1); activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); activation.SetMissingPatterns({cel::AttributePattern( "test_all_types", {cel::AttributeQualifierPattern::OfString("single_int64")})}); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("test_all_types.single_int64"))); } TEST_F(DirectSelectStepTest, IdentifiesUnknowns) { cel::Activation activation; RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; auto step = CreateDirectSelectStep(CreateDirectIdentStep("test_all_types", -1), cel::StringValue("single_int64"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true); TestAllTypes message; message.set_single_int64(1); activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); activation.SetUnknownPatterns({cel::AttributePattern( "test_all_types", {cel::AttributeQualifierPattern::OfString("single_int64")})}); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(AttributeStrings(Cast(result)), UnorderedElementsAre("test_all_types.single_int64")); } TEST_F(DirectSelectStepTest, ForwardErrorValue) { cel::Activation activation; RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; auto step = CreateDirectSelectStep( CreateConstValueDirectStep(cel::ErrorValue(absl::InternalError("test1")), -1), cel::StringValue("single_int64"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kInternal, HasSubstr("test1"))); } TEST_F(DirectSelectStepTest, ForwardUnknownOperand) { cel::Activation activation; RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; AttributeSet attr_set({Attribute("attr", {AttributeQualifier::OfInt(0)})}); auto step = CreateDirectSelectStep( CreateConstValueDirectStep( cel::UnknownValue(cel::Unknown(std::move(attr_set))), -1), cel::StringValue("single_int64"), /*test_only=*/false, -1, /*enable_wrapper_type_null_unboxing=*/true); TestAllTypes message; message.set_single_int64(1); activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); ExecutionFrameBase frame(activation, options, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); Value result; AttributeTrail attr; ASSERT_THAT(step->Evaluate(frame, result, attr), IsOk()); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(AttributeStrings(Cast(result)), UnorderedElementsAre("attr[0]")); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/shadowable_value_step.cc ================================================ #include "eval/eval/shadowable_value_step.h" #include #include #include #include #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { using ::cel::Value; class ShadowableValueStep : public ExpressionStepBase { public: ShadowableValueStep(std::string identifier, cel::Value value, int64_t expr_id) : ExpressionStepBase(expr_id), identifier_(std::move(identifier)), value_(std::move(value)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: std::string identifier_; Value value_; }; absl::Status ShadowableValueStep::Evaluate(ExecutionFrame* frame) const { cel::Value result; CEL_ASSIGN_OR_RETURN(auto found, frame->modern_activation().FindVariable( identifier_, frame->descriptor_pool(), frame->message_factory(), frame->arena(), &result)); if (found) { frame->value_stack().Push(std::move(result)); } else { frame->value_stack().Push(value_); } return absl::OkStatus(); } class DirectShadowableValueStep : public DirectExpressionStep { public: DirectShadowableValueStep(std::string identifier, cel::Value value, int64_t expr_id) : DirectExpressionStep(expr_id), identifier_(std::move(identifier)), value_(std::move(value)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override; private: std::string identifier_; Value value_; }; // TODO(uncreated-issue/67): Attribute tracking is skipped for the shadowed case. May // cause problems for users with unknown tracking and variables named like // 'list' etc, but follows the current behavior of the stack machine version. absl::Status DirectShadowableValueStep::Evaluate( ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { CEL_ASSIGN_OR_RETURN(auto found, frame.activation().FindVariable( identifier_, frame.descriptor_pool(), frame.message_factory(), frame.arena(), &result)); if (!found) { result = value_; } return absl::OkStatus(); } } // namespace absl::StatusOr> CreateShadowableValueStep( absl::string_view name, cel::Value value, int64_t expr_id) { return absl::make_unique(std::string(name), std::move(value), expr_id); } std::unique_ptr CreateDirectShadowableValueStep( absl::string_view name, cel::Value value, int64_t expr_id) { return std::make_unique(std::string(name), std::move(value), expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/shadowable_value_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ #include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/value.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Create an identifier resolution step with a default value that may be // shadowed by an identifier of the same name within the runtime-provided // Activation. absl::StatusOr> CreateShadowableValueStep( absl::string_view name, cel::Value value, int64_t expr_id); std::unique_ptr CreateDirectShadowableValueStep( absl::string_view name, cel::Value value, int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ ================================================ FILE: eval/eval/shadowable_value_step_test.cc ================================================ #include "eval/eval/shadowable_value_step.h" #include #include #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "base/type_provider.h" #include "common/value.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/evaluator_core.h" #include "eval/internal/interop.h" #include "eval/public/activation.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { namespace { using ::cel::TypeProvider; using ::cel::interop_internal::CreateTypeValueFromView; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::testing::Eq; absl::StatusOr RunShadowableExpression( const absl_nonnull std::shared_ptr& env, std::string identifier, cel::Value value, const Activation& activation, Arena* arena) { CEL_ASSIGN_OR_RETURN( auto step, CreateShadowableValueStep(std::move(identifier), std::move(value), 1)); ExecutionPath path; path.push_back(std::move(step)); CelExpressionFlatImpl impl( env, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env->type_registry.GetComposedTypeProvider(), cel::RuntimeOptions{})); return impl.Evaluate(activation, arena); } TEST(ShadowableValueStepTest, TestEvaluateNoShadowing) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); std::string type_name = "google.api.expr.runtime.TestMessage"; Activation activation; Arena arena; auto type_value = CreateTypeValueFromView(&arena, type_name); auto status = RunShadowableExpression(env, type_name, type_value, activation, &arena); ASSERT_OK(status); auto value = status.value(); ASSERT_TRUE(value.IsCelType()); EXPECT_THAT(value.CelTypeOrDie().value(), Eq(type_name)); } TEST(ShadowableValueStepTest, TestEvaluateShadowedIdentifier) { absl_nonnull std::shared_ptr env = NewTestingRuntimeEnv(); std::string type_name = "int"; auto shadow_value = CelValue::CreateInt64(1024L); Activation activation; activation.InsertValue(type_name, shadow_value); Arena arena; auto type_value = CreateTypeValueFromView(&arena, type_name); auto status = RunShadowableExpression(env, type_name, type_value, activation, &arena); ASSERT_OK(status); auto value = status.value(); ASSERT_TRUE(value.IsInt64()); EXPECT_THAT(value.Int64OrDie(), Eq(1024L)); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/ternary_step.cc ================================================ #include "eval/eval/ternary_step.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/builtins.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { using ::cel::builtin::kTernary; using ::cel::runtime_internal::CreateNoMatchingOverloadError; inline constexpr size_t kTernaryStepCondition = 0; inline constexpr size_t kTernaryStepTrue = 1; inline constexpr size_t kTernaryStepFalse = 2; class ExhaustiveDirectTernaryStep : public DirectExpressionStep { public: ExhaustiveDirectTernaryStep(std::unique_ptr condition, std::unique_ptr left, std::unique_ptr right, int64_t expr_id) : DirectExpressionStep(expr_id), condition_(std::move(condition)), left_(std::move(left)), right_(std::move(right)) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& attribute) const override { cel::Value condition; cel::Value lhs; cel::Value rhs; AttributeTrail condition_attr; AttributeTrail lhs_attr; AttributeTrail rhs_attr; CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); CEL_RETURN_IF_ERROR(left_->Evaluate(frame, lhs, lhs_attr)); CEL_RETURN_IF_ERROR(right_->Evaluate(frame, rhs, rhs_attr)); if (condition.IsError() || condition.IsUnknown()) { result = std::move(condition); attribute = std::move(condition_attr); return absl::OkStatus(); } if (!condition.IsBool()) { result = cel::ErrorValue(CreateNoMatchingOverloadError(kTernary)); return absl::OkStatus(); } if (condition.GetBool().NativeValue()) { result = std::move(lhs); attribute = std::move(lhs_attr); } else { result = std::move(rhs); attribute = std::move(rhs_attr); } return absl::OkStatus(); } private: std::unique_ptr condition_; std::unique_ptr left_; std::unique_ptr right_; }; class ShortcircuitingDirectTernaryStep : public DirectExpressionStep { public: ShortcircuitingDirectTernaryStep( std::unique_ptr condition, std::unique_ptr left, std::unique_ptr right, int64_t expr_id) : DirectExpressionStep(expr_id), condition_(std::move(condition)), left_(std::move(left)), right_(std::move(right)) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& attribute) const override { cel::Value condition; AttributeTrail condition_attr; CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); if (condition.IsError() || condition.IsUnknown()) { result = std::move(condition); attribute = std::move(condition_attr); return absl::OkStatus(); } if (!condition.IsBool()) { result = cel::ErrorValue(CreateNoMatchingOverloadError(kTernary)); return absl::OkStatus(); } if (condition.GetBool().NativeValue()) { return left_->Evaluate(frame, result, attribute); } return right_->Evaluate(frame, result, attribute); } private: std::unique_ptr condition_; std::unique_ptr left_; std::unique_ptr right_; }; class TernaryStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. explicit TernaryStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} absl::Status Evaluate(ExecutionFrame* frame) const override; }; absl::Status TernaryStep::Evaluate(ExecutionFrame* frame) const { // Must have 3 or more values on the stack. if (!frame->value_stack().HasEnough(3)) { return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } // Create Span object that contains input arguments to the function. auto args = frame->value_stack().GetSpan(3); const auto& condition = args[kTernaryStepCondition]; // As opposed to regular functions, ternary treats unknowns or errors on the // condition (arg0) as blocking. If we get an error or unknown then we // ignore the other arguments and forward the condition as the result. if (frame->enable_unknowns()) { // Check if unknown? if (condition.IsUnknown()) { frame->value_stack().Pop(2); return absl::OkStatus(); } } if (condition.IsError()) { frame->value_stack().Pop(2); return absl::OkStatus(); } cel::Value result; if (!condition.IsBool()) { result = cel::ErrorValue(CreateNoMatchingOverloadError(kTernary)); } else if (condition.GetBool().NativeValue()) { result = args[kTernaryStepTrue]; } else { result = args[kTernaryStepFalse]; } frame->value_stack().PopAndPush(args.size(), std::move(result)); return absl::OkStatus(); } } // namespace // Factory method for ternary (_?_:_) recursive execution step std::unique_ptr CreateDirectTernaryStep( std::unique_ptr condition, std::unique_ptr left, std::unique_ptr right, int64_t expr_id, bool shortcircuiting) { if (shortcircuiting) { return std::make_unique( std::move(condition), std::move(left), std::move(right), expr_id); } return std::make_unique( std::move(condition), std::move(left), std::move(right), expr_id); } absl::StatusOr> CreateTernaryStep( int64_t expr_id) { return std::make_unique(expr_id); } } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/ternary_step.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_TERNARY_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_TERNARY_STEP_H_ #include #include #include "absl/status/statusor.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Factory method for ternary (_?_:_) recursive execution step std::unique_ptr CreateDirectTernaryStep( std::unique_ptr condition, std::unique_ptr left, std::unique_ptr right, int64_t expr_id, bool shortcircuiting = true); // Factory method for ternary (_?_:_) execution step absl::StatusOr> CreateTernaryStep( int64_t expr_id); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_TERNARY_STEP_H_ ================================================ FILE: eval/eval/ternary_step_test.cc ================================================ #include "eval/eval/ternary_step.h" #include #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "base/attribute.h" #include "base/attribute_set.h" #include "base/type_provider.h" #include "common/casting.h" #include "common/expr.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/const_value_step.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/activation.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/internal/runtime_type_provider.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::BoolValue; using ::cel::Cast; using ::cel::ErrorValue; using ::cel::Expr; using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::RuntimeOptions; using ::cel::TypeProvider; using ::cel::UnknownValue; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::google::protobuf::Arena; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Truly; class LogicStepTest : public testing::TestWithParam { public: LogicStepTest() : env_(NewTestingRuntimeEnv()) {} absl::Status EvaluateLogic(CelValue arg0, CelValue arg1, CelValue arg2, CelValue* result, bool enable_unknown) { ExecutionPath path; CEL_ASSIGN_OR_RETURN(auto step, CreateIdentStep("name0", /*expr_id=*/-1)); path.push_back(std::move(step)); CEL_ASSIGN_OR_RETURN(step, CreateIdentStep("name1", /*expr_id=*/-1)); path.push_back(std::move(step)); CEL_ASSIGN_OR_RETURN(step, CreateIdentStep("name2", /*expr_id=*/-1)); path.push_back(std::move(step)); CEL_ASSIGN_OR_RETURN(step, CreateTernaryStep(4)); path.push_back(std::move(step)); cel::RuntimeOptions options; if (enable_unknown) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl impl( env_, FlatExpression(std::move(path), /*comprehension_slot_count=*/0, env_->type_registry.GetComposedTypeProvider(), options)); Activation activation; std::string value("test"); activation.InsertValue("name0", arg0); activation.InsertValue("name1", arg1); activation.InsertValue("name2", arg2); auto status0 = impl.Evaluate(activation, &arena_); if (!status0.ok()) return status0.status(); *result = status0.value(); return absl::OkStatus(); } private: absl_nonnull std::shared_ptr env_; Arena arena_; }; TEST_P(LogicStepTest, TestBoolCond) { CelValue result; absl::Status status = EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), CelValue::CreateBool(false), &result, GetParam()); ASSERT_OK(status); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); status = EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(true), CelValue::CreateBool(false), &result, GetParam()); ASSERT_OK(status); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } TEST_P(LogicStepTest, TestErrorHandling) { CelValue result; CelError error = absl::CancelledError(); CelValue error_value = CelValue::CreateError(&error); ASSERT_OK(EvaluateLogic(error_value, CelValue::CreateBool(true), CelValue::CreateBool(false), &result, GetParam())); ASSERT_TRUE(result.IsError()); ASSERT_OK(EvaluateLogic(CelValue::CreateBool(true), error_value, CelValue::CreateBool(false), &result, GetParam())); ASSERT_TRUE(result.IsError()); ASSERT_OK(EvaluateLogic(CelValue::CreateBool(false), error_value, CelValue::CreateBool(false), &result, GetParam())); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); } TEST_F(LogicStepTest, TestUnknownHandling) { CelValue result; UnknownSet unknown_set; CelError cel_error = absl::CancelledError(); CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); CelValue error_value = CelValue::CreateError(&cel_error); ASSERT_OK(EvaluateLogic(unknown_value, CelValue::CreateBool(true), CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); ASSERT_OK(EvaluateLogic(CelValue::CreateBool(true), unknown_value, CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); ASSERT_OK(EvaluateLogic(CelValue::CreateBool(false), unknown_value, CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.BoolOrDie()); ASSERT_OK(EvaluateLogic(error_value, unknown_value, CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsError()); ASSERT_OK(EvaluateLogic(unknown_value, error_value, CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); Expr expr0; auto& ident_expr0 = expr0.mutable_ident_expr(); ident_expr0.set_name("name0"); Expr expr1; auto& ident_expr1 = expr1.mutable_ident_expr(); ident_expr1.set_name("name1"); CelAttribute attr0(expr0.ident_expr().name(), {}), attr1(expr1.ident_expr().name(), {}); UnknownAttributeSet unknown_attr_set0({attr0}); UnknownAttributeSet unknown_attr_set1({attr1}); UnknownSet unknown_set0(unknown_attr_set0); UnknownSet unknown_set1(unknown_attr_set1); EXPECT_THAT(unknown_attr_set0.size(), Eq(1)); EXPECT_THAT(unknown_attr_set1.size(), Eq(1)); ASSERT_OK(EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), CelValue::CreateUnknownSet(&unknown_set1), CelValue::CreateBool(false), &result, true)); ASSERT_TRUE(result.IsUnknownSet()); const auto& attrs = result.UnknownSetOrDie()->unknown_attributes(); ASSERT_THAT(attrs, testing::SizeIs(1)); EXPECT_THAT(attrs.begin()->variable_name(), Eq("name0")); } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); class TernaryStepDirectTest : public testing::TestWithParam { public: TernaryStepDirectTest() : type_provider_(cel::internal::GetTestingDescriptorPool()) {} bool Shortcircuiting() { return GetParam(); } protected: Arena arena_; cel::runtime_internal::RuntimeTypeProvider type_provider_; }; TEST_P(TernaryStepDirectTest, ReturnLhs) { cel::Activation activation; RuntimeOptions opts; ExecutionFrameBase frame(activation, opts, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectTernaryStep( CreateConstValueDirectStep(BoolValue(true), -1), CreateConstValueDirectStep(IntValue(1), -1), CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); cel::Value result; AttributeTrail attr_unused; ASSERT_OK(step->Evaluate(frame, result, attr_unused)); ASSERT_TRUE(InstanceOf(result)); EXPECT_EQ(Cast(result).NativeValue(), 1); } TEST_P(TernaryStepDirectTest, ReturnRhs) { cel::Activation activation; RuntimeOptions opts; ExecutionFrameBase frame(activation, opts, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectTernaryStep( CreateConstValueDirectStep(BoolValue(false), -1), CreateConstValueDirectStep(IntValue(1), -1), CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); cel::Value result; AttributeTrail attr_unused; ASSERT_OK(step->Evaluate(frame, result, attr_unused)); ASSERT_TRUE(InstanceOf(result)); EXPECT_EQ(Cast(result).NativeValue(), 2); } TEST_P(TernaryStepDirectTest, ForwardError) { cel::Activation activation; RuntimeOptions opts; ExecutionFrameBase frame(activation, opts, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); cel::Value error_value = cel::ErrorValue(absl::InternalError("test error")); std::unique_ptr step = CreateDirectTernaryStep( CreateConstValueDirectStep(error_value, -1), CreateConstValueDirectStep(IntValue(1), -1), CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); cel::Value result; AttributeTrail attr_unused; ASSERT_OK(step->Evaluate(frame, result, attr_unused)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kInternal, "test error")); } TEST_P(TernaryStepDirectTest, ForwardUnknown) { cel::Activation activation; RuntimeOptions opts; opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; ExecutionFrameBase frame(activation, opts, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::vector attrs{{cel::Attribute("var")}}; cel::UnknownValue unknown_value = cel::UnknownValue(cel::Unknown(cel::AttributeSet(attrs))); std::unique_ptr step = CreateDirectTernaryStep( CreateConstValueDirectStep(unknown_value, -1), CreateConstValueDirectStep(IntValue(2), -1), CreateConstValueDirectStep(IntValue(3), -1), -1, Shortcircuiting()); cel::Value result; AttributeTrail attr_unused; ASSERT_OK(step->Evaluate(frame, result, attr_unused)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue().unknown_attributes(), ElementsAre(Truly([](const cel::Attribute& attr) { return attr.variable_name() == "var"; }))); } TEST_P(TernaryStepDirectTest, UnexpectedCondtionKind) { cel::Activation activation; RuntimeOptions opts; ExecutionFrameBase frame(activation, opts, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectTernaryStep( CreateConstValueDirectStep(IntValue(-1), -1), CreateConstValueDirectStep(IntValue(1), -1), CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); cel::Value result; AttributeTrail attr_unused; ASSERT_OK(step->Evaluate(frame, result, attr_unused)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), StatusIs(absl::StatusCode::kUnknown, HasSubstr("No matching overloads found"))); } TEST_P(TernaryStepDirectTest, Shortcircuiting) { class RecordCallStep : public DirectExpressionStep { public: explicit RecordCallStep(bool& was_called) : DirectExpressionStep(-1), was_called_(&was_called) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& trail) const override { *was_called_ = true; result = IntValue(1); return absl::OkStatus(); } private: bool* absl_nonnull was_called_; }; bool lhs_was_called = false; bool rhs_was_called = false; cel::Activation activation; RuntimeOptions opts; ExecutionFrameBase frame(activation, opts, type_provider_, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); std::unique_ptr step = CreateDirectTernaryStep( CreateConstValueDirectStep(BoolValue(false), -1), std::make_unique(lhs_was_called), std::make_unique(rhs_was_called), -1, Shortcircuiting()); cel::Value result; AttributeTrail attr_unused; ASSERT_OK(step->Evaluate(frame, result, attr_unused)); ASSERT_TRUE(InstanceOf(result)); EXPECT_THAT(Cast(result).NativeValue(), Eq(1)); bool expect_eager_eval = !Shortcircuiting(); EXPECT_EQ(lhs_was_called, expect_eager_eval); EXPECT_TRUE(rhs_was_called); } INSTANTIATE_TEST_SUITE_P(TernaryStepDirectTest, TernaryStepDirectTest, testing::Bool()); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/eval/trace_step.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ #include #include #include #include "absl/status/status.h" #include "absl/types/optional.h" #include "common/native_type.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { // A decorator that implements tracing for recursively evaluated CEL // expressions. // // Allows inspection for extensions to extract the wrapped expression. class TraceStep : public DirectExpressionStep { public: explicit TraceStep(std::unique_ptr expression) : DirectExpressionStep(-1), expression_(std::move(expression)) {} absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, AttributeTrail& trail) const override { CEL_RETURN_IF_ERROR(expression_->Evaluate(frame, result, trail)); if (!frame.callback()) { return absl::OkStatus(); } return frame.callback()(expression_->expr_id(), result, frame.descriptor_pool(), frame.message_factory(), frame.arena()); } cel::NativeTypeId GetNativeTypeId() const override { return cel::NativeTypeId::For(); } absl::optional> GetDependencies() const override { return {{expression_.get()}}; } absl::optional>> ExtractDependencies() override { std::vector> dependencies; dependencies.push_back(std::move(expression_)); return dependencies; }; private: std::unique_ptr expression_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ ================================================ FILE: eval/internal/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "interop", hdrs = ["interop.h"], deps = ["//common:legacy_value"], ) cc_library( name = "cel_value_equal", srcs = ["cel_value_equal.cc"], hdrs = ["cel_value_equal.h"], deps = [ "//common:kind", "//eval/public:cel_number", "//eval/public:cel_value", "//eval/public:message_wrapper", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_info_apis", "//internal:number", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "cel_value_equal_test", srcs = ["cel_value_equal_test.cc"], deps = [ ":cel_value_equal", "//eval/public:cel_value", "//eval/public:message_wrapper", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/public/structs:trivial_legacy_type_info", "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "errors", srcs = ["errors.cc"], hdrs = ["errors.h"], deps = [ "//runtime/internal:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "adapter_activation_impl", srcs = ["adapter_activation_impl.cc"], hdrs = ["adapter_activation_impl.h"], deps = [ ":interop", "//base:attributes", "//common:value", "//eval/public:base_activation", "//eval/public:cel_value", "//internal:status_macros", "//runtime:activation_interface", "//runtime:function_overload_reference", "//runtime/internal:activation_attribute_matcher_access", "//runtime/internal:attribute_matcher", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: eval/internal/adapter_activation_impl.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/internal/adapter_activation_impl.h" #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/value.h" #include "eval/internal/interop.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "runtime/function_overload_reference.h" #include "runtime/internal/activation_attribute_matcher_access.h" #include "runtime/internal/attribute_matcher.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::interop_internal { using ::google::api::expr::runtime::CelFunction; absl::StatusOr AdapterActivationImpl::FindVariable( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { // This implementation should only be used during interop, when we can // always assume the memory manager is backed by a protobuf arena. absl::optional legacy_value = legacy_activation_.FindValue(name, arena); if (!legacy_value.has_value()) { return false; } CEL_RETURN_IF_ERROR(ModernValue(arena, *legacy_value, *result)); return true; } std::vector AdapterActivationImpl::FindFunctionOverloads(absl::string_view name) const ABSL_ATTRIBUTE_LIFETIME_BOUND { std::vector legacy_candidates = legacy_activation_.FindFunctionOverloads(name); std::vector result; result.reserve(legacy_candidates.size()); for (const auto* candidate : legacy_candidates) { if (candidate == nullptr) { continue; } result.push_back({candidate->descriptor(), *candidate}); } return result; } absl::Span AdapterActivationImpl::GetUnknownAttributes() const { return legacy_activation_.unknown_attribute_patterns(); } absl::Span AdapterActivationImpl::GetMissingAttributes() const { return legacy_activation_.missing_attribute_patterns(); } const runtime_internal::AttributeMatcher* absl_nullable AdapterActivationImpl::GetAttributeMatcher() const { return runtime_internal::ActivationAttributeMatcherAccess:: GetAttributeMatcher(legacy_activation_); } } // namespace cel::interop_internal ================================================ FILE: eval/internal/adapter_activation_impl.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ADAPTER_ACTIVATION_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ADAPTER_ACTIVATION_IMPL_H_ #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/value.h" #include "eval/public/base_activation.h" #include "runtime/activation_interface.h" #include "runtime/function_overload_reference.h" #include "runtime/internal/attribute_matcher.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::interop_internal { // An Activation implementation that adapts the legacy version (based on // expr::CelValue) to the new cel::Handle based version. This implementation // must be scoped to an evaluation. class AdapterActivationImpl : public ActivationInterface { public: explicit AdapterActivationImpl( const google::api::expr::runtime::BaseActivation& legacy_activation) : legacy_activation_(legacy_activation) {} absl::StatusOr FindVariable( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override; std::vector FindFunctionOverloads( absl::string_view name) const override; absl::Span GetUnknownAttributes() const override; absl::Span GetMissingAttributes() const override; private: const runtime_internal::AttributeMatcher* absl_nullable GetAttributeMatcher() const override; const google::api::expr::runtime::BaseActivation& legacy_activation_; }; } // namespace cel::interop_internal #endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ADAPTER_ACTIVATION_IMPL_H_ ================================================ FILE: eval/internal/cel_value_equal.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/internal/cel_value_equal.h" #include #include "absl/time/time.h" #include "absl/types/optional.h" #include "common/kind.h" #include "eval/public/cel_number.h" #include "eval/public/cel_value.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "internal/number.h" #include "google/protobuf/arena.h" namespace cel::interop_internal { namespace { using ::cel::internal::Number; using ::google::api::expr::runtime::CelList; using ::google::api::expr::runtime::CelMap; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::GetNumberFromCelValue; using ::google::api::expr::runtime::LegacyTypeAccessApis; using ::google::api::expr::runtime::LegacyTypeInfoApis; // Forward declaration of the functors for generic equality operator. // Equal defined between compatible types. struct HeterogeneousEqualProvider { absl::optional operator()(const CelValue& lhs, const CelValue& rhs) const; }; // Comparison template functions template absl::optional Inequal(Type lhs, Type rhs) { return lhs != rhs; } template absl::optional Equal(Type lhs, Type rhs) { return lhs == rhs; } // Equality for lists. Template parameter provides either heterogeneous or // homogenous equality for comparing members. template absl::optional ListEqual(const CelList* t1, const CelList* t2) { if (t1 == t2) { return true; } int index_size = t1->size(); if (t2->size() != index_size) { return false; } google::protobuf::Arena arena; for (int i = 0; i < index_size; i++) { CelValue e1 = (*t1).Get(&arena, i); CelValue e2 = (*t2).Get(&arena, i); absl::optional eq = EqualsProvider()(e1, e2); if (eq.has_value()) { if (!(*eq)) { return false; } } else { // Propagate that the equality is undefined. return eq; } } return true; } // Equality for maps. Template parameter provides either heterogeneous or // homogenous equality for comparing values. template absl::optional MapEqual(const CelMap* t1, const CelMap* t2) { if (t1 == t2) { return true; } if (t1->size() != t2->size()) { return false; } google::protobuf::Arena arena; auto list_keys = t1->ListKeys(&arena); if (!list_keys.ok()) { return absl::nullopt; } const CelList* keys = *list_keys; for (int i = 0; i < keys->size(); i++) { CelValue key = (*keys).Get(&arena, i); CelValue v1 = (*t1).Get(&arena, key).value(); absl::optional v2 = (*t2).Get(&arena, key); if (!v2.has_value()) { auto number = GetNumberFromCelValue(key); if (!number.has_value()) { return false; } if (!key.IsInt64() && number->LosslessConvertibleToInt()) { CelValue int_key = CelValue::CreateInt64(number->AsInt()); absl::optional eq = EqualsProvider()(key, int_key); if (eq.has_value() && *eq) { v2 = (*t2).Get(&arena, int_key); } } if (!key.IsUint64() && !v2.has_value() && number->LosslessConvertibleToUint()) { CelValue uint_key = CelValue::CreateUint64(number->AsUint()); absl::optional eq = EqualsProvider()(key, uint_key); if (eq.has_value() && *eq) { v2 = (*t2).Get(&arena, uint_key); } } } if (!v2.has_value()) { return false; } absl::optional eq = EqualsProvider()(v1, *v2); if (!eq.has_value() || !*eq) { // Shortcircuit on value comparison errors and 'false' results. return eq; } } return true; } bool MessageEqual(const CelValue::MessageWrapper& m1, const CelValue::MessageWrapper& m2) { const LegacyTypeInfoApis* lhs_type_info = m1.legacy_type_info(); const LegacyTypeInfoApis* rhs_type_info = m2.legacy_type_info(); if (lhs_type_info->GetTypename(m1) != rhs_type_info->GetTypename(m2)) { return false; } const LegacyTypeAccessApis* accessor = lhs_type_info->GetAccessApis(m1); if (accessor == nullptr) { return false; } return accessor->IsEqualTo(m1, m2); } // Generic equality for CEL values of the same type. // EqualityProvider is used for equality among members of container types. template absl::optional HomogenousCelValueEqual(const CelValue& t1, const CelValue& t2) { if (t1.type() != t2.type()) { return absl::nullopt; } switch (t1.type()) { case Kind::kNullType: return Equal(CelValue::NullType(), CelValue::NullType()); case Kind::kBool: return Equal(t1.BoolOrDie(), t2.BoolOrDie()); case Kind::kInt64: return Equal(t1.Int64OrDie(), t2.Int64OrDie()); case Kind::kUint64: return Equal(t1.Uint64OrDie(), t2.Uint64OrDie()); case Kind::kDouble: return Equal(t1.DoubleOrDie(), t2.DoubleOrDie()); case Kind::kString: return Equal(t1.StringOrDie(), t2.StringOrDie()); case Kind::kBytes: return Equal(t1.BytesOrDie(), t2.BytesOrDie()); case Kind::kDuration: return Equal(t1.DurationOrDie(), t2.DurationOrDie()); case Kind::kTimestamp: return Equal(t1.TimestampOrDie(), t2.TimestampOrDie()); case Kind::kList: return ListEqual(t1.ListOrDie(), t2.ListOrDie()); case Kind::kMap: return MapEqual(t1.MapOrDie(), t2.MapOrDie()); case Kind::kCelType: return Equal(t1.CelTypeOrDie(), t2.CelTypeOrDie()); default: break; } return absl::nullopt; } absl::optional HeterogeneousEqualProvider::operator()( const CelValue& lhs, const CelValue& rhs) const { return CelValueEqualImpl(lhs, rhs); } } // namespace // Equal operator is defined for all types at plan time. Runtime delegates to // the correct implementation for types or returns nullopt if the comparison // isn't defined. absl::optional CelValueEqualImpl(const CelValue& v1, const CelValue& v2) { if (v1.type() == v2.type()) { // Message equality is only defined if heterogeneous comparisons are enabled // to preserve the legacy behavior for equality. if (CelValue::MessageWrapper lhs, rhs; v1.GetValue(&lhs) && v2.GetValue(&rhs)) { return MessageEqual(lhs, rhs); } return HomogenousCelValueEqual(v1, v2); } absl::optional lhs = GetNumberFromCelValue(v1); absl::optional rhs = GetNumberFromCelValue(v2); if (rhs.has_value() && lhs.has_value()) { return *lhs == *rhs; } // TODO(uncreated-issue/6): It's currently possible for the interpreter to create a // map containing an Error. Return no matching overload to propagate an error // instead of a false result. if (v1.IsError() || v1.IsUnknownSet() || v2.IsError() || v2.IsUnknownSet()) { return absl::nullopt; } return false; } } // namespace cel::interop_internal ================================================ FILE: eval/internal/cel_value_equal.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_CEL_VALUE_EQUAL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_CEL_VALUE_EQUAL_H_ #include "absl/types/optional.h" #include "eval/public/cel_value.h" namespace cel::interop_internal { // Implementation for general equality between CELValues. Exposed for // consistent behavior in set membership functions. // // Returns nullopt if the comparison is undefined between differently typed // values. absl::optional CelValueEqualImpl( const google::api::expr::runtime::CelValue& v1, const google::api::expr::runtime::CelValue& v2); } // namespace cel::interop_internal #endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_CEL_VALUE_EQUAL_H_ ================================================ FILE: eval/internal/cel_value_equal_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/internal/cel_value_equal.h" #include #include #include #include #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace cel::interop_internal { namespace { using ::google::api::expr::runtime::CelList; using ::google::api::expr::runtime::CelMap; using ::google::api::expr::runtime::CelProtoWrapper; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::ContainerBackedListImpl; using ::google::api::expr::runtime::CreateContainerBackedMap; using ::google::api::expr::runtime::MessageWrapper; using ::google::api::expr::runtime::TestMessage; using ::google::api::expr::runtime::TrivialTypeInfo; using ::testing::_; using ::testing::Combine; using ::testing::Optional; using ::testing::Values; using ::testing::ValuesIn; struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; absl::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; bool IsNumeric(CelValue::Type type) { return type == CelValue::Type::kDouble || type == CelValue::Type::kInt64 || type == CelValue::Type::kUint64; } const CelList& CelListExample1() { static ContainerBackedListImpl* example = new ContainerBackedListImpl({CelValue::CreateInt64(1)}); return *example; } const CelList& CelListExample2() { static ContainerBackedListImpl* example = new ContainerBackedListImpl({CelValue::CreateInt64(2)}); return *example; } const CelMap& CelMapExample1() { static CelMap* example = []() { std::vector> values{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}}; // Implementation copies values into a hash map. auto map = CreateContainerBackedMap(absl::MakeSpan(values)); return map->release(); }(); return *example; } const CelMap& CelMapExample2() { static CelMap* example = []() { std::vector> values{ {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; auto map = CreateContainerBackedMap(absl::MakeSpan(values)); return map->release(); }(); return *example; } const std::vector& ValueExamples1() { static std::vector* examples = []() { google::protobuf::Arena arena; auto result = std::make_unique>(); result->push_back(CelValue::CreateNull()); result->push_back(CelValue::CreateBool(false)); result->push_back(CelValue::CreateInt64(1)); result->push_back(CelValue::CreateUint64(1)); result->push_back(CelValue::CreateDouble(1.0)); result->push_back(CelValue::CreateStringView("string")); result->push_back(CelValue::CreateBytesView("bytes")); // No arena allocs expected in this example. result->push_back(CelProtoWrapper::CreateMessage( std::make_unique().release(), &arena)); result->push_back(CelValue::CreateDuration(absl::Seconds(1))); result->push_back(CelValue::CreateTimestamp(absl::FromUnixSeconds(1))); result->push_back(CelValue::CreateList(&CelListExample1())); result->push_back(CelValue::CreateMap(&CelMapExample1())); result->push_back(CelValue::CreateCelTypeView("type")); return result.release(); }(); return *examples; } const std::vector& ValueExamples2() { static std::vector* examples = []() { google::protobuf::Arena arena; auto result = std::make_unique>(); auto message2 = std::make_unique(); message2->set_int64_value(2); result->push_back(CelValue::CreateNull()); result->push_back(CelValue::CreateBool(true)); result->push_back(CelValue::CreateInt64(2)); result->push_back(CelValue::CreateUint64(2)); result->push_back(CelValue::CreateDouble(2.0)); result->push_back(CelValue::CreateStringView("string2")); result->push_back(CelValue::CreateBytesView("bytes2")); // No arena allocs expected in this example. result->push_back( CelProtoWrapper::CreateMessage(message2.release(), &arena)); result->push_back(CelValue::CreateDuration(absl::Seconds(2))); result->push_back(CelValue::CreateTimestamp(absl::FromUnixSeconds(2))); result->push_back(CelValue::CreateList(&CelListExample2())); result->push_back(CelValue::CreateMap(&CelMapExample2())); result->push_back(CelValue::CreateCelTypeView("type2")); return result.release(); }(); return *examples; } class CelValueEqualImplTypesTest : public testing::TestWithParam> { public: CelValueEqualImplTypesTest() = default; const CelValue& lhs() { return std::get<0>(GetParam()); } const CelValue& rhs() { return std::get<1>(GetParam()); } bool should_be_equal() { return std::get<2>(GetParam()); } }; std::string CelValueEqualTestName( const testing::TestParamInfo>& test_case) { return absl::StrCat(CelValue::TypeName(std::get<0>(test_case.param).type()), CelValue::TypeName(std::get<1>(test_case.param).type()), (std::get<2>(test_case.param)) ? "Equal" : "Inequal"); } TEST_P(CelValueEqualImplTypesTest, Basic) { absl::optional result = CelValueEqualImpl(lhs(), rhs()); if (lhs().IsNull() || rhs().IsNull()) { if (lhs().IsNull() && rhs().IsNull()) { EXPECT_THAT(result, Optional(true)); } else { EXPECT_THAT(result, Optional(false)); } } else if (lhs().type() == rhs().type() || (IsNumeric(lhs().type()) && IsNumeric(rhs().type()))) { EXPECT_THAT(result, Optional(should_be_equal())); } else { EXPECT_THAT(result, Optional(false)); } } INSTANTIATE_TEST_SUITE_P(EqualityBetweenTypes, CelValueEqualImplTypesTest, Combine(ValuesIn(ValueExamples1()), ValuesIn(ValueExamples1()), Values(true)), &CelValueEqualTestName); INSTANTIATE_TEST_SUITE_P(InequalityBetweenTypes, CelValueEqualImplTypesTest, Combine(ValuesIn(ValueExamples1()), ValuesIn(ValueExamples2()), Values(false)), &CelValueEqualTestName); struct NumericInequalityTestCase { std::string name; CelValue a; CelValue b; }; const std::vector& NumericValuesNotEqualExample() { static std::vector* examples = []() { auto result = std::make_unique>(); result->push_back({"NegativeIntAndUint", CelValue::CreateInt64(-1), CelValue::CreateUint64(2)}); result->push_back( {"IntAndLargeUint", CelValue::CreateInt64(1), CelValue::CreateUint64( static_cast(std::numeric_limits::max()) + 1)}); result->push_back( {"IntAndLargeDouble", CelValue::CreateInt64(2), CelValue::CreateDouble( static_cast(std::numeric_limits::max()) + 1025)}); result->push_back( {"IntAndSmallDouble", CelValue::CreateInt64(2), CelValue::CreateDouble( static_cast(std::numeric_limits::lowest()) - 1025)}); result->push_back( {"UintAndLargeDouble", CelValue::CreateUint64(2), CelValue::CreateDouble( static_cast(std::numeric_limits::max()) + 2049)}); result->push_back({"NegativeDoubleAndUint", CelValue::CreateDouble(-2.0), CelValue::CreateUint64(123)}); // NaN tests. result->push_back({"NanAndDouble", CelValue::CreateDouble(NAN), CelValue::CreateDouble(1.0)}); result->push_back({"NanAndNan", CelValue::CreateDouble(NAN), CelValue::CreateDouble(NAN)}); result->push_back({"DoubleAndNan", CelValue::CreateDouble(1.0), CelValue::CreateDouble(NAN)}); result->push_back( {"IntAndNan", CelValue::CreateInt64(1), CelValue::CreateDouble(NAN)}); result->push_back( {"NanAndInt", CelValue::CreateDouble(NAN), CelValue::CreateInt64(1)}); result->push_back( {"UintAndNan", CelValue::CreateUint64(1), CelValue::CreateDouble(NAN)}); result->push_back( {"NanAndUint", CelValue::CreateDouble(NAN), CelValue::CreateUint64(1)}); return result.release(); }(); return *examples; } using NumericInequalityTest = testing::TestWithParam; TEST_P(NumericInequalityTest, NumericValues) { NumericInequalityTestCase test_case = GetParam(); absl::optional result = CelValueEqualImpl(test_case.a, test_case.b); EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, false); } INSTANTIATE_TEST_SUITE_P( InequalityBetweenNumericTypesTest, NumericInequalityTest, ValuesIn(NumericValuesNotEqualExample()), [](const testing::TestParamInfo& info) { return info.param.name; }); TEST(CelValueEqualImplTest, LossyNumericEquality) { absl::optional result = CelValueEqualImpl( CelValue::CreateDouble( static_cast(std::numeric_limits::max()) - 1), CelValue::CreateInt64(std::numeric_limits::max())); EXPECT_TRUE(result.has_value()); EXPECT_TRUE(*result); } TEST(CelValueEqualImplTest, ListMixedTypesInequal) { ContainerBackedListImpl lhs({CelValue::CreateInt64(1)}); ContainerBackedListImpl rhs({CelValue::CreateStringView("abc")}); EXPECT_THAT( CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), Optional(false)); } TEST(CelValueEqualImplTest, NestedList) { ContainerBackedListImpl inner_lhs({CelValue::CreateInt64(1)}); ContainerBackedListImpl lhs({CelValue::CreateList(&inner_lhs)}); ContainerBackedListImpl inner_rhs({CelValue::CreateNull()}); ContainerBackedListImpl rhs({CelValue::CreateList(&inner_rhs)}); EXPECT_THAT( CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), Optional(false)); } TEST(CelValueEqualImplTest, MapMixedValueTypesInequal) { std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; std::vector> rhs_data{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}}; ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, CreateContainerBackedMap(absl::MakeSpan(lhs_data))); ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), CelValue::CreateMap(rhs.get())), Optional(false)); } TEST(CelValueEqualImplTest, MapMixedKeyTypesEqual) { std::vector> lhs_data{ {CelValue::CreateUint64(1), CelValue::CreateStringView("abc")}}; std::vector> rhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, CreateContainerBackedMap(absl::MakeSpan(lhs_data))); ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), CelValue::CreateMap(rhs.get())), Optional(true)); } TEST(CelValueEqualImplTest, MapMixedKeyTypesInequal) { std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; std::vector> rhs_data{ {CelValue::CreateInt64(2), CelValue::CreateInt64(2)}}; ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, CreateContainerBackedMap(absl::MakeSpan(lhs_data))); ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), CelValue::CreateMap(rhs.get())), Optional(false)); } TEST(CelValueEqualImplTest, NestedMaps) { std::vector> inner_lhs_data{ {CelValue::CreateInt64(2), CelValue::CreateStringView("abc")}}; ASSERT_OK_AND_ASSIGN( std::unique_ptr inner_lhs, CreateContainerBackedMap(absl::MakeSpan(inner_lhs_data))); std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateMap(inner_lhs.get())}}; std::vector> inner_rhs_data{ {CelValue::CreateInt64(2), CelValue::CreateNull()}}; ASSERT_OK_AND_ASSIGN( std::unique_ptr inner_rhs, CreateContainerBackedMap(absl::MakeSpan(inner_rhs_data))); std::vector> rhs_data{ {CelValue::CreateInt64(1), CelValue::CreateMap(inner_rhs.get())}}; ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, CreateContainerBackedMap(absl::MakeSpan(lhs_data))); ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), CelValue::CreateMap(rhs.get())), Optional(false)); } TEST(CelValueEqualImplTest, ProtoEqualityDifferingTypenameInequal) { // If message wrappers report a different typename, treat as inequal without // calling into the provided equal implementation. google::protobuf::Arena arena; TestMessage example; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( int32_value: 1 uint32_value: 2 string_value: "test" )", &example)); CelValue lhs = CelProtoWrapper::CreateMessage(&example, &arena); CelValue rhs = CelValue::CreateMessageWrapper( MessageWrapper(&example, TrivialTypeInfo::GetInstance())); EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); } TEST(CelValueEqualImplTest, ProtoEqualityNoAccessorInequal) { // If message wrappers report no access apis, then treat as inequal. TestMessage example; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( int32_value: 1 uint32_value: 2 string_value: "test" )", &example)); CelValue lhs = CelValue::CreateMessageWrapper( MessageWrapper(&example, TrivialTypeInfo::GetInstance())); CelValue rhs = CelValue::CreateMessageWrapper( MessageWrapper(&example, TrivialTypeInfo::GetInstance())); EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); } TEST(CelValueEqualImplTest, ProtoEqualityAny) { google::protobuf::Arena arena; TestMessage packed_value; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( int32_value: 1 uint32_value: 2 string_value: "test" )", &packed_value)); TestMessage lhs; lhs.mutable_any_value()->PackFrom(packed_value); TestMessage rhs; rhs.mutable_any_value()->PackFrom(packed_value); EXPECT_THAT(CelValueEqualImpl(CelProtoWrapper::CreateMessage(&lhs, &arena), CelProtoWrapper::CreateMessage(&rhs, &arena)), Optional(true)); // Equality falls back to bytewise comparison if type is missing. lhs.mutable_any_value()->clear_type_url(); rhs.mutable_any_value()->clear_type_url(); EXPECT_THAT(CelValueEqualImpl(CelProtoWrapper::CreateMessage(&lhs, &arena), CelProtoWrapper::CreateMessage(&rhs, &arena)), Optional(true)); } // Add transitive dependencies in appropriate order for the dynamic descriptor // pool. // Return false if the dependencies could not be added to the pool. bool AddDepsToPool(const google::protobuf::FileDescriptor* descriptor, google::protobuf::DescriptorPool& pool) { for (int i = 0; i < descriptor->dependency_count(); i++) { if (!AddDepsToPool(descriptor->dependency(i), pool)) { return false; } } google::protobuf::FileDescriptorProto descriptor_proto; descriptor->CopyTo(&descriptor_proto); return pool.BuildFile(descriptor_proto) != nullptr; } // Equivalent descriptors managed by separate descriptor pools are not equal, so // the underlying messages are not considered equal. TEST(CelValueEqualImplTest, DynamicDescriptorAndGeneratedInequal) { // Simulate a dynamically loaded descriptor that happens to match the // compiled version. google::protobuf::DescriptorPool pool; google::protobuf::DynamicMessageFactory factory; google::protobuf::Arena arena; factory.SetDelegateToGeneratedFactory(false); ASSERT_TRUE(AddDepsToPool(TestMessage::descriptor()->file(), pool)); TestMessage example_message; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(R"pb( int64_value: 12345 bool_list: false bool_list: true message_value { float_value: 1.0 } )pb", &example_message)); // Messages from a loaded descriptor and generated versions can't be compared // via MessageDifferencer, so return false. std::unique_ptr example_dynamic_message( factory .GetPrototype(pool.FindMessageTypeByName( TestMessage::descriptor()->full_name())) ->New()); ASSERT_TRUE(example_dynamic_message->ParseFromString( example_message.SerializeAsString())); EXPECT_THAT(CelValueEqualImpl( CelProtoWrapper::CreateMessage(&example_message, &arena), CelProtoWrapper::CreateMessage(example_dynamic_message.get(), &arena)), Optional(false)); } TEST(CelValueEqualImplTest, DynamicMessageAndMessageEqual) { google::protobuf::DynamicMessageFactory factory; google::protobuf::Arena arena; factory.SetDelegateToGeneratedFactory(false); TestMessage example_message; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(R"pb( int64_value: 12345 bool_list: false bool_list: true message_value { float_value: 1.0 } )pb", &example_message)); // Dynamic message and generated Message subclass with the same generated // descriptor are comparable. std::unique_ptr example_dynamic_message( factory.GetPrototype(TestMessage::descriptor())->New()); ASSERT_TRUE(example_dynamic_message->ParseFromString( example_message.SerializeAsString())); EXPECT_THAT(CelValueEqualImpl( CelProtoWrapper::CreateMessage(&example_message, &arena), CelProtoWrapper::CreateMessage(example_dynamic_message.get(), &arena)), Optional(true)); } } // namespace } // namespace cel::interop_internal ================================================ FILE: eval/internal/errors.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/internal/errors.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "runtime/internal/errors.h" #include "google/protobuf/arena.h" namespace cel { namespace interop_internal { using ::google::protobuf::Arena; const absl::Status* CreateNoMatchingOverloadError(google::protobuf::Arena* arena, absl::string_view fn) { return Arena::Create( arena, runtime_internal::CreateNoMatchingOverloadError(fn)); } const absl::Status* CreateNoSuchFieldError(google::protobuf::Arena* arena, absl::string_view field) { return Arena::Create( arena, runtime_internal::CreateNoSuchFieldError(field)); } const absl::Status* CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key) { return Arena::Create( arena, runtime_internal::CreateNoSuchKeyError(key)); } const absl::Status* CreateMissingAttributeError( google::protobuf::Arena* arena, absl::string_view missing_attribute_path) { return Arena::Create( arena, runtime_internal::CreateMissingAttributeError(missing_attribute_path)); } const absl::Status* CreateUnknownFunctionResultError( google::protobuf::Arena* arena, absl::string_view help_message) { return Arena::Create( arena, runtime_internal::CreateUnknownFunctionResultError(help_message)); } const absl::Status* CreateError(google::protobuf::Arena* arena, absl::string_view message, absl::StatusCode code) { return Arena::Create(arena, code, message); } } // namespace interop_internal } // namespace cel ================================================ FILE: eval/internal/errors.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Factories and constants for well-known CEL errors. #ifndef THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ERRORS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ERRORS_H_ #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "runtime/internal/errors.h" // IWYU pragma: export #include "google/protobuf/arena.h" namespace cel { namespace interop_internal { // Factories for interop error values. // const pointer Results are arena allocated to support interop with cel::Handle // and expr::runtime::CelValue. const absl::Status* CreateNoMatchingOverloadError(google::protobuf::Arena* arena, absl::string_view fn); const absl::Status* CreateNoSuchFieldError(google::protobuf::Arena* arena, absl::string_view field); const absl::Status* CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key); const absl::Status* CreateUnknownValueError(google::protobuf::Arena* arena, absl::string_view unknown_path); const absl::Status* CreateMissingAttributeError( google::protobuf::Arena* arena, absl::string_view missing_attribute_path); const absl::Status* CreateUnknownFunctionResultError( google::protobuf::Arena* arena, absl::string_view help_message); const absl::Status* CreateError( google::protobuf::Arena* arena, absl::string_view message, absl::StatusCode code = absl::StatusCode::kUnknown); } // namespace interop_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_ERRORS_H_ ================================================ FILE: eval/internal/interop.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_INTEROP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_INTEROP_H_ #include "common/legacy_value.h" // IWYU pragma: export #endif // THIRD_PARTY_CEL_CPP_EVAL_INTERNAL_INTEROP_H_ ================================================ FILE: eval/public/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) package_group( name = "ast_visibility", packages = [ "//eval/compiler", "//extensions", ], ) licenses(["notice"]) exports_files(["LICENSE"]) cc_library( name = "message_wrapper", hdrs = [ "message_wrapper.h", ], deps = [ "//base/internal:message_wrapper", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/numeric:bits", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "message_wrapper_test", srcs = [ "message_wrapper_test.cc", ], deps = [ ":message_wrapper", "//eval/public/structs:trivial_legacy_type_info", "//eval/testutil:test_message_cc_proto", "//internal:casts", "//internal:testing", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_value_internal", hdrs = [ "cel_value_internal.h", ], deps = [ ":message_wrapper", "//internal:casts", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_value", srcs = [ "cel_value.cc", ], hdrs = [ "cel_value.h", ], deps = [ ":cel_value_internal", ":message_wrapper", ":unknown_set", "//common:kind", "//common:memory", "//common:native_type", "//eval/internal:errors", "//eval/public/structs:legacy_type_info_apis", "//extensions/protobuf:memory_manager", "//internal:casts", "//internal:status_macros", "//internal:utf8", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_attribute", srcs = [ "cel_attribute.cc", ], hdrs = [ "cel_attribute.h", ], deps = [ ":cel_value", "//base:attributes", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "cel_value_producer", hdrs = [ "cel_value_producer.h", ], deps = [":cel_value"], ) cc_library( name = "unknown_attribute_set", hdrs = [ "unknown_attribute_set.h", ], deps = ["//base:attributes"], ) cc_library( name = "activation", srcs = [ "activation.cc", ], hdrs = [ "activation.h", ], deps = [ ":base_activation", ":cel_attribute", ":cel_function", ":cel_value", ":cel_value_producer", "//runtime/internal:attribute_matcher", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "activation_bind_helper", srcs = [ "activation_bind_helper.cc", ], hdrs = [ "activation_bind_helper.h", ], deps = [ ":activation", "//eval/public/containers:field_access", "//eval/public/containers:field_backed_list_impl", "//eval/public/containers:field_backed_map_impl", "@com_google_absl//absl/status", ], ) cc_library( name = "cel_function", srcs = [ "cel_function.cc", ], hdrs = [ "cel_function.h", ], deps = [ ":cel_value", "//common:function_descriptor", "//common:value", "//eval/internal:interop", "//internal:status_macros", "//runtime:function", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_function_adapter_impl", hdrs = [ "cel_function_adapter_impl.h", ], deps = [ ":cel_function", ":cel_function_registry", ":cel_value", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) cc_library( name = "cel_function_adapter", hdrs = [ "cel_function_adapter.h", ], deps = [ ":cel_function_adapter_impl", ":cel_value", "//eval/public/structs:cel_proto_wrapper", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "portable_cel_function_adapter", hdrs = [ "portable_cel_function_adapter.h", ], deps = [":cel_function_adapter"], ) cc_library( name = "cel_builtins", hdrs = [ "cel_builtins.h", ], deps = [ "//base:builtins", ], ) cc_library( name = "builtin_func_registrar", srcs = [ "builtin_func_registrar.cc", ], hdrs = [ "builtin_func_registrar.h", ], deps = [ ":cel_function_registry", ":cel_options", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", "//runtime/standard:arithmetic_functions", "//runtime/standard:comparison_functions", "//runtime/standard:container_functions", "//runtime/standard:container_membership_functions", "//runtime/standard:equality_functions", "//runtime/standard:logical_functions", "//runtime/standard:regex_functions", "//runtime/standard:string_functions", "//runtime/standard:time_functions", "//runtime/standard:type_conversion_functions", "@com_google_absl//absl/status", ], ) cc_library( name = "comparison_functions", srcs = [ "comparison_functions.cc", ], hdrs = [ "comparison_functions.h", ], deps = [ ":cel_function_registry", ":cel_options", "//runtime:function_registry", "//runtime:runtime_options", "//runtime/standard:comparison_functions", "@com_google_absl//absl/status", ], ) cc_test( name = "comparison_functions_test", size = "small", srcs = [ "comparison_functions_test.cc", ], deps = [ ":activation", ":cel_expr_builder_factory", ":cel_expression", ":cel_function_registry", ":cel_options", ":cel_value", ":comparison_functions", "//eval/public/testing:matchers", "//internal:status_macros", "//internal:testing", "//parser", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "equality_function_registrar", srcs = [ "equality_function_registrar.cc", ], hdrs = [ "equality_function_registrar.h", ], deps = [ ":cel_function_registry", ":cel_options", "//eval/internal:cel_value_equal", "//runtime:runtime_options", "//runtime/standard:equality_functions", "@com_google_absl//absl/status", ], ) cc_test( name = "equality_function_registrar_test", size = "small", srcs = [ "equality_function_registrar_test.cc", ], deps = [ ":activation", ":cel_builtins", ":cel_expr_builder_factory", ":cel_expression", ":cel_function_registry", ":cel_options", ":cel_value", ":equality_function_registrar", ":message_wrapper", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/public/structs:trivial_legacy_type_info", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:benchmark", "//internal:status_macros", "//internal:testing", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "container_function_registrar", srcs = [ "container_function_registrar.cc", ], hdrs = [ "container_function_registrar.h", ], deps = [ ":cel_function_registry", ":cel_options", "//runtime:runtime_options", "//runtime/standard:container_functions", "@com_google_absl//absl/status", ], ) cc_test( name = "container_function_registrar_test", size = "small", srcs = [ "container_function_registrar_test.cc", ], deps = [ ":activation", ":cel_expr_builder_factory", ":cel_expression", ":cel_value", ":container_function_registrar", ":equality_function_registrar", "//eval/public/containers:container_backed_list_impl", "//eval/public/testing:matchers", "//internal:testing", "//parser", ], ) cc_library( name = "logical_function_registrar", srcs = [ "logical_function_registrar.cc", ], hdrs = [ "logical_function_registrar.h", ], deps = [ ":cel_function_registry", ":cel_options", "//runtime/standard:logical_functions", "@com_google_absl//absl/status", ], ) cc_test( name = "logical_function_registrar_test", size = "small", srcs = [ "logical_function_registrar_test.cc", ], deps = [ ":activation", ":cel_expr_builder_factory", ":cel_expression", ":cel_options", ":cel_value", ":logical_function_registrar", ":portable_cel_function_adapter", "//eval/public/testing:matchers", "//internal:testing", "//parser", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "extension_func_registrar", srcs = [ "extension_func_registrar.cc", ], hdrs = [ "extension_func_registrar.h", ], deps = [ ":cel_function", ":cel_function_adapter", ":cel_function_registry", ":cel_value", "//eval/public/structs:cel_proto_wrapper", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googleapis//google/type:timeofday_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_expression", hdrs = [ "cel_expression.h", ], deps = [ ":base_activation", ":cel_function_registry", ":cel_type_registry", ":cel_value", "//common:legacy_value", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "source_position", srcs = ["source_position.cc"], hdrs = ["source_position.h"], deps = [ "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "ast_visitor", hdrs = [ "ast_visitor.h", ], deps = [ ":source_position", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "ast_visitor_base", hdrs = [ "ast_visitor_base.h", ], deps = [ ":ast_visitor", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "ast_traverse", srcs = [ "ast_traverse.cc", ], hdrs = [ "ast_traverse.h", ], deps = [ ":ast_visitor", ":source_position", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "cel_options", srcs = [ "cel_options.cc", ], hdrs = [ "cel_options.h", ], deps = [ "//runtime:runtime_options", "@com_google_absl//absl/base:core_headers", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_expr_builder_factory", srcs = [ "cel_expr_builder_factory.cc", ], hdrs = [ "cel_expr_builder_factory.h", ], deps = [ ":cel_expression", ":cel_function", ":cel_options", "//common:kind", "//common:memory", "//eval/compiler:cel_expression_builder_flat_impl", "//eval/compiler:comprehension_vulnerability_check", "//eval/compiler:constant_folding", "//eval/compiler:flat_expr_builder", "//eval/compiler:qualified_reference_resolver", "//eval/compiler:regex_precompilation_optimization", "//extensions:select_optimization", "//internal:noop_delete", "//runtime:runtime_options", "//runtime/internal:runtime_env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "value_export_util", srcs = [ "value_export_util.cc", ], hdrs = [ "value_export_util.h", ], deps = [ ":cel_value", "//internal:proto_time_encoding", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:json_util", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:time_util", ], ) cc_library( name = "cel_function_registry", srcs = ["cel_function_registry.cc"], hdrs = ["cel_function_registry.h"], deps = [ ":cel_function", ":cel_options", ":cel_value", "//common:function_descriptor", "//common:kind", "//common:value", "//eval/internal:interop", "//internal:status_macros", "//runtime:function", "//runtime:function_overload_reference", "//runtime:function_registry", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "cel_value_test", size = "small", srcs = [ "cel_value_test.cc", ], deps = [ ":cel_value", ":unknown_set", "//common:memory", "//eval/internal:errors", "//eval/public/structs:trivial_legacy_type_info", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "cel_attribute_test", size = "small", srcs = [ "cel_attribute_test.cc", ], deps = [ ":cel_attribute", ":cel_value", "//eval/public/structs:cel_proto_wrapper", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "activation_test", size = "small", srcs = [ "activation_test.cc", ], deps = [ ":activation", ":cel_attribute", ":cel_function", "//eval/eval:attribute_trail", "//eval/eval:ident_step", "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", "//parser", ], ) cc_test( name = "ast_traverse_test", srcs = [ "ast_traverse_test.cc", ], deps = [ ":ast_traverse", ":ast_visitor", "//internal:testing", ], ) cc_library( name = "ast_rewrite", srcs = [ "ast_rewrite.cc", ], hdrs = [ "ast_rewrite.h", ], deps = [ ":ast_visitor", ":source_position", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "ast_rewrite_test", srcs = [ "ast_rewrite_test.cc", ], deps = [ ":ast_rewrite", ":ast_visitor", ":source_position", "//internal:testing", "//parser", "//testutil:util", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "activation_bind_helper_test", size = "small", srcs = [ "activation_bind_helper_test.cc", ], deps = [ ":activation", ":activation_bind_helper", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", "//testutil:util", "@com_google_absl//absl/status", ], ) cc_test( name = "cel_function_registry_test", srcs = [ "cel_function_registry_test.cc", ], deps = [ ":activation", ":cel_function", ":cel_function_registry", "//common:kind", "//eval/internal:adapter_activation_impl", "//internal:testing", "//runtime:function_overload_reference", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_test( name = "cel_function_adapter_test", size = "small", srcs = [ "cel_function_adapter_test.cc", ], deps = [ ":cel_function_adapter", "//internal:status_macros", "//internal:testing", ], ) cc_library( name = "cel_type_registry", srcs = ["cel_type_registry.cc"], hdrs = ["cel_type_registry.h"], deps = [ "//base:data", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_provider", "//eval/public/structs:protobuf_descriptor_type_provider", "//runtime:type_registry", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "cel_type_registry_test", srcs = ["cel_type_registry_test.cc"], deps = [ ":cel_type_registry", "//base:data", "//common:memory", "//common:type", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_provider", "//internal:testing", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) cc_test( name = "cel_type_registry_protobuf_reflection_test", srcs = ["cel_type_registry_protobuf_reflection_test.cc"], deps = [ ":cel_type_registry", "//common:memory", "//common:type", "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "builtin_func_test", size = "small", srcs = [ "builtin_func_test.cc", ], deps = [ ":activation", ":builtin_func_registrar", ":cel_builtins", ":cel_expr_builder_factory", ":cel_function_registry", ":cel_options", ":cel_value", "//eval/public/structs:cel_proto_wrapper", "//internal:status_macros", "//internal:testing", "//internal:time", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "extension_func_test", size = "small", srcs = [ "extension_func_test.cc", ], deps = [ ":builtin_func_registrar", ":cel_expr_builder_factory", ":cel_expression", ":cel_function_registry", ":cel_value", ":extension_func_registrar", "//eval/public/structs:cel_proto_wrapper", "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_googleapis//google/type:timeofday_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:time_util", ], ) cc_test( name = "source_position_test", size = "small", srcs = [ "source_position_test.cc", ], deps = [ ":source_position", "//internal:testing", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "unknown_attribute_set_test", size = "small", srcs = [ "unknown_attribute_set_test.cc", ], deps = [ ":cel_attribute", ":cel_value", ":unknown_attribute_set", "//internal:testing", ], ) cc_test( name = "value_export_util_test", size = "small", srcs = [ "value_export_util_test.cc", ], deps = [ ":value_export_util", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", "//testutil:util", "@com_google_absl//absl/strings", ], ) cc_library( name = "unknown_function_result_set", srcs = ["unknown_function_result_set.cc"], hdrs = ["unknown_function_result_set.h"], deps = [ "//base:function_result", "//base:function_result_set", ], ) cc_test( name = "unknown_function_result_set_test", size = "small", srcs = [ "unknown_function_result_set_test.cc", ], deps = [ ":cel_function", ":cel_value", ":unknown_function_result_set", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//internal:testing", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_library( name = "unknown_set", hdrs = ["unknown_set.h"], deps = [ ":unknown_attribute_set", ":unknown_function_result_set", "//base/internal:unknown_set", ], ) cc_test( name = "unknown_set_test", srcs = ["unknown_set_test.cc"], deps = [ ":cel_attribute", ":cel_function", ":unknown_attribute_set", ":unknown_function_result_set", ":unknown_set", "//internal:testing", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "transform_utility", srcs = [ "transform_utility.cc", ], hdrs = [ "transform_utility.h", ], deps = [ ":cel_value", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//internal:proto_time_encoding", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_library( name = "set_util", srcs = ["set_util.cc"], hdrs = ["set_util.h"], deps = [":cel_value"], ) cc_library( name = "base_activation", hdrs = ["base_activation.h"], deps = [ ":cel_attribute", ":cel_function", ":cel_value", "//runtime/internal:attribute_matcher", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/strings", "@com_google_protobuf//:field_mask_cc_proto", ], ) cc_test( name = "set_util_test", size = "small", srcs = [ "set_util_test.cc", ], deps = [ ":cel_value", ":set_util", ":unknown_set", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "builtin_func_registrar_test", srcs = ["builtin_func_registrar_test.cc"], deps = [ ":activation", ":builtin_func_registrar", ":cel_expr_builder_factory", ":cel_expression", ":cel_options", ":cel_value", "//eval/public/testing:matchers", "//internal:testing", "//internal:time", "//parser", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_number", srcs = ["cel_number.cc"], hdrs = ["cel_number.h"], deps = [ ":cel_value", "//internal:number", "@com_google_absl//absl/types:optional", ], ) cc_test( name = "cel_number_test", srcs = ["cel_number_test.cc"], deps = [ ":cel_number", ":cel_value", "//internal:testing", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "string_extension_func_registrar", srcs = ["string_extension_func_registrar.cc"], hdrs = ["string_extension_func_registrar.h"], deps = [ ":cel_function_registry", ":cel_options", "//extensions:strings", "@com_google_absl//absl/status", ], ) cc_test( name = "string_extension_func_registrar_test", srcs = ["string_extension_func_registrar_test.cc"], deps = [ ":builtin_func_registrar", ":cel_function_registry", ":cel_value", ":string_extension_func_registrar", "//eval/public/containers:container_backed_list_impl", "//internal:testing", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: eval/public/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: eval/public/activation.cc ================================================ #include "eval/public/activation.h" #include #include #include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "eval/public/cel_function.h" namespace google { namespace api { namespace expr { namespace runtime { absl::optional Activation::FindValue(absl::string_view name, google::protobuf::Arena* arena) const { auto entry = value_map_.find(name); // No entry found. if (entry == value_map_.end()) { return {}; } return entry->second.RetrieveValue(arena); } absl::Status Activation::InsertFunction(std::unique_ptr function) { auto& overloads = function_map_[function->descriptor().name()]; for (const auto& overload : overloads) { if (overload->descriptor().ShapeMatches(function->descriptor())) { return absl::InvalidArgumentError( "Function with same shape already defined in activation"); } } overloads.emplace_back(std::move(function)); return absl::OkStatus(); } std::vector Activation::FindFunctionOverloads( absl::string_view name) const { const auto map_entry = function_map_.find(name); std::vector overloads; if (map_entry == function_map_.end()) { return overloads; } overloads.resize(map_entry->second.size()); std::transform(map_entry->second.begin(), map_entry->second.end(), overloads.begin(), [](const auto& func) { return func.get(); }); return overloads; } bool Activation::RemoveFunctionEntries( const CelFunctionDescriptor& descriptor) { auto map_entry = function_map_.find(descriptor.name()); if (map_entry == function_map_.end()) { return false; } std::vector>& overloads = map_entry->second; bool funcs_removed = false; auto func_iter = overloads.begin(); while (func_iter != overloads.end()) { if (descriptor.ShapeMatches(func_iter->get()->descriptor())) { func_iter = overloads.erase(func_iter); funcs_removed = true; } else { ++func_iter; } } if (overloads.empty()) { function_map_.erase(map_entry); } return funcs_removed; } void Activation::InsertValue(absl::string_view name, const CelValue& value) { value_map_.try_emplace(name, ValueEntry(value)); } void Activation::InsertValueProducer( absl::string_view name, std::unique_ptr value_producer) { value_map_.try_emplace(name, ValueEntry(std::move(value_producer))); } bool Activation::RemoveValueEntry(absl::string_view name) { return value_map_.erase(name); } bool Activation::ClearValueEntry(absl::string_view name) { auto entry = value_map_.find(name); // No entry found. if (entry == value_map_.end()) { return false; } return entry->second.ClearValue(); } int Activation::ClearCachedValues() { int n = 0; for (auto& entry : value_map_) { if (entry.second.HasProducer()) { if (entry.second.ClearValue()) { n++; } } } return n; } } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/activation.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_H_ #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "eval/public/base_activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" #include "eval/public/cel_value.h" #include "eval/public/cel_value_producer.h" #include "runtime/internal/attribute_matcher.h" #include "google/protobuf/arena.h" namespace cel::runtime_internal { class ActivationAttributeMatcherAccess; } namespace google::api::expr::runtime { // Instance of Activation class is used by evaluator. // It provides binding between references used in expressions // and actual values. class Activation : public BaseActivation { public: Activation() = default; // Non-copyable/non-assignable Activation(const Activation&) = delete; Activation& operator=(const Activation&) = delete; // Move-constructible/move-assignable Activation(Activation&& other) = default; Activation& operator=(Activation&& other) = default; // BaseActivation std::vector FindFunctionOverloads( absl::string_view name) const override; absl::optional FindValue(absl::string_view name, google::protobuf::Arena* arena) const override; // Insert a function into the activation (ie a lazily bound function). Returns // a status if the name and shape of the function matches another one that has // already been bound. absl::Status InsertFunction(std::unique_ptr function); // Insert value into Activation. void InsertValue(absl::string_view name, const CelValue& value); // Insert ValueProducer into Activation. void InsertValueProducer(absl::string_view name, std::unique_ptr value_producer); // Remove functions that have the same name and shape as descriptor. Returns // true if matching functions were found and removed. bool RemoveFunctionEntries(const CelFunctionDescriptor& descriptor); // Removes value or producer, returns true if entry with the name was found bool RemoveValueEntry(absl::string_view name); // Clears a cached value for a value producer, returns if true if entry was // found and cleared. bool ClearValueEntry(absl::string_view name); // Clears all cached values for value producers. Returns the number of entries // cleared. int ClearCachedValues(); // Set missing attribute patterns for evaluation. // // If a field access is found to match any of the provided patterns, the // result is treated as a missing attribute error. void set_missing_attribute_patterns( std::vector missing_attribute_patterns) { missing_attribute_patterns_ = std::move(missing_attribute_patterns); } const std::vector& missing_attribute_patterns() const override { return missing_attribute_patterns_; } // Sets the collection of attribute patterns that will be recognized as // "unknown" values during expression evaluation. void set_unknown_attribute_patterns( std::vector unknown_attribute_patterns) { unknown_attribute_patterns_ = std::move(unknown_attribute_patterns); } // Return the collection of attribute patterns that determine "unknown" // values. const std::vector& unknown_attribute_patterns() const override { return unknown_attribute_patterns_; } private: class ValueEntry { public: explicit ValueEntry(std::unique_ptr prod) : value_(), producer_(std::move(prod)) {} explicit ValueEntry(const CelValue& value) : value_(value), producer_() {} // Retrieve associated CelValue. // If the value is not set and producer is set, // obtain and cache value from producer. absl::optional RetrieveValue(google::protobuf::Arena* arena) const { if (!value_.has_value()) { if (producer_) { value_ = producer_->Produce(arena); } } return value_; } bool ClearValue() { bool result = value_.has_value(); value_.reset(); return result; } bool HasProducer() const { return producer_ != nullptr; } private: mutable absl::optional value_; std::unique_ptr producer_; }; friend class cel::runtime_internal::ActivationAttributeMatcherAccess; void SetAttributeMatcher( const cel::runtime_internal::AttributeMatcher* matcher) { attribute_matcher_ = matcher; } void SetAttributeMatcher( std::unique_ptr matcher) { owned_attribute_matcher_ = std::move(matcher); attribute_matcher_ = owned_attribute_matcher_.get(); } const cel::runtime_internal::AttributeMatcher* absl_nullable GetAttributeMatcher() const override { return attribute_matcher_; } absl::flat_hash_map value_map_; absl::flat_hash_map>> function_map_; std::vector missing_attribute_patterns_; std::vector unknown_attribute_patterns_; const cel::runtime_internal::AttributeMatcher* attribute_matcher_ = nullptr; std::unique_ptr owned_attribute_matcher_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_H_ ================================================ FILE: eval/public/activation_bind_helper.cc ================================================ #include "eval/public/activation_bind_helper.h" #include "absl/status/status.h" #include "eval/public/containers/field_access.h" #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/containers/field_backed_map_impl.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using google::protobuf::Arena; using google::protobuf::Message; using google::protobuf::FieldDescriptor; using google::protobuf::Descriptor; absl::Status CreateValueFromField(const google::protobuf::Message* msg, const FieldDescriptor* field_desc, google::protobuf::Arena* arena, CelValue* result) { if (field_desc->is_map()) { *result = CelValue::CreateMap(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); return absl::OkStatus(); } else if (field_desc->is_repeated()) { *result = CelValue::CreateList(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); return absl::OkStatus(); } else { return CreateValueFromSingleField(msg, field_desc, arena, result); } } } // namespace absl::Status BindProtoToActivation(const Message* message, Arena* arena, Activation* activation, ProtoUnsetFieldOptions options) { // If we need to bind any types that are backed by an arena allocation, we // will cause a memory leak. if (arena == nullptr) { return absl::InvalidArgumentError( "arena must not be null for BindProtoToActivation."); } // TODO(issues/24): Improve the utilities to bind dynamic values as well. const Descriptor* desc = message->GetDescriptor(); const google::protobuf::Reflection* reflection = message->GetReflection(); for (int i = 0; i < desc->field_count(); i++) { CelValue value; const FieldDescriptor* field_desc = desc->field(i); if (options == ProtoUnsetFieldOptions::kSkip) { if (!field_desc->is_repeated() && !reflection->HasField(*message, field_desc)) { continue; } } auto status = CreateValueFromField(message, field_desc, arena, &value); if (!status.ok()) { return status; } activation->InsertValue(field_desc->name(), value); } return absl::OkStatus(); } } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/activation_bind_helper.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_BIND_HELPER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_BIND_HELPER_H_ #include "eval/public/activation.h" namespace google { namespace api { namespace expr { namespace runtime { enum class ProtoUnsetFieldOptions { // Do not bind a field if it is unset. Repeated fields are bound as empty // list. kSkip, // Bind the (cc api) default value for a field. kBindDefault }; // Utility method, that takes a protobuf Message and interprets it as a // namespace, binding its fields to Activation. |arena| must be non-null. // // Field names and values become respective names and values of parameters // bound to the Activation object. // Example: // Assume we have a protobuf message of type: // message Person { // int age = 1; // string name = 2; // } // // The sample code snippet will look as follows: // // Person person; // person.set_name("John Doe"); // person.age(42); // // CEL_RETURN_IF_ERROR(BindProtoToActivation(&person, &arena, &activation)); // // After this snippet, activation will have two parameters bound: // "name", with string value of "John Doe" // "age", with int value of 42. // // The default behavior for unset fields is to skip them. E.g. if the name field // is not set on the Person message, it will not be bound in to the activation. // ProtoUnsetFieldOptions::kBindDefault, will bind the cc proto api default for // the field (either an explicit default value or a type specific default). // // TODO(issues/41): Consider updating the default behavior to bind default // values for unset fields. absl::Status BindProtoToActivation( const google::protobuf::Message* message, google::protobuf::Arena* arena, Activation* activation, ProtoUnsetFieldOptions options = ProtoUnsetFieldOptions::kSkip); } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_ACTIVATION_BIND_HELPER_H_ ================================================ FILE: eval/public/activation_bind_helper_test.cc ================================================ #include "eval/public/activation_bind_helper.h" #include "absl/status/status.h" #include "eval/public/activation.h" #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "testutil/util.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using testutil::EqualsProto; TEST(ActivationBindHelperTest, TestSingleBoolBind) { TestMessage message; message.set_bool_value(true); google::protobuf::Arena arena; Activation activation; ASSERT_OK(BindProtoToActivation(&message, &arena, &activation)); auto result = activation.FindValue("bool_value", &arena); ASSERT_TRUE(result.has_value()); CelValue value = result.value(); ASSERT_TRUE(value.IsBool()); EXPECT_EQ(value.BoolOrDie(), true); } TEST(ActivationBindHelperTest, TestSingleInt32Bind) { TestMessage message; message.set_int32_value(42); google::protobuf::Arena arena; Activation activation; ASSERT_OK(BindProtoToActivation(&message, &arena, &activation)); auto result = activation.FindValue("int32_value", &arena); ASSERT_TRUE(result.has_value()); CelValue value = result.value(); ASSERT_TRUE(value.IsInt64()); EXPECT_EQ(value.Int64OrDie(), 42); } TEST(ActivationBindHelperTest, TestUnsetRepeatedIsEmptyList) { TestMessage message; google::protobuf::Arena arena; Activation activation; ASSERT_OK(BindProtoToActivation(&message, &arena, &activation)); auto result = activation.FindValue("int32_list", &arena); ASSERT_TRUE(result.has_value()); CelValue value = result.value(); ASSERT_TRUE(value.IsList()); EXPECT_TRUE(value.ListOrDie()->empty()); } TEST(ActivationBindHelperTest, TestSkipUnsetFields) { TestMessage message; message.set_int32_value(42); google::protobuf::Arena arena; Activation activation; ASSERT_OK(BindProtoToActivation(&message, &arena, &activation, ProtoUnsetFieldOptions::kSkip)); // Explicitly set field is unaffected. auto result = activation.FindValue("int32_value", &arena); ASSERT_TRUE(result.has_value()); CelValue value = result.value(); ASSERT_TRUE(value.IsInt64()); EXPECT_EQ(value.Int64OrDie(), 42); result = activation.FindValue("message_value", &arena); ASSERT_FALSE(result.has_value()); } TEST(ActivationBindHelperTest, TestBindDefaultFields) { TestMessage message; message.set_int32_value(42); google::protobuf::Arena arena; Activation activation; ASSERT_OK(BindProtoToActivation(&message, &arena, &activation, ProtoUnsetFieldOptions::kBindDefault)); auto result = activation.FindValue("int32_value", &arena); ASSERT_TRUE(result.has_value()); CelValue value = result.value(); ASSERT_TRUE(value.IsInt64()); EXPECT_EQ(value.Int64OrDie(), 42); result = activation.FindValue("message_value", &arena); ASSERT_TRUE(result.has_value()); EXPECT_NE(nullptr, result->MessageOrDie()); EXPECT_THAT(TestMessage::default_instance(), EqualsProto(*result->MessageOrDie())); } TEST(ActivationBindHelperTest, RejectsNullArena) { TestMessage message; message.set_bool_value(true); Activation activation; ASSERT_EQ(BindProtoToActivation(&message, /*arena=*/nullptr, &activation), absl::InvalidArgumentError( "arena must not be null for BindProtoToActivation.")); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/activation_test.cc ================================================ #include "eval/public/activation.h" #include #include #include #include "eval/eval/attribute_trail.h" #include "eval/eval/ident_step.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::absl_testing::StatusIs; using ::cel::extensions::ProtoMemoryManager; using ::cel::expr::Expr; using ::google::protobuf::Arena; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Property; using ::testing::Return; class MockValueProducer : public CelValueProducer { public: MOCK_METHOD(CelValue, Produce, (Arena*), (override)); }; // Simple function that takes no args and returns an int64. class ConstCelFunction : public CelFunction { public: explicit ConstCelFunction(absl::string_view name) : CelFunction({std::string(name), false, {}}) {} explicit ConstCelFunction(const CelFunctionDescriptor& desc) : CelFunction(desc) {} absl::Status Evaluate(absl::Span args, CelValue* output, google::protobuf::Arena* arena) const override { *output = CelValue::CreateInt64(42); return absl::OkStatus(); } }; TEST(ActivationTest, CheckValueInsertFindAndRemove) { Activation activation; Arena arena; activation.InsertValue("value42", CelValue::CreateInt64(42)); // Test getting unbound value EXPECT_FALSE(activation.FindValue("value43", &arena)); // Test getting bound value EXPECT_TRUE(activation.FindValue("value42", &arena)); CelValue value = activation.FindValue("value42", &arena).value(); // Test value is correct. EXPECT_THAT(value.Int64OrDie(), Eq(42)); // Test removing unbound value EXPECT_FALSE(activation.RemoveValueEntry("value43")); // Test removing bound value EXPECT_TRUE(activation.RemoveValueEntry("value42")); // Now the value is unbound EXPECT_FALSE(activation.FindValue("value42", &arena)); } TEST(ActivationTest, CheckValueProducerInsertFindAndRemove) { const std::string kValue = "42"; auto producer = std::make_unique(); google::protobuf::Arena arena; ON_CALL(*producer, Produce(&arena)) .WillByDefault(Return(CelValue::CreateString(&kValue))); // ValueProducer is expected to be invoked only once. EXPECT_CALL(*producer, Produce(&arena)).Times(1); Activation activation; activation.InsertValueProducer("value42", std::move(producer)); // Test getting unbound value EXPECT_FALSE(activation.FindValue("value43", &arena)); // Test getting bound value - 1st pass // Access attempt is repeated twice. // ValueProducer is expected to be invoked only once. for (int i = 0; i < 2; i++) { auto opt_value = activation.FindValue("value42", &arena); EXPECT_TRUE(opt_value.has_value()) << " for pass " << i; CelValue value = opt_value.value(); EXPECT_THAT(value.StringOrDie().value(), Eq(kValue)) << " for pass " << i; } // Test removing bound value EXPECT_TRUE(activation.RemoveValueEntry("value42")); // Now the value is unbound EXPECT_FALSE(activation.FindValue("value42", &arena)); } TEST(ActivationTest, CheckInsertFunction) { Activation activation; ASSERT_OK(activation.InsertFunction( std::make_unique("ConstFunc"))); auto overloads = activation.FindFunctionOverloads("ConstFunc"); EXPECT_THAT(overloads, ElementsAre(Property( &CelFunction::descriptor, Property(&CelFunctionDescriptor::name, Eq("ConstFunc"))))); EXPECT_THAT(activation.InsertFunction( std::make_unique("ConstFunc")), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Function with same shape"))); EXPECT_THAT(activation.FindFunctionOverloads("ConstFunc0"), IsEmpty()); } TEST(ActivationTest, CheckRemoveFunction) { Activation activation; ASSERT_OK(activation.InsertFunction(std::make_unique( CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kInt64}}))); EXPECT_OK(activation.InsertFunction(std::make_unique( CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kUint64}}))); auto overloads = activation.FindFunctionOverloads("ConstFunc"); EXPECT_THAT( overloads, ElementsAre( Property(&CelFunction::descriptor, Property(&CelFunctionDescriptor::name, Eq("ConstFunc"))), Property(&CelFunction::descriptor, Property(&CelFunctionDescriptor::name, Eq("ConstFunc"))))); EXPECT_TRUE(activation.RemoveFunctionEntries( {"ConstFunc", false, {CelValue::Type::kAny}})); EXPECT_THAT(activation.FindFunctionOverloads("ConstFunc"), IsEmpty()); } TEST(ActivationTest, CheckValueProducerClear) { const std::string kValue1 = "42"; const std::string kValue2 = "43"; auto producer1 = std::make_unique(); auto producer2 = std::make_unique(); google::protobuf::Arena arena; ON_CALL(*producer1, Produce(&arena)) .WillByDefault(Return(CelValue::CreateString(&kValue1))); ON_CALL(*producer2, Produce(&arena)) .WillByDefault(Return(CelValue::CreateString(&kValue2))); EXPECT_CALL(*producer1, Produce(&arena)).Times(2); EXPECT_CALL(*producer2, Produce(&arena)).Times(1); Activation activation; activation.InsertValueProducer("value42", std::move(producer1)); activation.InsertValueProducer("value43", std::move(producer2)); // Produce first value auto opt_value = activation.FindValue("value42", &arena); EXPECT_TRUE(opt_value.has_value()); EXPECT_THAT(opt_value->StringOrDie().value(), Eq(kValue1)); // Test clearing bound value EXPECT_TRUE(activation.ClearValueEntry("value42")); EXPECT_FALSE(activation.ClearValueEntry("value43")); // Produce second value auto opt_value2 = activation.FindValue("value43", &arena); EXPECT_TRUE(opt_value2.has_value()); EXPECT_THAT(opt_value2->StringOrDie().value(), Eq(kValue2)); // Clear all values EXPECT_EQ(1, activation.ClearCachedValues()); EXPECT_FALSE(activation.ClearValueEntry("value42")); EXPECT_FALSE(activation.ClearValueEntry("value43")); // Produce first value again auto opt_value3 = activation.FindValue("value42", &arena); EXPECT_TRUE(opt_value3.has_value()); EXPECT_THAT(opt_value3->StringOrDie().value(), Eq(kValue1)); EXPECT_EQ(1, activation.ClearCachedValues()); } TEST(ActivationTest, ErrorPathTest) { Activation activation; Expr expr; auto* select_expr = expr.mutable_select_expr(); select_expr->set_field("ip"); Expr* ident_expr = select_expr->mutable_operand(); ident_expr->mutable_ident_expr()->set_name("destination"); const CelAttributePattern destination_ip_pattern( "destination", {CreateCelAttributeQualifierPattern(CelValue::CreateStringView("ip"))}); AttributeTrail trail("destination"); trail = trail.Step(CreateCelAttributeQualifier(CelValue::CreateStringView("ip"))); ASSERT_EQ(destination_ip_pattern.IsMatch(trail.attribute()), CelAttributePattern::MatchType::FULL); EXPECT_TRUE(activation.missing_attribute_patterns().empty()); activation.set_missing_attribute_patterns({destination_ip_pattern}); EXPECT_EQ( activation.missing_attribute_patterns()[0].IsMatch(trail.attribute()), CelAttributePattern::MatchType::FULL); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/ast_rewrite.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/ast_rewrite.h" #include #include #include "cel/expr/syntax.pb.h" #include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "eval/public/ast_visitor.h" #include "eval/public/source_position.h" namespace google::api::expr::runtime { using cel::expr::Expr; using cel::expr::SourceInfo; using Ident = cel::expr::Expr::Ident; using Select = cel::expr::Expr::Select; using Call = cel::expr::Expr::Call; using CreateList = cel::expr::Expr::CreateList; using CreateStruct = cel::expr::Expr::CreateStruct; using Comprehension = cel::expr::Expr::Comprehension; namespace { struct ArgRecord { // Not null. Expr* expr; // Not null. const SourceInfo* source_info; // For records that are direct arguments to call, we need to call // the CallArg visitor immediately after the argument is evaluated. const Expr* calling_expr; int call_arg; }; struct ComprehensionRecord { // Not null. Expr* expr; // Not null. const SourceInfo* source_info; const Comprehension* comprehension; const Expr* comprehension_expr; ComprehensionArg comprehension_arg; bool use_comprehension_callbacks; }; struct ExprRecord { // Not null. Expr* expr; // Not null. const SourceInfo* source_info; }; using StackRecordKind = absl::variant; struct StackRecord { public: ABSL_ATTRIBUTE_UNUSED static constexpr int kNotCallArg = -1; static constexpr int kTarget = -2; StackRecord(Expr* e, const SourceInfo* info) { ExprRecord record; record.expr = e; record.source_info = info; record_variant = record; } StackRecord(Expr* e, const SourceInfo* info, Comprehension* comprehension, Expr* comprehension_expr, ComprehensionArg comprehension_arg, bool use_comprehension_callbacks) { if (use_comprehension_callbacks) { ComprehensionRecord record; record.expr = e; record.source_info = info; record.comprehension = comprehension; record.comprehension_expr = comprehension_expr; record.comprehension_arg = comprehension_arg; record.use_comprehension_callbacks = use_comprehension_callbacks; record_variant = record; return; } ArgRecord record; record.expr = e; record.source_info = info; record.calling_expr = comprehension_expr; record.call_arg = comprehension_arg; record_variant = record; } StackRecord(Expr* e, const SourceInfo* info, const Expr* call, int argnum) { ArgRecord record; record.expr = e; record.source_info = info; record.calling_expr = call; record.call_arg = argnum; record_variant = record; } Expr* expr() const { return absl::get(record_variant).expr; } const SourceInfo* source_info() const { return absl::get(record_variant).source_info; } bool IsExprRecord() const { return absl::holds_alternative(record_variant); } StackRecordKind record_variant; bool visited = false; }; struct PreVisitor { void operator()(const ExprRecord& record) { Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); visitor->PreVisitExpr(expr, &position); switch (expr->expr_kind_case()) { case Expr::kSelectExpr: visitor->PreVisitSelect(&expr->select_expr(), expr, &position); break; case Expr::kCallExpr: visitor->PreVisitCall(&expr->call_expr(), expr, &position); break; case Expr::kComprehensionExpr: visitor->PreVisitComprehension(&expr->comprehension_expr(), expr, &position); break; default: // No pre-visit action. break; } } // Do nothing for Arg variant. void operator()(const ArgRecord&) {} void operator()(const ComprehensionRecord& record) { Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); visitor->PreVisitComprehensionSubexpression( expr, record.comprehension, record.comprehension_arg, &position); } AstVisitor* visitor; }; void PreVisit(const StackRecord& record, AstVisitor* visitor) { absl::visit(PreVisitor{visitor}, record.record_variant); } struct PostVisitor { void operator()(const ExprRecord& record) { Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); switch (expr->expr_kind_case()) { case Expr::kConstExpr: visitor->PostVisitConst(&expr->const_expr(), expr, &position); break; case Expr::kIdentExpr: visitor->PostVisitIdent(&expr->ident_expr(), expr, &position); break; case Expr::kSelectExpr: visitor->PostVisitSelect(&expr->select_expr(), expr, &position); break; case Expr::kCallExpr: visitor->PostVisitCall(&expr->call_expr(), expr, &position); break; case Expr::kListExpr: visitor->PostVisitCreateList(&expr->list_expr(), expr, &position); break; case Expr::kStructExpr: visitor->PostVisitCreateStruct(&expr->struct_expr(), expr, &position); break; case Expr::kComprehensionExpr: visitor->PostVisitComprehension(&expr->comprehension_expr(), expr, &position); break; case Expr::EXPR_KIND_NOT_SET: break; default: ABSL_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); } visitor->PostVisitExpr(expr, &position); } void operator()(const ArgRecord& record) { Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); if (record.call_arg == StackRecord::kTarget) { visitor->PostVisitTarget(record.calling_expr, &position); } else { visitor->PostVisitArg(record.call_arg, record.calling_expr, &position); } } void operator()(const ComprehensionRecord& record) { Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); visitor->PostVisitComprehensionSubexpression( expr, record.comprehension, record.comprehension_arg, &position); } AstVisitor* visitor; }; void PostVisit(const StackRecord& record, AstVisitor* visitor) { absl::visit(PostVisitor{visitor}, record.record_variant); } void PushSelectDeps(Select* select_expr, const SourceInfo* source_info, std::stack* stack) { if (select_expr->has_operand()) { stack->push(StackRecord(select_expr->mutable_operand(), source_info)); } } void PushCallDeps(Call* call_expr, Expr* expr, const SourceInfo* source_info, std::stack* stack) { const int arg_size = call_expr->args_size(); // Our contract is that we visit arguments in order. To do that, we need // to push them onto the stack in reverse order. for (int i = arg_size - 1; i >= 0; --i) { stack->push(StackRecord(call_expr->mutable_args(i), source_info, expr, i)); } // Are we receiver-style? if (call_expr->has_target()) { stack->push(StackRecord(call_expr->mutable_target(), source_info, expr, StackRecord::kTarget)); } } void PushListDeps(CreateList* list_expr, const SourceInfo* source_info, std::stack* stack) { auto& elements = *list_expr->mutable_elements(); for (auto it = elements.rbegin(); it != elements.rend(); ++it) { auto& element = *it; stack->push(StackRecord(&element, source_info)); } } void PushStructDeps(CreateStruct* struct_expr, const SourceInfo* source_info, std::stack* stack) { auto& entries = *struct_expr->mutable_entries(); for (auto it = entries.rbegin(); it != entries.rend(); ++it) { auto& entry = *it; // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_value()) { stack->push(StackRecord(entry.mutable_value(), source_info)); } if (entry.has_map_key()) { stack->push(StackRecord(entry.mutable_map_key(), source_info)); } } } void PushComprehensionDeps(Comprehension* c, Expr* expr, const SourceInfo* source_info, std::stack* stack, bool use_comprehension_callbacks) { StackRecord iter_range(c->mutable_iter_range(), source_info, c, expr, ITER_RANGE, use_comprehension_callbacks); StackRecord accu_init(c->mutable_accu_init(), source_info, c, expr, ACCU_INIT, use_comprehension_callbacks); StackRecord loop_condition(c->mutable_loop_condition(), source_info, c, expr, LOOP_CONDITION, use_comprehension_callbacks); StackRecord loop_step(c->mutable_loop_step(), source_info, c, expr, LOOP_STEP, use_comprehension_callbacks); StackRecord result(c->mutable_result(), source_info, c, expr, RESULT, use_comprehension_callbacks); // Push them in reverse order. stack->push(result); stack->push(loop_step); stack->push(loop_condition); stack->push(accu_init); stack->push(iter_range); } struct PushDepsVisitor { void operator()(const ExprRecord& record) { Expr* expr = record.expr; switch (expr->expr_kind_case()) { case Expr::kSelectExpr: PushSelectDeps(expr->mutable_select_expr(), record.source_info, &stack); break; case Expr::kCallExpr: PushCallDeps(expr->mutable_call_expr(), expr, record.source_info, &stack); break; case Expr::kListExpr: PushListDeps(expr->mutable_list_expr(), record.source_info, &stack); break; case Expr::kStructExpr: PushStructDeps(expr->mutable_struct_expr(), record.source_info, &stack); break; case Expr::kComprehensionExpr: PushComprehensionDeps(expr->mutable_comprehension_expr(), expr, record.source_info, &stack, options.use_comprehension_callbacks); break; default: break; } } void operator()(const ArgRecord& record) { stack.push(StackRecord(record.expr, record.source_info)); } void operator()(const ComprehensionRecord& record) { stack.push(StackRecord(record.expr, record.source_info)); } std::stack& stack; const RewriteTraversalOptions& options; }; void PushDependencies(const StackRecord& record, std::stack& stack, const RewriteTraversalOptions& options) { absl::visit(PushDepsVisitor{stack, options}, record.record_variant); } } // namespace bool AstRewrite(Expr* expr, const SourceInfo* source_info, AstRewriter* visitor) { return AstRewrite(expr, source_info, visitor, RewriteTraversalOptions{}); } bool AstRewrite(Expr* expr, const SourceInfo* source_info, AstRewriter* visitor, RewriteTraversalOptions options) { std::stack stack; std::vector traversal_path; stack.push(StackRecord(expr, source_info)); bool rewritten = false; while (!stack.empty()) { StackRecord& record = stack.top(); if (!record.visited) { if (record.IsExprRecord()) { traversal_path.push_back(record.expr()); visitor->TraversalStackUpdate(absl::MakeSpan(traversal_path)); SourcePosition pos(record.expr()->id(), record.source_info()); if (visitor->PreVisitRewrite(record.expr(), &pos)) { rewritten = true; } } PreVisit(record, visitor); PushDependencies(record, stack, options); record.visited = true; } else { PostVisit(record, visitor); if (record.IsExprRecord()) { SourcePosition pos(record.expr()->id(), record.source_info()); if (visitor->PostVisitRewrite(record.expr(), &pos)) { rewritten = true; } traversal_path.pop_back(); visitor->TraversalStackUpdate(absl::MakeSpan(traversal_path)); } stack.pop(); } } return rewritten; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/ast_rewrite.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_REWRITE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_REWRITE_H_ #include "cel/expr/syntax.pb.h" #include "absl/types/span.h" #include "eval/public/ast_visitor.h" namespace google::api::expr::runtime { // Traversal options for AstRewrite. struct RewriteTraversalOptions { // If enabled, use comprehension specific callbacks instead of the general // arguments callbacks. bool use_comprehension_callbacks; RewriteTraversalOptions() : use_comprehension_callbacks(false) {} }; // Interface for AST rewriters. // Extends AstVisitor interface with update methods. // see AstRewrite for more details on usage. class AstRewriter : public AstVisitor { public: ~AstRewriter() override {} // Rewrite a sub expression before visiting. // Occurs before visiting Expr. If expr is modified, the new value will be // visited. virtual bool PreVisitRewrite(cel::expr::Expr* expr, const SourcePosition* position) = 0; // Rewrite a sub expression after visiting. // Occurs after visiting expr and it's children. If expr is modified, the old // sub expression is visited. virtual bool PostVisitRewrite(cel::expr::Expr* expr, const SourcePosition* position) = 0; // Notify the visitor of updates to the traversal stack. virtual void TraversalStackUpdate( absl::Span path) = 0; }; // Trivial implementation for AST rewriters. // Virtual methods are overridden with no-op callbacks. class AstRewriterBase : public AstRewriter { public: ~AstRewriterBase() override {} void PreVisitExpr(const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitExpr(const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitConst(const cel::expr::Constant*, const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitIdent(const cel::expr::Expr::Ident*, const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitSelect(const cel::expr::Expr::Select*, const cel::expr::Expr*, const SourcePosition*) override {} void PreVisitCall(const cel::expr::Expr::Call*, const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitCall(const cel::expr::Expr::Call*, const cel::expr::Expr*, const SourcePosition*) override {} void PreVisitComprehension(const cel::expr::Expr::Comprehension*, const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitComprehension(const cel::expr::Expr::Comprehension*, const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitArg(int, const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitTarget(const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitCreateList(const cel::expr::Expr::CreateList*, const cel::expr::Expr*, const SourcePosition*) override {} void PostVisitCreateStruct(const cel::expr::Expr::CreateStruct*, const cel::expr::Expr*, const SourcePosition*) override {} bool PreVisitRewrite(cel::expr::Expr* expr, const SourcePosition* position) override { return false; } bool PostVisitRewrite(cel::expr::Expr* expr, const SourcePosition* position) override { return false; } void TraversalStackUpdate( absl::Span path) override {} }; // Traverses the AST representation in an expr proto. Returns true if any // rewrites occur. // // Rewrites may happen before and/or after visiting an expr subtree. If a // change happens during the pre-visit rewrite, the updated subtree will be // visited. If a change happens during the post-visit rewrite, the old subtree // will be visited. // // expr: root node of the tree. // source_info: optional additional parse information about the expression // visitor: the callback object that receives the visitation notifications // options: options for traversal. see RewriteTraversalOptions. Defaults are // used if not sepecified. // // Traversal order follows the pattern: // PreVisitRewrite // PreVisitExpr // ..PreVisit{ExprKind} // ....PreVisit{ArgumentIndex} // .......PreVisitExpr (subtree) // .......PostVisitExpr (subtree) // ....PostVisit{ArgumentIndex} // ..PostVisit{ExprKind} // PostVisitExpr // PostVisitRewrite // // Example callback order for fn(1, var): // PreVisitExpr // ..PreVisitCall(fn) // ......PreVisitExpr // ........PostVisitConst(1) // ......PostVisitExpr // ....PostVisitArg(fn, 0) // ......PreVisitExpr // ........PostVisitIdent(var) // ......PostVisitExpr // ....PostVisitArg(fn, 1) // ..PostVisitCall(fn) // PostVisitExpr bool AstRewrite(cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info, AstRewriter* visitor); bool AstRewrite(cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info, AstRewriter* visitor, RewriteTraversalOptions options); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_REWRITE_H_ ================================================ FILE: eval/public/ast_rewrite_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/ast_rewrite.h" #include #include #include "cel/expr/syntax.pb.h" #include "eval/public/ast_visitor.h" #include "eval/public/source_position.h" #include "internal/testing.h" #include "parser/parser.h" #include "testutil/util.h" namespace google::api::expr::runtime { namespace { using ::cel::expr::Constant; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::testing::_; using ::testing::ElementsAre; using ::testing::InSequence; using Ident = cel::expr::Expr::Ident; using Select = cel::expr::Expr::Select; using Call = cel::expr::Expr::Call; using CreateList = cel::expr::Expr::CreateList; using CreateStruct = cel::expr::Expr::CreateStruct; using Comprehension = cel::expr::Expr::Comprehension; class MockAstRewriter : public AstRewriter { public: // Expr handler. MOCK_METHOD(void, PreVisitExpr, (const Expr* expr, const SourcePosition* position), (override)); // Expr handler. MOCK_METHOD(void, PostVisitExpr, (const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitConst, (const Constant* const_expr, const Expr* expr, const SourcePosition* position), (override)); // Ident node handler. MOCK_METHOD(void, PostVisitIdent, (const Ident* ident_expr, const Expr* expr, const SourcePosition* position), (override)); // Select node handler group MOCK_METHOD(void, PreVisitSelect, (const Select* select_expr, const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitSelect, (const Select* select_expr, const Expr* expr, const SourcePosition* position), (override)); // Call node handler group MOCK_METHOD(void, PreVisitCall, (const Call* call_expr, const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitCall, (const Call* call_expr, const Expr* expr, const SourcePosition* position), (override)); // Comprehension node handler group MOCK_METHOD(void, PreVisitComprehension, (const Comprehension* comprehension_expr, const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitComprehension, (const Comprehension* comprehension_expr, const Expr* expr, const SourcePosition* position), (override)); // Comprehension node handler group MOCK_METHOD(void, PreVisitComprehensionSubexpression, (const Expr* expr, const Comprehension* comprehension_expr, ComprehensionArg comprehension_arg, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitComprehensionSubexpression, (const Expr* expr, const Comprehension* comprehension_expr, ComprehensionArg comprehension_arg, const SourcePosition* position), (override)); // We provide finer granularity for Call and Comprehension node callbacks // to allow special handling for short-circuiting. MOCK_METHOD(void, PostVisitTarget, (const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitArg, (int arg_num, const Expr* expr, const SourcePosition* position), (override)); // CreateList node handler group MOCK_METHOD(void, PostVisitCreateList, (const CreateList* list_expr, const Expr* expr, const SourcePosition* position), (override)); // CreateStruct node handler group MOCK_METHOD(void, PostVisitCreateStruct, (const CreateStruct* struct_expr, const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(bool, PreVisitRewrite, (Expr * expr, const SourcePosition* position), (override)); MOCK_METHOD(bool, PostVisitRewrite, (Expr * expr, const SourcePosition* position), (override)); MOCK_METHOD(void, TraversalStackUpdate, (absl::Span path), (override)); }; TEST(AstCrawlerTest, CheckCrawlConstant) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto const_expr = expr.mutable_const_expr(); EXPECT_CALL(handler, PostVisitConst(const_expr, &expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } TEST(AstCrawlerTest, CheckCrawlIdent) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto ident_expr = expr.mutable_ident_expr(); EXPECT_CALL(handler, PostVisitIdent(ident_expr, &expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } // Test handling of Select node when operand is not set. TEST(AstCrawlerTest, CheckCrawlSelectNotCrashingPostVisitAbsentOperand) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto select_expr = expr.mutable_select_expr(); // Lowest level entry will be called first EXPECT_CALL(handler, PostVisitSelect(select_expr, &expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } // Test handling of Select node TEST(AstCrawlerTest, CheckCrawlSelect) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto select_expr = expr.mutable_select_expr(); auto operand = select_expr->mutable_operand(); auto ident_expr = operand->mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PostVisitIdent(ident_expr, operand, _)).Times(1); EXPECT_CALL(handler, PostVisitSelect(select_expr, &expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } // Test handling of Call node without receiver TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { SourceInfo source_info; MockAstRewriter handler; // (, ) Expr expr; auto* call_expr = expr.mutable_call_expr(); Expr* arg0 = call_expr->add_args(); auto* const_expr = arg0->mutable_const_expr(); Expr* arg1 = call_expr->add_args(); auto* ident_expr = arg1->mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(call_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitTarget(_, _)).Times(0); // Arg0 EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); // Arg1 EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); // Back to call EXPECT_CALL(handler, PostVisitCall(call_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } // Test handling of Call node with receiver TEST(AstCrawlerTest, CheckCrawlCallReceiver) { SourceInfo source_info; MockAstRewriter handler; // .(, ) Expr expr; auto* call_expr = expr.mutable_call_expr(); Expr* target = call_expr->mutable_target(); auto* target_ident = target->mutable_ident_expr(); Expr* arg0 = call_expr->add_args(); auto* const_expr = arg0->mutable_const_expr(); Expr* arg1 = call_expr->add_args(); auto* ident_expr = arg1->mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(call_expr, &expr, _)).Times(1); // Target EXPECT_CALL(handler, PostVisitIdent(target_ident, target, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(target, _)).Times(1); EXPECT_CALL(handler, PostVisitTarget(&expr, _)).Times(1); // Arg0 EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); // Arg1 EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); // Back to call EXPECT_CALL(handler, PostVisitCall(call_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } // Test handling of Comprehension node TEST(AstCrawlerTest, CheckCrawlComprehension) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto c = expr.mutable_comprehension_expr(); auto iter_range = c->mutable_iter_range(); auto iter_range_expr = iter_range->mutable_const_expr(); auto accu_init = c->mutable_accu_init(); auto accu_init_expr = accu_init->mutable_ident_expr(); auto loop_condition = c->mutable_loop_condition(); auto loop_condition_expr = loop_condition->mutable_const_expr(); auto loop_step = c->mutable_loop_step(); auto loop_step_expr = loop_step->mutable_ident_expr(); auto result = c->mutable_result(); auto result_expr = result->mutable_const_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(c, &expr, _)).Times(1); EXPECT_CALL(handler, PreVisitComprehensionSubexpression(iter_range, c, ITER_RANGE, _)) .Times(1); EXPECT_CALL(handler, PostVisitConst(iter_range_expr, iter_range, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(iter_range, c, ITER_RANGE, _)) .Times(1); // ACCU_INIT EXPECT_CALL(handler, PreVisitComprehensionSubexpression(accu_init, c, ACCU_INIT, _)) .Times(1); EXPECT_CALL(handler, PostVisitIdent(accu_init_expr, accu_init, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(accu_init, c, ACCU_INIT, _)) .Times(1); // LOOP CONDITION EXPECT_CALL(handler, PreVisitComprehensionSubexpression(loop_condition, c, LOOP_CONDITION, _)) .Times(1); EXPECT_CALL(handler, PostVisitConst(loop_condition_expr, loop_condition, _)) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(loop_condition, c, LOOP_CONDITION, _)) .Times(1); // LOOP STEP EXPECT_CALL(handler, PreVisitComprehensionSubexpression(loop_step, c, LOOP_STEP, _)) .Times(1); EXPECT_CALL(handler, PostVisitIdent(loop_step_expr, loop_step, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(loop_step, c, LOOP_STEP, _)) .Times(1); // RESULT EXPECT_CALL(handler, PreVisitComprehensionSubexpression(result, c, RESULT, _)) .Times(1); EXPECT_CALL(handler, PostVisitConst(result_expr, result, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(result, c, RESULT, _)) .Times(1); EXPECT_CALL(handler, PostVisitComprehension(c, &expr, _)).Times(1); RewriteTraversalOptions opts; opts.use_comprehension_callbacks = true; AstRewrite(&expr, &source_info, &handler, opts); } // Test handling of Comprehension node TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto c = expr.mutable_comprehension_expr(); auto iter_range = c->mutable_iter_range(); auto iter_range_expr = iter_range->mutable_const_expr(); auto accu_init = c->mutable_accu_init(); auto accu_init_expr = accu_init->mutable_ident_expr(); auto loop_condition = c->mutable_loop_condition(); auto loop_condition_expr = loop_condition->mutable_const_expr(); auto loop_step = c->mutable_loop_step(); auto loop_step_expr = loop_step->mutable_ident_expr(); auto result = c->mutable_result(); auto result_expr = result->mutable_const_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(c, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitConst(iter_range_expr, iter_range, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(ITER_RANGE, &expr, _)).Times(1); // ACCU_INIT EXPECT_CALL(handler, PostVisitIdent(accu_init_expr, accu_init, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(ACCU_INIT, &expr, _)).Times(1); // LOOP CONDITION EXPECT_CALL(handler, PostVisitConst(loop_condition_expr, loop_condition, _)) .Times(1); EXPECT_CALL(handler, PostVisitArg(LOOP_CONDITION, &expr, _)).Times(1); // LOOP STEP EXPECT_CALL(handler, PostVisitIdent(loop_step_expr, loop_step, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(LOOP_STEP, &expr, _)).Times(1); // RESULT EXPECT_CALL(handler, PostVisitConst(result_expr, result, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(RESULT, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehension(c, &expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } // Test handling of CreateList node. TEST(AstCrawlerTest, CheckCreateList) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto list_expr = expr.mutable_list_expr(); auto arg0 = list_expr->add_elements(); auto const_expr = arg0->mutable_const_expr(); auto arg1 = list_expr->add_elements(); auto ident_expr = arg1->mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitCreateList(list_expr, &expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } // Test handling of CreateStruct node. TEST(AstCrawlerTest, CheckCreateStruct) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto struct_expr = expr.mutable_struct_expr(); auto entry0 = struct_expr->add_entries(); auto key = entry0->mutable_map_key()->mutable_const_expr(); auto value = entry0->mutable_value()->mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PostVisitConst(key, &entry0->map_key(), _)).Times(1); EXPECT_CALL(handler, PostVisitIdent(value, &entry0->value(), _)).Times(1); EXPECT_CALL(handler, PostVisitCreateStruct(struct_expr, &expr, _)).Times(1); AstRewrite(&expr, &source_info, &handler); } // Test generic Expr handlers. TEST(AstCrawlerTest, CheckExprHandlers) { SourceInfo source_info; MockAstRewriter handler; Expr expr; auto struct_expr = expr.mutable_struct_expr(); auto entry0 = struct_expr->add_entries(); entry0->mutable_map_key()->mutable_const_expr(); entry0->mutable_value()->mutable_ident_expr(); EXPECT_CALL(handler, PreVisitExpr(_, _)).Times(3); EXPECT_CALL(handler, PostVisitExpr(_, _)).Times(3); AstRewrite(&expr, &source_info, &handler); } // Test generic Expr handlers. TEST(AstCrawlerTest, CheckExprRewriteHandlers) { SourceInfo source_info; MockAstRewriter handler; Expr select_expr; select_expr.mutable_select_expr()->set_field("var"); auto* inner_select_expr = select_expr.mutable_select_expr()->mutable_operand(); inner_select_expr->mutable_select_expr()->set_field("mid"); auto* ident = inner_select_expr->mutable_select_expr()->mutable_operand(); ident->mutable_ident_expr()->set_name("top"); { InSequence sequence; EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre(&select_expr))); EXPECT_CALL(handler, PreVisitRewrite(&select_expr, _)); EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( &select_expr, inner_select_expr))); EXPECT_CALL(handler, PreVisitRewrite(inner_select_expr, _)); EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( &select_expr, inner_select_expr, ident))); EXPECT_CALL(handler, PreVisitRewrite(ident, _)); EXPECT_CALL(handler, PostVisitRewrite(ident, _)); EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre( &select_expr, inner_select_expr))); EXPECT_CALL(handler, PostVisitRewrite(inner_select_expr, _)); EXPECT_CALL(handler, TraversalStackUpdate(testing::ElementsAre(&select_expr))); EXPECT_CALL(handler, PostVisitRewrite(&select_expr, _)); EXPECT_CALL(handler, TraversalStackUpdate(testing::IsEmpty())); } EXPECT_FALSE(AstRewrite(&select_expr, &source_info, &handler)); } // Simple rewrite that replaces a select path with a dot-qualified identifier. class RewriterExample : public AstRewriterBase { public: RewriterExample() {} bool PostVisitRewrite(Expr* expr, const SourcePosition* info) override { if (target_.has_value() && expr->id() == *target_) { expr->mutable_ident_expr()->set_name("com.google.Identifier"); return true; } return false; } void PostVisitIdent(const Ident* ident, const Expr* expr, const SourcePosition* pos) override { if (path_.size() >= 3) { if (ident->name() == "com") { const Expr* p1 = path_.at(path_.size() - 2); const Expr* p2 = path_.at(path_.size() - 3); if (p1->has_select_expr() && p1->select_expr().field() == "google" && p2->has_select_expr() && p2->select_expr().field() == "Identifier") { target_ = p2->id(); } } } } void TraversalStackUpdate(absl::Span path) override { path_ = path; } private: absl::Span path_; absl::optional target_; }; TEST(AstRewrite, SelectRewriteExample) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed, parser::Parse("com.google.Identifier")); RewriterExample example; ASSERT_TRUE( AstRewrite(parsed.mutable_expr(), &parsed.source_info(), &example)); EXPECT_THAT(parsed.expr(), testutil::EqualsProto(R"pb( id: 3 ident_expr { name: "com.google.Identifier" } )pb")); } // Rewrites x -> y -> z to demonstrate traversal when a node is rewritten on // both passes. class PreRewriterExample : public AstRewriterBase { public: PreRewriterExample() {} bool PreVisitRewrite(Expr* expr, const SourcePosition* info) override { if (expr->ident_expr().name() == "x") { expr->mutable_ident_expr()->set_name("y"); return true; } return false; } bool PostVisitRewrite(Expr* expr, const SourcePosition* info) override { if (expr->ident_expr().name() == "y") { expr->mutable_ident_expr()->set_name("z"); return true; } return false; } void PostVisitIdent(const Ident* ident, const Expr* expr, const SourcePosition* pos) override { visited_idents_.push_back(ident->name()); } const std::vector& visited_idents() const { return visited_idents_; } private: std::vector visited_idents_; }; TEST(AstRewrite, PreAndPostVisitExpample) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed, parser::Parse("x")); PreRewriterExample visitor; ASSERT_TRUE( AstRewrite(parsed.mutable_expr(), &parsed.source_info(), &visitor)); EXPECT_THAT(parsed.expr(), testutil::EqualsProto(R"pb( id: 1 ident_expr { name: "z" } )pb")); EXPECT_THAT(visitor.visited_idents(), ElementsAre("y")); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/ast_traverse.cc ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/ast_traverse.h" #include #include "cel/expr/syntax.pb.h" #include "absl/log/absl_log.h" #include "absl/types/variant.h" #include "eval/public/ast_visitor.h" #include "eval/public/source_position.h" namespace google::api::expr::runtime { using cel::expr::Expr; using cel::expr::SourceInfo; using Ident = cel::expr::Expr::Ident; using Select = cel::expr::Expr::Select; using Call = cel::expr::Expr::Call; using CreateList = cel::expr::Expr::CreateList; using CreateStruct = cel::expr::Expr::CreateStruct; using Comprehension = cel::expr::Expr::Comprehension; namespace { struct ArgRecord { // Not null. const Expr* expr; // Not null. const SourceInfo* source_info; // For records that are direct arguments to call, we need to call // the CallArg visitor immediately after the argument is evaluated. const Expr* calling_expr; int call_arg; }; struct ComprehensionRecord { // Not null. const Expr* expr; // Not null. const SourceInfo* source_info; const Comprehension* comprehension; const Expr* comprehension_expr; ComprehensionArg comprehension_arg; bool use_comprehension_callbacks; }; struct ExprRecord { // Not null. const Expr* expr; // Not null. const SourceInfo* source_info; }; using StackRecordKind = absl::variant; struct StackRecord { public: ABSL_ATTRIBUTE_UNUSED static constexpr int kNotCallArg = -1; static constexpr int kTarget = -2; StackRecord(const Expr* e, const SourceInfo* info) { ExprRecord record; record.expr = e; record.source_info = info; record_variant = record; } StackRecord(const Expr* e, const SourceInfo* info, const Comprehension* comprehension, const Expr* comprehension_expr, ComprehensionArg comprehension_arg, bool use_comprehension_callbacks) { if (use_comprehension_callbacks) { ComprehensionRecord record; record.expr = e; record.source_info = info; record.comprehension = comprehension; record.comprehension_expr = comprehension_expr; record.comprehension_arg = comprehension_arg; record.use_comprehension_callbacks = use_comprehension_callbacks; record_variant = record; return; } ArgRecord record; record.expr = e; record.source_info = info; record.calling_expr = comprehension_expr; record.call_arg = comprehension_arg; record_variant = record; } StackRecord(const Expr* e, const SourceInfo* info, const Expr* call, int argnum) { ArgRecord record; record.expr = e; record.source_info = info; record.calling_expr = call; record.call_arg = argnum; record_variant = record; } StackRecordKind record_variant; bool visited = false; }; struct PreVisitor { void operator()(const ExprRecord& record) { const Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); visitor->PreVisitExpr(expr, &position); switch (expr->expr_kind_case()) { case Expr::kConstExpr: visitor->PreVisitConst(&expr->const_expr(), expr, &position); break; case Expr::kIdentExpr: visitor->PreVisitIdent(&expr->ident_expr(), expr, &position); break; case Expr::kSelectExpr: visitor->PreVisitSelect(&expr->select_expr(), expr, &position); break; case Expr::kCallExpr: visitor->PreVisitCall(&expr->call_expr(), expr, &position); break; case Expr::kListExpr: visitor->PreVisitCreateList(&expr->list_expr(), expr, &position); break; case Expr::kStructExpr: visitor->PreVisitCreateStruct(&expr->struct_expr(), expr, &position); break; case Expr::kComprehensionExpr: visitor->PreVisitComprehension(&expr->comprehension_expr(), expr, &position); break; default: // No pre-visit action. break; } } // Do nothing for Arg variant. void operator()(const ArgRecord&) {} void operator()(const ComprehensionRecord& record) { const Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); visitor->PreVisitComprehensionSubexpression( expr, record.comprehension, record.comprehension_arg, &position); } AstVisitor* visitor; }; void PreVisit(const StackRecord& record, AstVisitor* visitor) { absl::visit(PreVisitor{visitor}, record.record_variant); } struct PostVisitor { void operator()(const ExprRecord& record) { const Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); switch (expr->expr_kind_case()) { case Expr::kConstExpr: visitor->PostVisitConst(&expr->const_expr(), expr, &position); break; case Expr::kIdentExpr: visitor->PostVisitIdent(&expr->ident_expr(), expr, &position); break; case Expr::kSelectExpr: visitor->PostVisitSelect(&expr->select_expr(), expr, &position); break; case Expr::kCallExpr: visitor->PostVisitCall(&expr->call_expr(), expr, &position); break; case Expr::kListExpr: visitor->PostVisitCreateList(&expr->list_expr(), expr, &position); break; case Expr::kStructExpr: visitor->PostVisitCreateStruct(&expr->struct_expr(), expr, &position); break; case Expr::kComprehensionExpr: visitor->PostVisitComprehension(&expr->comprehension_expr(), expr, &position); break; default: ABSL_LOG(ERROR) << "Unsupported Expr kind: " << expr->expr_kind_case(); } visitor->PostVisitExpr(expr, &position); } void operator()(const ArgRecord& record) { const Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); if (record.call_arg == StackRecord::kTarget) { visitor->PostVisitTarget(record.calling_expr, &position); } else { visitor->PostVisitArg(record.call_arg, record.calling_expr, &position); } } void operator()(const ComprehensionRecord& record) { const Expr* expr = record.expr; const SourcePosition position(expr->id(), record.source_info); visitor->PostVisitComprehensionSubexpression( expr, record.comprehension, record.comprehension_arg, &position); } AstVisitor* visitor; }; void PostVisit(const StackRecord& record, AstVisitor* visitor) { absl::visit(PostVisitor{visitor}, record.record_variant); } void PushSelectDeps(const Select* select_expr, const SourceInfo* source_info, std::stack* stack) { if (select_expr->has_operand()) { stack->push(StackRecord(&select_expr->operand(), source_info)); } } void PushCallDeps(const Call* call_expr, const Expr* expr, const SourceInfo* source_info, std::stack* stack) { const int arg_size = call_expr->args_size(); // Our contract is that we visit arguments in order. To do that, we need // to push them onto the stack in reverse order. for (int i = arg_size - 1; i >= 0; --i) { stack->push(StackRecord(&call_expr->args(i), source_info, expr, i)); } // Are we receiver-style? if (call_expr->has_target()) { stack->push(StackRecord(&call_expr->target(), source_info, expr, StackRecord::kTarget)); } } void PushListDeps(const CreateList* list_expr, const SourceInfo* source_info, std::stack* stack) { const auto& elements = list_expr->elements(); for (auto it = elements.rbegin(); it != elements.rend(); ++it) { const auto& element = *it; stack->push(StackRecord(&element, source_info)); } } void PushStructDeps(const CreateStruct* struct_expr, const SourceInfo* source_info, std::stack* stack) { const auto& entries = struct_expr->entries(); for (auto it = entries.rbegin(); it != entries.rend(); ++it) { const auto& entry = *it; // The contract is to visit key, then value. So put them on the stack // in the opposite order. if (entry.has_value()) { stack->push(StackRecord(&entry.value(), source_info)); } if (entry.has_map_key()) { stack->push(StackRecord(&entry.map_key(), source_info)); } } } void PushComprehensionDeps(const Comprehension* c, const Expr* expr, const SourceInfo* source_info, std::stack* stack, bool use_comprehension_callbacks) { StackRecord iter_range(&c->iter_range(), source_info, c, expr, ITER_RANGE, use_comprehension_callbacks); StackRecord accu_init(&c->accu_init(), source_info, c, expr, ACCU_INIT, use_comprehension_callbacks); StackRecord loop_condition(&c->loop_condition(), source_info, c, expr, LOOP_CONDITION, use_comprehension_callbacks); StackRecord loop_step(&c->loop_step(), source_info, c, expr, LOOP_STEP, use_comprehension_callbacks); StackRecord result(&c->result(), source_info, c, expr, RESULT, use_comprehension_callbacks); // Push them in reverse order. stack->push(result); stack->push(loop_step); stack->push(loop_condition); stack->push(accu_init); stack->push(iter_range); } struct PushDepsVisitor { void operator()(const ExprRecord& record) { const Expr* expr = record.expr; switch (expr->expr_kind_case()) { case Expr::kSelectExpr: PushSelectDeps(&expr->select_expr(), record.source_info, &stack); break; case Expr::kCallExpr: PushCallDeps(&expr->call_expr(), expr, record.source_info, &stack); break; case Expr::kListExpr: PushListDeps(&expr->list_expr(), record.source_info, &stack); break; case Expr::kStructExpr: PushStructDeps(&expr->struct_expr(), record.source_info, &stack); break; case Expr::kComprehensionExpr: PushComprehensionDeps(&expr->comprehension_expr(), expr, record.source_info, &stack, options.use_comprehension_callbacks); break; default: break; } } void operator()(const ArgRecord& record) { stack.push(StackRecord(record.expr, record.source_info)); } void operator()(const ComprehensionRecord& record) { stack.push(StackRecord(record.expr, record.source_info)); } std::stack& stack; const TraversalOptions& options; }; void PushDependencies(const StackRecord& record, std::stack& stack, const TraversalOptions& options) { absl::visit(PushDepsVisitor{stack, options}, record.record_variant); } } // namespace void AstTraverse(const Expr* expr, const SourceInfo* source_info, AstVisitor* visitor, TraversalOptions options) { std::stack stack; stack.push(StackRecord(expr, source_info)); while (!stack.empty()) { StackRecord& record = stack.top(); if (!record.visited) { PreVisit(record, visitor); PushDependencies(record, stack, options); record.visited = true; } else { PostVisit(record, visitor); stack.pop(); } } } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/ast_traverse.h ================================================ /* * Copyright 2018 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_H_ #include "cel/expr/syntax.pb.h" #include "eval/public/ast_visitor.h" namespace google::api::expr::runtime { struct TraversalOptions { bool use_comprehension_callbacks; TraversalOptions() : use_comprehension_callbacks(false) {} }; // Traverses the AST representation in an expr proto. // // expr: root node of the tree. // source_info: optional additional parse information about the expression // visitor: the callback object that receives the visitation notifications // // Traversal order follows the pattern: // PreVisitExpr // ..PreVisit{ExprKind} // ....PreVisit{ArgumentIndex} // .......PreVisitExpr (subtree) // .......PostVisitExpr (subtree) // ....PostVisit{ArgumentIndex} // ..PostVisit{ExprKind} // PostVisitExpr // // Example callback order for fn(1, var): // PreVisitExpr // ..PreVisitCall(fn) // ......PreVisitExpr // ........PostVisitConst(1) // ......PostVisitExpr // ....PostVisitArg(fn, 0) // ......PreVisitExpr // ........PostVisitIdent(var) // ......PostVisitExpr // ....PostVisitArg(fn, 1) // ..PostVisitCall(fn) // PostVisitExpr void AstTraverse(const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info, AstVisitor* visitor, TraversalOptions options = TraversalOptions()); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_TRAVERSE_H_ ================================================ FILE: eval/public/ast_traverse_test.cc ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/ast_traverse.h" #include "eval/public/ast_visitor.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { using cel::expr::Constant; using cel::expr::Expr; using cel::expr::SourceInfo; using testing::_; using Ident = cel::expr::Expr::Ident; using Select = cel::expr::Expr::Select; using Call = cel::expr::Expr::Call; using CreateList = cel::expr::Expr::CreateList; using CreateStruct = cel::expr::Expr::CreateStruct; using Comprehension = cel::expr::Expr::Comprehension; class MockAstVisitor : public AstVisitor { public: // Expr handler. MOCK_METHOD(void, PreVisitExpr, (const Expr* expr, const SourcePosition* position), (override)); // Expr handler. MOCK_METHOD(void, PostVisitExpr, (const Expr* expr, const SourcePosition* position), (override)); // Constant node handler. MOCK_METHOD(void, PreVisitConst, (const Constant* const_expr, const Expr* expr, const SourcePosition* position), (override)); // Constant node handler. MOCK_METHOD(void, PostVisitConst, (const Constant* const_expr, const Expr* expr, const SourcePosition* position), (override)); // Ident node handler. MOCK_METHOD(void, PreVisitIdent, (const Ident* ident_expr, const Expr* expr, const SourcePosition* position), (override)); // Ident node handler. MOCK_METHOD(void, PostVisitIdent, (const Ident* ident_expr, const Expr* expr, const SourcePosition* position), (override)); // Select node handler group MOCK_METHOD(void, PreVisitSelect, (const Select* select_expr, const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitSelect, (const Select* select_expr, const Expr* expr, const SourcePosition* position), (override)); // Call node handler group MOCK_METHOD(void, PreVisitCall, (const Call* call_expr, const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitCall, (const Call* call_expr, const Expr* expr, const SourcePosition* position), (override)); // Comprehension node handler group MOCK_METHOD(void, PreVisitComprehension, (const Comprehension* comprehension_expr, const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitComprehension, (const Comprehension* comprehension_expr, const Expr* expr, const SourcePosition* position), (override)); // Comprehension node handler group MOCK_METHOD(void, PreVisitComprehensionSubexpression, (const Expr* expr, const Comprehension* comprehension_expr, ComprehensionArg comprehension_arg, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitComprehensionSubexpression, (const Expr* expr, const Comprehension* comprehension_expr, ComprehensionArg comprehension_arg, const SourcePosition* position), (override)); // We provide finer granularity for Call and Comprehension node callbacks // to allow special handling for short-circuiting. MOCK_METHOD(void, PostVisitTarget, (const Expr* expr, const SourcePosition* position), (override)); MOCK_METHOD(void, PostVisitArg, (int arg_num, const Expr* expr, const SourcePosition* position), (override)); // CreateList node handler group MOCK_METHOD(void, PreVisitCreateList, (const CreateList* list_expr, const Expr* expr, const SourcePosition* position), (override)); // CreateList node handler group MOCK_METHOD(void, PostVisitCreateList, (const CreateList* list_expr, const Expr* expr, const SourcePosition* position), (override)); // CreateStruct node handler group MOCK_METHOD(void, PreVisitCreateStruct, (const CreateStruct* struct_expr, const Expr* expr, const SourcePosition* position), (override)); // CreateStruct node handler group MOCK_METHOD(void, PostVisitCreateStruct, (const CreateStruct* struct_expr, const Expr* expr, const SourcePosition* position), (override)); }; TEST(AstCrawlerTest, CheckCrawlConstant) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto const_expr = expr.mutable_const_expr(); EXPECT_CALL(handler, PreVisitConst(const_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitConst(const_expr, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } TEST(AstCrawlerTest, CheckCrawlIdent) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto ident_expr = expr.mutable_ident_expr(); EXPECT_CALL(handler, PreVisitIdent(ident_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitIdent(ident_expr, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } // Test handling of Select node when operand is not set. TEST(AstCrawlerTest, CheckCrawlSelectNotCrashingPostVisitAbsentOperand) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto select_expr = expr.mutable_select_expr(); // Lowest level entry will be called first EXPECT_CALL(handler, PostVisitSelect(select_expr, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } // Test handling of Select node TEST(AstCrawlerTest, CheckCrawlSelect) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto select_expr = expr.mutable_select_expr(); auto operand = select_expr->mutable_operand(); auto ident_expr = operand->mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PostVisitIdent(ident_expr, operand, _)).Times(1); EXPECT_CALL(handler, PostVisitSelect(select_expr, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } // Test handling of Call node without receiver TEST(AstCrawlerTest, CheckCrawlCallNoReceiver) { SourceInfo source_info; MockAstVisitor handler; // (, ) Expr expr; auto* call_expr = expr.mutable_call_expr(); Expr* arg0 = call_expr->add_args(); auto* const_expr = arg0->mutable_const_expr(); Expr* arg1 = call_expr->add_args(); auto* ident_expr = arg1->mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(call_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitTarget(_, _)).Times(0); // Arg0 EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); // Arg1 EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); // Back to call EXPECT_CALL(handler, PostVisitCall(call_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } // Test handling of Call node with receiver TEST(AstCrawlerTest, CheckCrawlCallReceiver) { SourceInfo source_info; MockAstVisitor handler; // .(, ) Expr expr; auto* call_expr = expr.mutable_call_expr(); Expr* target = call_expr->mutable_target(); auto* target_ident = target->mutable_ident_expr(); Expr* arg0 = call_expr->add_args(); auto* const_expr = arg0->mutable_const_expr(); Expr* arg1 = call_expr->add_args(); auto* ident_expr = arg1->mutable_ident_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitCall(call_expr, &expr, _)).Times(1); // Target EXPECT_CALL(handler, PostVisitIdent(target_ident, target, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(target, _)).Times(1); EXPECT_CALL(handler, PostVisitTarget(&expr, _)).Times(1); // Arg0 EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(0, &expr, _)).Times(1); // Arg1 EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(1, &expr, _)).Times(1); // Back to call EXPECT_CALL(handler, PostVisitCall(call_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitExpr(&expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } // Test handling of Comprehension node TEST(AstCrawlerTest, CheckCrawlComprehension) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto c = expr.mutable_comprehension_expr(); auto iter_range = c->mutable_iter_range(); auto iter_range_expr = iter_range->mutable_const_expr(); auto accu_init = c->mutable_accu_init(); auto accu_init_expr = accu_init->mutable_ident_expr(); auto loop_condition = c->mutable_loop_condition(); auto loop_condition_expr = loop_condition->mutable_const_expr(); auto loop_step = c->mutable_loop_step(); auto loop_step_expr = loop_step->mutable_ident_expr(); auto result = c->mutable_result(); auto result_expr = result->mutable_const_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(c, &expr, _)).Times(1); EXPECT_CALL(handler, PreVisitComprehensionSubexpression(iter_range, c, ITER_RANGE, _)) .Times(1); EXPECT_CALL(handler, PostVisitConst(iter_range_expr, iter_range, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(iter_range, c, ITER_RANGE, _)) .Times(1); // ACCU_INIT EXPECT_CALL(handler, PreVisitComprehensionSubexpression(accu_init, c, ACCU_INIT, _)) .Times(1); EXPECT_CALL(handler, PostVisitIdent(accu_init_expr, accu_init, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(accu_init, c, ACCU_INIT, _)) .Times(1); // LOOP CONDITION EXPECT_CALL(handler, PreVisitComprehensionSubexpression(loop_condition, c, LOOP_CONDITION, _)) .Times(1); EXPECT_CALL(handler, PostVisitConst(loop_condition_expr, loop_condition, _)) .Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(loop_condition, c, LOOP_CONDITION, _)) .Times(1); // LOOP STEP EXPECT_CALL(handler, PreVisitComprehensionSubexpression(loop_step, c, LOOP_STEP, _)) .Times(1); EXPECT_CALL(handler, PostVisitIdent(loop_step_expr, loop_step, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(loop_step, c, LOOP_STEP, _)) .Times(1); // RESULT EXPECT_CALL(handler, PreVisitComprehensionSubexpression(result, c, RESULT, _)) .Times(1); EXPECT_CALL(handler, PostVisitConst(result_expr, result, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehensionSubexpression(result, c, RESULT, _)) .Times(1); EXPECT_CALL(handler, PostVisitComprehension(c, &expr, _)).Times(1); TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(&expr, &source_info, &handler, opts); } // Test handling of Comprehension node TEST(AstCrawlerTest, CheckCrawlComprehensionLegacyCallbacks) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto c = expr.mutable_comprehension_expr(); auto iter_range = c->mutable_iter_range(); auto iter_range_expr = iter_range->mutable_const_expr(); auto accu_init = c->mutable_accu_init(); auto accu_init_expr = accu_init->mutable_ident_expr(); auto loop_condition = c->mutable_loop_condition(); auto loop_condition_expr = loop_condition->mutable_const_expr(); auto loop_step = c->mutable_loop_step(); auto loop_step_expr = loop_step->mutable_ident_expr(); auto result = c->mutable_result(); auto result_expr = result->mutable_const_expr(); testing::InSequence seq; // Lowest level entry will be called first EXPECT_CALL(handler, PreVisitComprehension(c, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitConst(iter_range_expr, iter_range, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(ITER_RANGE, &expr, _)).Times(1); // ACCU_INIT EXPECT_CALL(handler, PostVisitIdent(accu_init_expr, accu_init, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(ACCU_INIT, &expr, _)).Times(1); // LOOP CONDITION EXPECT_CALL(handler, PostVisitConst(loop_condition_expr, loop_condition, _)) .Times(1); EXPECT_CALL(handler, PostVisitArg(LOOP_CONDITION, &expr, _)).Times(1); // LOOP STEP EXPECT_CALL(handler, PostVisitIdent(loop_step_expr, loop_step, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(LOOP_STEP, &expr, _)).Times(1); // RESULT EXPECT_CALL(handler, PostVisitConst(result_expr, result, _)).Times(1); EXPECT_CALL(handler, PostVisitArg(RESULT, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitComprehension(c, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } // Test handling of CreateList node. TEST(AstCrawlerTest, CheckCreateList) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto list_expr = expr.mutable_list_expr(); auto arg0 = list_expr->add_elements(); auto const_expr = arg0->mutable_const_expr(); auto arg1 = list_expr->add_elements(); auto ident_expr = arg1->mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PreVisitCreateList(list_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitConst(const_expr, arg0, _)).Times(1); EXPECT_CALL(handler, PostVisitIdent(ident_expr, arg1, _)).Times(1); EXPECT_CALL(handler, PostVisitCreateList(list_expr, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } // Test handling of CreateStruct node. TEST(AstCrawlerTest, CheckCreateStruct) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto struct_expr = expr.mutable_struct_expr(); auto entry0 = struct_expr->add_entries(); auto key = entry0->mutable_map_key()->mutable_const_expr(); auto value = entry0->mutable_value()->mutable_ident_expr(); testing::InSequence seq; EXPECT_CALL(handler, PreVisitCreateStruct(struct_expr, &expr, _)).Times(1); EXPECT_CALL(handler, PostVisitConst(key, &entry0->map_key(), _)).Times(1); EXPECT_CALL(handler, PostVisitIdent(value, &entry0->value(), _)).Times(1); EXPECT_CALL(handler, PostVisitCreateStruct(struct_expr, &expr, _)).Times(1); AstTraverse(&expr, &source_info, &handler); } // Test generic Expr handlers. TEST(AstCrawlerTest, CheckExprHandlers) { SourceInfo source_info; MockAstVisitor handler; Expr expr; auto struct_expr = expr.mutable_struct_expr(); auto entry0 = struct_expr->add_entries(); entry0->mutable_map_key()->mutable_const_expr(); entry0->mutable_value()->mutable_ident_expr(); EXPECT_CALL(handler, PreVisitExpr(_, _)).Times(3); EXPECT_CALL(handler, PostVisitExpr(_, _)).Times(3); AstTraverse(&expr, &source_info, &handler); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/ast_visitor.h ================================================ /* * Copyright 2018 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_H_ #include "cel/expr/syntax.pb.h" #include "eval/public/source_position.h" namespace google { namespace api { namespace expr { namespace runtime { // ComprehensionArg specifies arg_num values passed to PostVisitArg // for subexpressions of Comprehension. enum ComprehensionArg { ITER_RANGE, ACCU_INIT, LOOP_CONDITION, LOOP_STEP, RESULT, }; // Callback handler class, used in conjunction with AstTraverse. // Methods of this class are invoked when AST nodes with corresponding // types are processed. // // For all types with children, the children will be visited in the natural // order from first to last. For structs, keys are visited before values. class AstVisitor { public: virtual ~AstVisitor() {} // Expr node handler method. Called for all Expr nodes. // Is invoked before child Expr nodes being processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. virtual void PreVisitExpr(const cel::expr::Expr*, const SourcePosition*) {} // Expr node handler method. Called for all Expr nodes. // Is invoked after child Expr nodes are processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. virtual void PostVisitExpr(const cel::expr::Expr*, const SourcePosition*) {} // Const node handler. // Invoked before child nodes are processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. virtual void PreVisitConst(const cel::expr::Constant*, const cel::expr::Expr*, const SourcePosition*) {} // Const node handler. // Invoked after child nodes are processed. virtual void PostVisitConst(const cel::expr::Constant*, const cel::expr::Expr*, const SourcePosition*) = 0; // Ident node handler. // Invoked before child nodes are processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. virtual void PreVisitIdent(const cel::expr::Expr::Ident*, const cel::expr::Expr*, const SourcePosition*) {} // Ident node handler. // Invoked after child nodes are processed. virtual void PostVisitIdent(const cel::expr::Expr::Ident*, const cel::expr::Expr*, const SourcePosition*) = 0; // Select node handler // Invoked before child nodes are processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. virtual void PreVisitSelect(const cel::expr::Expr::Select*, const cel::expr::Expr*, const SourcePosition*) {} // Select node handler // Invoked after child nodes are processed. virtual void PostVisitSelect(const cel::expr::Expr::Select*, const cel::expr::Expr*, const SourcePosition*) = 0; // Call node handler group // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. virtual void PreVisitCall(const cel::expr::Expr::Call*, const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked after all child nodes are processed. virtual void PostVisitCall(const cel::expr::Expr::Call*, const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked after target node is processed. // Expr is the call expression. virtual void PostVisitTarget(const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked before all child nodes are processed. virtual void PreVisitComprehension( const cel::expr::Expr::Comprehension*, const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked before comprehension child node is processed. virtual void PreVisitComprehensionSubexpression( const cel::expr::Expr* subexpr, const cel::expr::Expr::Comprehension* compr, ComprehensionArg comprehension_arg, const SourcePosition*) {} // Invoked after comprehension child node is processed. virtual void PostVisitComprehensionSubexpression( const cel::expr::Expr* subexpr, const cel::expr::Expr::Comprehension* compr, ComprehensionArg comprehension_arg, const SourcePosition*) {} // Invoked after all child nodes are processed. virtual void PostVisitComprehension( const cel::expr::Expr::Comprehension*, const cel::expr::Expr*, const SourcePosition*) = 0; // Invoked after each argument node processed. // For Call arg_num is the index of the argument. // For Comprehension arg_num is specified by ComprehensionArg. // Expr is the call expression. virtual void PostVisitArg(int arg_num, const cel::expr::Expr*, const SourcePosition*) = 0; // CreateList node handler // Invoked before child nodes are processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. virtual void PreVisitCreateList(const cel::expr::Expr::CreateList*, const cel::expr::Expr*, const SourcePosition*) {} // CreateList node handler // Invoked after child nodes are processed. virtual void PostVisitCreateList(const cel::expr::Expr::CreateList*, const cel::expr::Expr*, const SourcePosition*) = 0; // CreateStruct node handler // Invoked before child nodes are processed. // TODO(issues/22): this method is not pure virtual to avoid dependencies // breakage. Change it in subsequent CLs. virtual void PreVisitCreateStruct( const cel::expr::Expr::CreateStruct*, const cel::expr::Expr*, const SourcePosition*) {} // CreateStruct node handler // Invoked after child nodes are processed. virtual void PostVisitCreateStruct( const cel::expr::Expr::CreateStruct*, const cel::expr::Expr*, const SourcePosition*) = 0; }; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_H_ ================================================ FILE: eval/public/ast_visitor_base.h ================================================ /* * Copyright 2018 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_BASE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_BASE_H_ #include "eval/public/ast_visitor.h" #include "cel/expr/syntax.pb.h" namespace google { namespace api { namespace expr { namespace runtime { // Trivial base implementation of AstVisitor. class AstVisitorBase : public AstVisitor { public: AstVisitorBase() {} // Non-copyable AstVisitorBase(const AstVisitorBase&) = delete; AstVisitorBase& operator=(AstVisitorBase const&) = delete; ~AstVisitorBase() override {} // Const node handler. // Invoked after child nodes are processed. void PostVisitConst(const cel::expr::Constant*, const cel::expr::Expr*, const SourcePosition*) override {} // Ident node handler. // Invoked after child nodes are processed. void PostVisitIdent(const cel::expr::Expr::Ident*, const cel::expr::Expr*, const SourcePosition*) override {} // Select node handler // Invoked after child nodes are processed. void PostVisitSelect(const cel::expr::Expr::Select*, const cel::expr::Expr*, const SourcePosition*) override {} // Call node handler group // We provide finer granularity for Call node callbacks to allow special // handling for short-circuiting // PreVisitCall is invoked before child nodes are processed. void PreVisitCall(const cel::expr::Expr::Call*, const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after all child nodes are processed. void PostVisitCall(const cel::expr::Expr::Call*, const cel::expr::Expr*, const SourcePosition*) override {} // Invoked before all child nodes are processed. void PreVisitComprehension(const cel::expr::Expr::Comprehension*, const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after all child nodes are processed. void PostVisitComprehension(const cel::expr::Expr::Comprehension*, const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after each argument node processed. // For Call arg_num is the index of the argument. // For Comprehension arg_num is specified by ComprehensionArg. // Expr is the call expression. void PostVisitArg(int, const cel::expr::Expr*, const SourcePosition*) override {} // Invoked after target node processed. void PostVisitTarget(const cel::expr::Expr*, const SourcePosition*) override {} // CreateList node handler // Invoked after child nodes are processed. void PostVisitCreateList(const cel::expr::Expr::CreateList*, const cel::expr::Expr*, const SourcePosition*) override {} // CreateStruct node handler // Invoked after child nodes are processed. void PostVisitCreateStruct(const cel::expr::Expr::CreateStruct*, const cel::expr::Expr*, const SourcePosition*) override {} }; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_AST_VISITOR_BASE_H_ ================================================ FILE: eval/public/base_activation.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BASE_ACTIVATION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BASE_ACTIVATION_H_ #include #include "google/protobuf/field_mask.pb.h" #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" #include "eval/public/cel_value.h" #include "runtime/internal/attribute_matcher.h" namespace cel::runtime_internal { class ActivationAttributeMatcherAccess; } namespace google::api::expr::runtime { // Base class for an activation. class BaseActivation { public: BaseActivation() = default; // Non-copyable/non-assignable BaseActivation(const BaseActivation&) = delete; BaseActivation& operator=(const BaseActivation&) = delete; // Move-constructible/move-assignable BaseActivation(BaseActivation&& other) = default; BaseActivation& operator=(BaseActivation&& other) = default; // Return a list of function overloads for the given name. virtual std::vector FindFunctionOverloads( absl::string_view) const = 0; // Provide the value that is bound to the name, if found. // arena parameter is provided to support the case when we want to pass the // ownership of returned object ( Message/List/Map ) to Evaluator. virtual absl::optional FindValue(absl::string_view, google::protobuf::Arena*) const = 0; // Return the collection of attribute patterns that determine missing // attributes. virtual const std::vector& missing_attribute_patterns() const { static const std::vector* empty = new std::vector({}); return *empty; } // Return the collection of attribute patterns that determine "unknown" // values. virtual const std::vector& unknown_attribute_patterns() const { static const std::vector* empty = new std::vector({}); return *empty; } virtual ~BaseActivation() = default; private: friend class cel::runtime_internal::ActivationAttributeMatcherAccess; // Internal getter for overriding the attribute matching behavior. virtual const cel::runtime_internal::AttributeMatcher* absl_nullable GetAttributeMatcher() const { return nullptr; } }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BASE_ACTIVATION_H_ ================================================ FILE: eval/public/builtin_func_registrar.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/builtin_func_registrar.h" #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "runtime/standard/arithmetic_functions.h" #include "runtime/standard/comparison_functions.h" #include "runtime/standard/container_functions.h" #include "runtime/standard/container_membership_functions.h" #include "runtime/standard/equality_functions.h" #include "runtime/standard/logical_functions.h" #include "runtime/standard/regex_functions.h" #include "runtime/standard/string_functions.h" #include "runtime/standard/time_functions.h" #include "runtime/standard/type_conversion_functions.h" namespace google::api::expr::runtime { absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { cel::FunctionRegistry& modern_registry = registry->InternalGetRegistry(); cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); CEL_RETURN_IF_ERROR( cel::RegisterLogicalFunctions(modern_registry, runtime_options)); CEL_RETURN_IF_ERROR( cel::RegisterComparisonFunctions(modern_registry, runtime_options)); CEL_RETURN_IF_ERROR( cel::RegisterContainerFunctions(modern_registry, runtime_options)); CEL_RETURN_IF_ERROR(cel::RegisterContainerMembershipFunctions( modern_registry, runtime_options)); CEL_RETURN_IF_ERROR( cel::RegisterTypeConversionFunctions(modern_registry, runtime_options)); CEL_RETURN_IF_ERROR( cel::RegisterArithmeticFunctions(modern_registry, runtime_options)); CEL_RETURN_IF_ERROR( cel::RegisterTimeFunctions(modern_registry, runtime_options)); CEL_RETURN_IF_ERROR( cel::RegisterStringFunctions(modern_registry, runtime_options)); CEL_RETURN_IF_ERROR( cel::RegisterRegexFunctions(modern_registry, runtime_options)); CEL_RETURN_IF_ERROR( cel::RegisterEqualityFunctions(modern_registry, runtime_options)); return absl::OkStatus(); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/builtin_func_registrar.h ================================================ // Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" namespace google::api::expr::runtime { absl::Status RegisterBuiltinFunctions( CelFunctionRegistry* registry, const InterpreterOptions& options = InterpreterOptions()); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ ================================================ FILE: eval/public/builtin_func_registrar_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/builtin_func_registrar.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "eval/public/activation.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "internal/time.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using cel::expr::Expr; using cel::expr::SourceInfo; using ::absl_testing::StatusIs; using ::cel::internal::MaxDuration; using ::cel::internal::MinDuration; using ::testing::HasSubstr; struct TestCase { std::string test_name; std::string expr; absl::flat_hash_map vars; absl::StatusOr result = CelValue::CreateBool(true); InterpreterOptions options; }; InterpreterOptions OverflowChecksEnabled() { static InterpreterOptions options; options.enable_timestamp_duration_overflow_errors = true; return options; } void ExpectResult(const TestCase& test_case) { auto parsed_expr = parser::Parse(test_case.expr); ASSERT_OK(parsed_expr); const Expr& expr_ast = parsed_expr->expr(); const SourceInfo& source_info = parsed_expr->source_info(); std::unique_ptr builder = CreateCelExpressionBuilder(test_case.options); ASSERT_OK( RegisterBuiltinFunctions(builder->GetRegistry(), test_case.options)); ASSERT_OK_AND_ASSIGN(auto cel_expression, builder->CreateExpression(&expr_ast, &source_info)); Activation activation; for (auto var : test_case.vars) { activation.InsertValue(var.first, var.second); } google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(auto value, cel_expression->Evaluate(activation, &arena)); if (!test_case.result.ok()) { EXPECT_TRUE(value.IsError()) << value.DebugString(); EXPECT_THAT(*value.ErrorOrDie(), StatusIs(test_case.result.status().code(), HasSubstr(test_case.result.status().message()))); return; } EXPECT_THAT(value, test::EqualsCelValue(*test_case.result)); } using BuiltinFuncParamsTest = testing::TestWithParam; TEST_P(BuiltinFuncParamsTest, StandardFunctions) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( BuiltinFuncParamsTest, BuiltinFuncParamsTest, testing::ValuesIn({ // Legacy duration and timestamp arithmetic tests. {"TimeSubTimeLegacy", "t0 - t1 == duration('90s90ns')", { {"t0", CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + absl::Nanoseconds(100))}, {"t1", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + absl::Nanoseconds(10))}, }}, {"TimeSubDurationLegacy", "t0 - duration('90s90ns')", { {"t0", CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + absl::Nanoseconds(100))}, }, CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + absl::Nanoseconds(10))}, {"TimeAddDurationLegacy", "t + duration('90s90ns')", {{"t", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + absl::Nanoseconds(10))}}, CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + absl::Nanoseconds(100))}, {"DurationAddTimeLegacy", "duration('90s90ns') + t == t + duration('90s90ns')", {{"t", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + absl::Nanoseconds(10))}}}, {"DurationAddDurationLegacy", "duration('80s80ns') + duration('10s10ns') == duration('90s90ns')"}, {"DurationSubDurationLegacy", "duration('90s90ns') - duration('80s80ns') == duration('10s10ns')"}, {"MinDurationSubDurationLegacy", "min - duration('1ns') < duration('-87660000h')", {{"min", CelValue::CreateDuration(MinDuration())}}}, {"MaxDurationAddDurationLegacy", "max + duration('1ns') > duration('87660000h')", {{"max", CelValue::CreateDuration(MaxDuration())}}}, {"TimestampConversionFromStringLegacy", "timestamp('10000-01-02T00:00:00Z') > " "timestamp('9999-01-01T00:00:00Z')"}, {"TimestampFromUnixEpochSeconds", "timestamp(123) > timestamp('1970-01-01T00:02:02.999999999Z') && " "timestamp(123) == timestamp('1970-01-01T00:02:03Z') && " "timestamp(123) < timestamp('1970-01-01T00:02:03.000000001Z')"}, // Timestamp duration tests with fixes enabled for overflow checking. {"TimeSubTime", "t0 - t1 == duration('90s90ns')", { {"t0", CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + absl::Nanoseconds(100))}, {"t1", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + absl::Nanoseconds(10))}, }, CelValue::CreateBool(true), OverflowChecksEnabled()}, {"TimeSubDuration", "t0 - duration('90s90ns')", { {"t0", CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + absl::Nanoseconds(100))}, }, CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + absl::Nanoseconds(10)), OverflowChecksEnabled()}, {"TimeSubtractDurationOverflow", "timestamp('0001-01-01T00:00:00Z') - duration('1ns')", {}, absl::OutOfRangeError("timestamp overflow"), OverflowChecksEnabled()}, {"TimeAddDuration", "t + duration('90s90ns')", {{"t", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + absl::Nanoseconds(10))}}, CelValue::CreateTimestamp(absl::FromUnixSeconds(100) + absl::Nanoseconds(100)), OverflowChecksEnabled()}, {"TimeAddDurationOverflow", "timestamp('9999-12-31T23:59:59.999999999Z') + duration('1ns')", {}, absl::OutOfRangeError("timestamp overflow"), OverflowChecksEnabled()}, {"DurationAddTime", "duration('90s90ns') + t == t + duration('90s90ns')", {{"t", CelValue::CreateTimestamp(absl::FromUnixSeconds(10) + absl::Nanoseconds(10))}}, CelValue::CreateBool(true), OverflowChecksEnabled()}, {"DurationAddTimeOverflow", "duration('1ns') + timestamp('9999-12-31T23:59:59.999999999Z')", {}, absl::OutOfRangeError("timestamp overflow"), OverflowChecksEnabled()}, {"DurationAddDuration", "duration('80s80ns') + duration('10s10ns') == duration('90s90ns')", {}, CelValue::CreateBool(true), OverflowChecksEnabled()}, {"DurationSubDuration", "duration('90s90ns') - duration('80s80ns') == duration('10s10ns')", {}, CelValue::CreateBool(true), OverflowChecksEnabled()}, {"MinDurationSubDuration", "min - duration('1ns')", {{"min", CelValue::CreateDuration(MinDuration())}}, absl::OutOfRangeError("overflow"), OverflowChecksEnabled()}, {"MaxDurationAddDuration", "max + duration('1ns')", {{"max", CelValue::CreateDuration(MaxDuration())}}, absl::OutOfRangeError("overflow"), OverflowChecksEnabled()}, // Timestamp conversion overflow checks. {"TimestampConversionFromStringOverflow", "timestamp('10000-01-02T00:00:00Z')", {}, absl::OutOfRangeError("timestamp overflow"), OverflowChecksEnabled()}, {"TimestampConversionFromStringUnderflow", "timestamp('0000-01-01T00:00:00Z')", {}, absl::OutOfRangeError("timestamp overflow"), OverflowChecksEnabled()}, // List concatenation tests. {"ListConcatEmptyInputs", "[] + [] == []", {}, CelValue::CreateBool(true), OverflowChecksEnabled()}, {"ListConcatRightEmpty", "[1] + [] == [1]", {}, CelValue::CreateBool(true), OverflowChecksEnabled()}, {"ListConcatLeftEmpty", "[] + [1] == [1]", {}, CelValue::CreateBool(true), OverflowChecksEnabled()}, {"ListConcat", "[2] + [1] == [2, 1]", {}, CelValue::CreateBool(true), OverflowChecksEnabled()}, {"StringToBool", "string(true) + string(false)", {}, CelValue::CreateStringView("truefalse"), OverflowChecksEnabled()}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; }); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/builtin_func_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/time.h" namespace google::api::expr::runtime { namespace { using google::protobuf::Duration; using google::protobuf::Timestamp; using cel::expr::Expr; using cel::expr::SourceInfo; using google::protobuf::Arena; using ::cel::internal::MaxDuration; using ::cel::internal::MinDuration; using ::cel::internal::MinTimestamp; using ::testing::Eq; class BuiltinsTest : public ::testing::Test { protected: BuiltinsTest() {} // Helper method. Looks up in registry and tests comparison operation. void PerformRun(absl::string_view operation, absl::optional target, const std::vector& values, CelValue* result) { PerformRun(operation, target, values, result, options_); } // Helper method. Looks up in registry and tests comparison operation. void PerformRun(absl::string_view operation, absl::optional target, const std::vector& values, CelValue* result, const InterpreterOptions& options) { Activation activation; Expr expr; SourceInfo source_info; auto call = expr.mutable_call_expr(); call->set_function(operation); if (target.has_value()) { std::string param_name = "target"; activation.InsertValue(param_name, target.value()); auto target_arg = call->mutable_target(); auto ident = target_arg->mutable_ident_expr(); ident->set_name(param_name); } int counter = 0; for (const auto& value : values) { std::string param_name = absl::StrCat("param_", counter++); activation.InsertValue(param_name, value); auto arg = call->add_args(); auto ident = arg->mutable_ident_expr(); ident->set_name(param_name); } // Obtain CEL Expression builder. std::unique_ptr builder = CreateCelExpressionBuilder(options); // Builtin registration. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expression, builder->CreateExpression(&expr, &source_info)); ASSERT_OK_AND_ASSIGN(auto value, cel_expression->Evaluate(activation, &arena_)); *result = value; } // Helper method. Looks up in registry and tests comparison operation. void TestComparison(absl::string_view operation, const CelValue& ref, const CelValue& other, bool result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(operation, {}, {ref, other}, &result_value)); ASSERT_EQ(result_value.IsBool(), true) << absl::StrCat(CelValue::TypeName(ref.type()), " ", operation, " ", CelValue::TypeName(other.type())); ASSERT_EQ(result_value.BoolOrDie(), result) << operation << " for " << ref.DebugString() << " with " << other.DebugString(); } // Helper method. Looks up in registry and tests for no matching equality // overload. void TestNoMatchingEqualOverload(const CelValue& ref, const CelValue& other) { options_.enable_heterogeneous_equality = false; CelValue eq_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kEqual, {}, {ref, other}, &eq_value, options_)); ASSERT_TRUE(eq_value.IsError()) << " for " << CelValue::TypeName(ref.type()) << " and " << CelValue::TypeName(other.type()); EXPECT_TRUE(CheckNoMatchingOverloadError(eq_value)); CelValue ineq_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kInequal, {}, {ref, other}, &ineq_value, options_)); ASSERT_TRUE(ineq_value.IsError()) << " for " << CelValue::TypeName(ref.type()) << " and " << CelValue::TypeName(other.type()); EXPECT_TRUE(CheckNoMatchingOverloadError(ineq_value)); } // Helper method. Looks up in registry and tests Type conversions. void TestTypeConverts(absl::string_view operation, const CelValue& ref, CelValue::BytesHolder result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); ASSERT_EQ(result_value.IsBytes(), true); ASSERT_EQ(result_value.BytesOrDie(), result) << operation << " for " << CelValue::TypeName(ref.type()); } // Helper method. Looks up in registry and tests Type conversions. void TestTypeConverts(absl::string_view operation, const CelValue& ref, CelValue::StringHolder result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); ASSERT_EQ(result_value.IsString(), true); ASSERT_EQ(result_value.StringOrDie().value(), result.value()) << operation << " for " << CelValue::TypeName(ref.type()); } void TestTypeConverts(absl::string_view operation, const CelValue& ref, double result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); ASSERT_EQ(result_value.IsDouble(), true); ASSERT_EQ(result_value.DoubleOrDie(), result) << operation << " for " << CelValue::TypeName(ref.type()); } void TestTypeConverts(absl::string_view operation, const CelValue& ref, int64_t result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), result) << operation << " for " << CelValue::TypeName(ref.type()); } void TestTypeConverts(absl::string_view operation, const CelValue& ref, uint64_t result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); ASSERT_EQ(result_value.IsUint64(), true); ASSERT_EQ(result_value.Uint64OrDie(), result) << operation << " for " << CelValue::TypeName(ref.type()); } void TestTypeConverts(absl::string_view operation, const CelValue& ref, Duration& result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); ASSERT_EQ(result_value.IsDuration(), true); ASSERT_EQ(result_value.DurationOrDie(), CelProtoWrapper::CreateDuration(&result).DurationOrDie()) << operation << " for " << CelValue::TypeName(ref.type()); } // Helper method. Attempts to perform a type conversion and expects an error // as the result. void TestTypeConversionError(absl::string_view operation, const CelValue& ref) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {}, {ref}, &result_value)); ASSERT_EQ(result_value.IsError(), true) << result_value.DebugString(); } // Helper method. Looks up in registry and tests functions without params. void TestFunctions(absl::string_view operation, const CelValue& ref, int64_t result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(operation, {ref}, {}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), result) << operation << " for " << CelValue::TypeName(ref.type()); } // Helper method. Looks up in registry and tests functions with params. void TestFunctionsWithParams(absl::string_view operation, const CelValue& ref, const std::vector& params, int64_t result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(operation, {ref}, {params}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), result) << operation << " for " << CelValue::TypeName(ref.type()); } // Helper method to test && and || operations void TestLogicalOperation(absl::string_view operation, bool v1, bool v2, bool result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( operation, {}, {CelValue::CreateBool(v1), CelValue::CreateBool(v2)}, &result_value)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) << operation; } void TestComparisonsForType(CelValue::Type kind, const CelValue& ref, const CelValue& lesser) { std::string type_name = CelValue::TypeName(kind); TestComparison(builtin::kEqual, ref, ref, true); TestComparison(builtin::kEqual, ref, lesser, false); TestComparison(builtin::kInequal, ref, ref, false); TestComparison(builtin::kInequal, ref, lesser, true); TestComparison(builtin::kLess, ref, ref, false); TestComparison(builtin::kLess, ref, lesser, false); TestComparison(builtin::kLess, lesser, ref, true); TestComparison(builtin::kLessOrEqual, ref, ref, true); TestComparison(builtin::kLessOrEqual, ref, lesser, false); TestComparison(builtin::kLessOrEqual, lesser, ref, true); TestComparison(builtin::kGreater, ref, ref, false); TestComparison(builtin::kGreater, ref, lesser, true); TestComparison(builtin::kGreater, lesser, ref, false); TestComparison(builtin::kGreaterOrEqual, ref, ref, true); TestComparison(builtin::kGreaterOrEqual, ref, lesser, true); TestComparison(builtin::kGreaterOrEqual, lesser, ref, false); } // Helper method to test arithmetical operations for Int64 void TestArithmeticalOperationInt64(absl::string_view operation, int64_t v1, int64_t v2, int64_t result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( operation, {}, {CelValue::CreateInt64(v1), CelValue::CreateInt64(v2)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), result) << operation; } // Helper for testing invalid signed integer arithmetic operations. void TestArithmeticalErrorInt64(absl::string_view operation, int64_t v1, int64_t v2, absl::StatusCode code) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( operation, {}, {CelValue::CreateInt64(v1), CelValue::CreateInt64(v2)}, &result_value)); ASSERT_EQ(result_value.IsError(), true); ASSERT_EQ(result_value.ErrorOrDie()->code(), code) << operation; } // Helper method to test arithmetical operations for Uint64 void TestArithmeticalOperationUint64(absl::string_view operation, uint64_t v1, uint64_t v2, uint64_t result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( operation, {}, {CelValue::CreateUint64(v1), CelValue::CreateUint64(v2)}, &result_value)); ASSERT_EQ(result_value.IsUint64(), true); ASSERT_EQ(result_value.Uint64OrDie(), result) << operation; } // Helper for testing invalid unsigned integer arithmetic operations. void TestArithmeticalErrorUint64(absl::string_view operation, uint64_t v1, uint64_t v2, absl::StatusCode code) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( operation, {}, {CelValue::CreateUint64(v1), CelValue::CreateUint64(v2)}, &result_value)); ASSERT_EQ(result_value.IsError(), true); ASSERT_EQ(result_value.ErrorOrDie()->code(), code) << operation; } // Helper method to test arithmetical operations for Double void TestArithmeticalOperationDouble(absl::string_view operation, double v1, double v2, double result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( operation, {}, {CelValue::CreateDouble(v1), CelValue::CreateDouble(v2)}, &result_value)); ASSERT_EQ(result_value.IsDouble(), true); ASSERT_DOUBLE_EQ(result_value.DoubleOrDie(), result) << operation; } void TestInList(const CelList* cel_list, const CelValue& value, bool result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kIn, {}, {value, CelValue::CreateList(cel_list)}, &result_value)); ASSERT_EQ(result_value.IsBool(), true) << result_value.DebugString() << " argument: " << value.DebugString(); ASSERT_EQ(result_value.BoolOrDie(), result) << " for " << CelValue::TypeName(value.type()); } void TestInDeprecatedList(const CelList* cel_list, const CelValue& value, bool result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInDeprecated, {}, {value, CelValue::CreateList(cel_list)}, &result_value)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) << " for " << CelValue::TypeName(value.type()); } void TestInFunctionList(const CelList* cel_list, const CelValue& value, bool result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInFunction, {}, {value, CelValue::CreateList(cel_list)}, &result_value)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) << " for " << CelValue::TypeName(value.type()); } void TestInMap(const CelMap* cel_map, const CelValue& value, bool result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kIn, {}, {value, CelValue::CreateMap(cel_map)}, &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) << " for " << value.DebugString(); } void TestInDeprecatedMap(const CelMap* cel_map, const CelValue& value, bool result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInDeprecated, {}, {value, CelValue::CreateMap(cel_map)}, &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) << " for " << CelValue::TypeName(value.type()); } void TestInFunctionMap(const CelMap* cel_map, const CelValue& value, bool result) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInFunction, {}, {value, CelValue::CreateMap(cel_map)}, &result_value, options_)); ASSERT_EQ(result_value.IsBool(), true); ASSERT_EQ(result_value.BoolOrDie(), result) << " for " << CelValue::TypeName(value.type()); } InterpreterOptions options_; // Arena Arena arena_; }; class HeterogeneousEqualityTest : public BuiltinsTest { public: HeterogeneousEqualityTest() { options_.enable_heterogeneous_equality = true; } }; // Test Not() operation for Bool TEST_F(BuiltinsTest, TestNotOp) { CelValue result; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kNot, {}, {CelValue::CreateBool(true)}, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } // Test negation operation for numeric types. TEST_F(BuiltinsTest, TestNegOp) { CelValue result; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kNeg, {}, {CelValue::CreateInt64(-1)}, &result)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 1); ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kNeg, {}, {CelValue::CreateDouble(-1.0)}, &result)); ASSERT_TRUE(result.IsDouble()); EXPECT_EQ(result.DoubleOrDie(), 1.0); } // Test integer negation overflow. TEST_F(BuiltinsTest, TestNegIntOverflow) { CelValue result; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kNeg, {}, {CelValue::CreateInt64(std::numeric_limits::lowest())}, &result)); ASSERT_TRUE(result.IsError()); } // Test Equality/Non-Equality operation for Bool TEST_F(BuiltinsTest, TestBoolEqual) { CelValue ref = CelValue::CreateBool(true); CelValue lesser = CelValue::CreateBool(false); TestComparisonsForType(CelValue::Type::kBool, ref, lesser); } // Test Equality/Non-Equality operation for Int64 TEST_F(BuiltinsTest, TestInt64Equal) { CelValue ref = CelValue::CreateInt64(2); CelValue lesser = CelValue::CreateInt64(1); TestComparisonsForType(CelValue::Type::kInt64, ref, lesser); } // Test Equality/Non-Equality operation for Uint64 TEST_F(BuiltinsTest, TestUint64Comparisons) { CelValue ref = CelValue::CreateUint64(2); CelValue lesser = CelValue::CreateUint64(1); TestComparisonsForType(CelValue::Type::kUint64, ref, lesser); } // Test Equality/Non-Equality operation for Double TEST_F(BuiltinsTest, TestDoubleComparisons) { CelValue ref = CelValue::CreateDouble(2); CelValue lesser = CelValue::CreateDouble(1); TestComparisonsForType(CelValue::Type::kDouble, ref, lesser); } // Test Equality/Non-Equality operation for String TEST_F(BuiltinsTest, TestStringEqual) { std::string test1 = "test1"; std::string test2 = "test2"; CelValue ref = CelValue::CreateString(&test2); CelValue lesser = CelValue::CreateString(&test1); TestComparisonsForType(CelValue::Type::kString, ref, lesser); } // Test Equality/Non-Equality operation for Double TEST_F(BuiltinsTest, TestDurationComparisons) { Duration ref; Duration lesser; ref.set_seconds(2); ref.set_nanos(1); lesser.set_seconds(1); lesser.set_nanos(2); TestComparisonsForType(CelValue::Type::kDuration, CelProtoWrapper::CreateDuration(&ref), CelProtoWrapper::CreateDuration(&lesser)); } // Test Equality/Non-Equality operation for messages TEST_F(BuiltinsTest, TestNullMessageEqual) { CelValue ref = CelValue::CreateNull(); Expr dummy; CelValue value = CelProtoWrapper::CreateMessage(&dummy, &arena_); TestComparison(builtin::kEqual, ref, ref, true); TestComparison(builtin::kInequal, ref, ref, false); TestComparison(builtin::kEqual, value, ref, false); TestComparison(builtin::kInequal, value, ref, true); TestComparison(builtin::kEqual, ref, value, false); TestComparison(builtin::kInequal, ref, value, true); } // Test functions for Duration TEST_F(BuiltinsTest, TestDurationFunctions) { Duration ref; ref.set_seconds(93541L); ref.set_nanos(11000000L); TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), int64_t{25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), int64_t{1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), int64_t{93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), int64_t{11L}); std::string result = "93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), CelValue::StringHolder(&result)); TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); ref.set_seconds(-93541L); ref.set_nanos(-11000000L); TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), int64_t{-25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), int64_t{-1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), int64_t{-93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), int64_t{-11L}); result = "-93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), CelValue::StringHolder(&result)); TestTypeConverts(builtin::kDuration, CelValue::CreateString(&result), ref); absl::Duration d = MinDuration() + absl::Seconds(-1); result = absl::FormatDuration(d); TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); d = MaxDuration() + absl::Seconds(1); result = absl::FormatDuration(d); TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&result)); std::string inf = "inf"; std::string ninf = "-inf"; TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&inf)); TestTypeConversionError(builtin::kDuration, CelValue::CreateString(&ninf)); } // Test functions for Timestamp TEST_F(BuiltinsTest, TestTimestampFunctions) { Timestamp ref; // Test timestamp functions w/o timezone ref.set_seconds(1L); ref.set_nanos(11000000L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), int64_t{1970L}); TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), int64_t{0L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), int64_t{0L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), int64_t{0L}); TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), int64_t{1L}); TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), int64_t{0L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), int64_t{0L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), int64_t{1L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), int64_t{0L}); } TEST_F(BuiltinsTest, TestTimestampConversionToString) { Timestamp ref; ref.set_seconds(1L); ref.set_nanos(11000000L); std::string result = "1970-01-01T00:00:01.011Z"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateTimestamp(&ref), CelValue::StringHolder(&result)); ref.set_seconds(259200L); ref.set_nanos(0L); result = "1970-01-04T00:00:00Z"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateTimestamp(&ref), CelValue::StringHolder(&result)); } TEST_F(BuiltinsTest, TestTimestampFunctionsWithTimeZone) { // Test timestamp functions w/ IANA timezone Timestamp ref; ref.set_seconds(1L); ref.set_nanos(11000000L); std::vector params; const std::string timezone = "America/Los_Angeles"; params.push_back(CelValue::CreateString(&timezone)); TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{30L}); TestFunctionsWithParams(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{31L}); TestFunctionsWithParams(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), int64_t{1969L}); TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), int64_t{30L}); TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), int64_t{31L}); TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), int64_t{23L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), int64_t{59L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), int64_t{3L}); // Test timestamp functions w/ fixed timezone ref.set_seconds(1L); ref.set_nanos(11000000L); const std::string fixedzone = "-08:00"; params.clear(); params.push_back(CelValue::CreateString(&fixedzone)); TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{30L}); TestFunctionsWithParams(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{31L}); TestFunctionsWithParams(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), params, int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), int64_t{1969L}); TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), int64_t{30L}); TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), int64_t{31L}); TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), int64_t{23L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), int64_t{59L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), int64_t{3L}); TestTypeConversionError( builtin::kString, CelValue::CreateTimestamp(MinTimestamp() + absl::Seconds(-1))); } TEST_F(BuiltinsTest, TestBytesConversions_bytes) { std::string txt = "hello"; CelValue::BytesHolder result = CelValue::BytesHolder(&txt); TestTypeConverts(builtin::kBytes, CelValue::CreateBytes(&txt), result); } TEST_F(BuiltinsTest, TestBytesConversions_string) { std::string txt = "hello"; CelValue::BytesHolder result = CelValue::BytesHolder(&txt); TestTypeConverts(builtin::kBytes, CelValue::CreateString(&txt), result); } TEST_F(BuiltinsTest, TestDoubleConversions_double) { double ref = 100.1; TestTypeConverts(builtin::kDouble, CelValue::CreateDouble(ref), double{100.1}); } TEST_F(BuiltinsTest, TestDoubleConversions_int) { int64_t ref = 100L; TestTypeConverts(builtin::kDouble, CelValue::CreateInt64(ref), double{100.0}); } TEST_F(BuiltinsTest, TestDoubleConversions_string) { std::string ref = "-100.1"; TestTypeConverts(builtin::kDouble, CelValue::CreateString(&ref), double{-100.1}); } TEST_F(BuiltinsTest, TestDoubleConversions_uint) { uint64_t ref = 100UL; TestTypeConverts(builtin::kDouble, CelValue::CreateUint64(ref), double{100.0}); } TEST_F(BuiltinsTest, TestDoubleConversionError_stringInvalid) { std::string invalid = "-100e-10.0"; TestTypeConversionError(builtin::kDouble, CelValue::CreateString(&invalid)); } TEST_F(BuiltinsTest, TestDynConversions) { TestTypeConverts(builtin::kDyn, CelValue::CreateDouble(100.1), double{100.1}); TestTypeConverts(builtin::kDyn, CelValue::CreateInt64(100L), int64_t{100L}); TestTypeConverts(builtin::kDyn, CelValue::CreateUint64(100UL), uint64_t{100UL}); } TEST_F(BuiltinsTest, TestIntConversions_int) { TestTypeConverts(builtin::kInt, CelValue::CreateInt64(100L), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_Timestamp) { Timestamp ref; ref.set_seconds(100); TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_double) { double ref = 100.1; TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_string) { std::string ref = "100"; TestTypeConverts(builtin::kInt, CelValue::CreateString(&ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_uint) { uint64_t ref = 100; TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_doubleIntMin) { // Converting int64_t min may or may not roundtrip properly without overflow // depending on compiler flags, so the conservative approach is to treat this // case as overflow. double range = std::numeric_limits::lowest(); TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestIntConversions_doubleIntMinMinus1024) { // Converting values between [int64_t::lowest(), (int64_t::lowest() - 1024)] // will result in an int64_t representable value, in some cases, but not all // as the conversion depends on double range = std::numeric_limits::lowest(); range -= 1024L; TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus512) { // Converting int64_t max - 512 to a double will not roundtrip to the original // value, but it will roundtrip to a valid 64-bit integer. double range = std::numeric_limits::max() - 512; TestTypeConverts(builtin::kInt, CelValue::CreateDouble(range), int64_t{std::numeric_limits::max() - 1023}); } TEST_F(BuiltinsTest, TestIntConversionError_doubleNegRange) { double range = -1.0e99; TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestIntConversionError_doublePosRange) { double range = 1.0e99; TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMax) { // Converting int64_t max to a double results in a double value of int64_t max // + 1 which should cause the overflow testing to trip. double range = std::numeric_limits::max(); TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus1) { // Converting values between int64_t::max() and int64_t::max() - 511 will // result in overflow errors during round-tripping. double range = std::numeric_limits::max() - 1; TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus511) { // Converting values between int64_t::max() and int64_t::max() - 511 will // result in overflow errors during round-tripping. double range = std::numeric_limits::max() - 511; TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMinMinus1025) { // Converting double values lower than int64_t::lowest() - 1025 will result in // an overflow error. double range = std::numeric_limits::lowest(); range -= 1025L; TestTypeConversionError(builtin::kInt, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestIntConversionError_uintRange) { uint64_t range = 18446744073709551615UL; TestTypeConversionError(builtin::kInt, CelValue::CreateUint64(range)); } TEST_F(BuiltinsTest, TestUintConversions_double) { double ref = 100.1; TestTypeConverts(builtin::kUint, CelValue::CreateDouble(ref), uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_int) { int64_t ref = 100L; TestTypeConverts(builtin::kUint, CelValue::CreateInt64(ref), uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_string) { std::string ref = "100"; TestTypeConverts(builtin::kUint, CelValue::CreateString(&ref), uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_uint) { TestTypeConverts(builtin::kUint, CelValue::CreateUint64(uint64_t{100UL}), uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversionError_doubleNegRange) { double range = -1.0e99; TestTypeConversionError(builtin::kUint, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestUintConversionError_doublePosRange) { double range = 1.0e99; TestTypeConversionError(builtin::kUint, CelValue::CreateDouble(range)); } TEST_F(BuiltinsTest, TestUintConversionError_intRange) { int64_t range = -1L; TestTypeConversionError(builtin::kUint, CelValue::CreateInt64(range)); } TEST_F(BuiltinsTest, TestUintConversionError_stringInvalid) { std::string invalid = "-100"; TestTypeConversionError(builtin::kUint, CelValue::CreateString(&invalid)); } TEST_F(BuiltinsTest, TestTimestampComparisons) { Timestamp ref; Timestamp lesser; ref.set_seconds(2); ref.set_nanos(1); lesser.set_seconds(1); lesser.set_nanos(2); TestComparisonsForType(CelValue::Type::kTimestamp, CelProtoWrapper::CreateTimestamp(&ref), CelProtoWrapper::CreateTimestamp(&lesser)); } TEST_F(BuiltinsTest, TestLogicalOr) { const char* op_name = builtin::kOr; TestLogicalOperation(op_name, true, true, true); TestLogicalOperation(op_name, false, true, true); TestLogicalOperation(op_name, true, false, true); TestLogicalOperation(op_name, false, false, false); CelError error = absl::CancelledError(); // Test special cases - mix of bool and error // true || error CelValue result; ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateBool(true), CelValue::CreateError(&error)}, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); // error || true ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateError(&error), CelValue::CreateBool(true)}, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); // false || error ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateBool(false), CelValue::CreateError(&error)}, &result)); EXPECT_TRUE(result.IsError()); // error || false ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateError(&error), CelValue::CreateBool(false)}, &result)); EXPECT_TRUE(result.IsError()); // error || error ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateError(&error), CelValue::CreateError(&error)}, &result)); EXPECT_TRUE(result.IsError()); // "foo" || "bar" std::string arg0 = "foo"; std::string arg1 = "bar"; ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateString(&arg0), CelValue::CreateString(&arg1)}, &result)); EXPECT_TRUE(CheckNoMatchingOverloadError(result)); } TEST_F(BuiltinsTest, TestLogicalAnd) { const char* op_name = builtin::kAnd; TestLogicalOperation(op_name, true, true, true); TestLogicalOperation(op_name, false, true, false); TestLogicalOperation(op_name, true, false, false); TestLogicalOperation(op_name, false, false, false); CelError error = absl::CancelledError(); // Test special cases - mix of bool and error // true && error CelValue result; ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateBool(false), CelValue::CreateError(&error)}, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); // error && false ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateError(&error), CelValue::CreateBool(false)}, &result)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); // false && error ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateBool(true), CelValue::CreateError(&error)}, &result)); EXPECT_TRUE(result.IsError()); // error && true ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateError(&error), CelValue::CreateBool(true)}, &result)); EXPECT_TRUE(result.IsError()); // error && error ASSERT_NO_FATAL_FAILURE(PerformRun( op_name, {}, {CelValue::CreateError(&error), CelValue::CreateError(&error)}, &result)); EXPECT_TRUE(result.IsError()); } TEST_F(BuiltinsTest, TestTernary) { std::vector args = {CelValue::CreateBool(true), CelValue::CreateInt64(1), CelValue::CreateInt64(2)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kTernary, {}, args, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), 1); args[0] = CelValue::CreateBool(false); ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kTernary, {}, args, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), 2); } TEST_F(BuiltinsTest, TestTernaryErrorAsCondition) { CelError cel_error = absl::CancelledError(); std::vector args = {CelValue::CreateError(&cel_error), CelValue::CreateInt64(1), CelValue::CreateInt64(2)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kTernary, {}, args, &result_value)); ASSERT_EQ(result_value.IsError(), true); ASSERT_EQ(*result_value.ErrorOrDie(), cel_error); } TEST_F(BuiltinsTest, TestTernaryStringAsCondition) { std::string test = "test"; std::vector args = {CelValue::CreateString(&test), CelValue::CreateInt64(1), CelValue::CreateInt64(2)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kTernary, {}, args, &result_value)); EXPECT_TRUE(CheckNoMatchingOverloadError(result_value)); } class FakeList : public CelList { public: explicit FakeList(const std::vector& values) : values_(values) {} int size() const override { return values_.size(); } CelValue operator[](int index) const override { return values_[index]; } private: std::vector values_; }; class FakeErrorMap : public CelMap { public: FakeErrorMap() {} int size() const override { return 0; } absl::StatusOr Has(const CelValue& key) const override { return absl::InvalidArgumentError("bad key type"); } absl::optional operator[](CelValue key) const override { return absl::nullopt; } absl::StatusOr ListKeys() const override { return absl::UnimplementedError("CelMap::ListKeys is not implemented"); } }; template class FakeMap : public CelMap { public: template FakeMap(const std::map& data, const CreateCelValue& create_cel_value, const GetCelValue& get_cel_value) : data_(data), get_cel_value_(get_cel_value) { std::vector keys; for (auto kv : data) { keys.push_back(create_cel_value(kv.first)); } keys_ = std::make_unique(keys); } int size() const override { return data_.size(); } absl::optional operator[](CelValue key) const override { absl::optional raw_value = get_cel_value_(key); if (!raw_value) { return absl::nullopt; } auto it = data_.find(*raw_value); if (it == data_.end()) { return absl::nullopt; } return it->second; } absl::StatusOr ListKeys() const override { return keys_.get(); } private: std::map data_; std::unique_ptr keys_; std::function(CelValue)> get_cel_value_; }; class FakeBoolMap : public FakeMap { public: explicit FakeBoolMap(const std::map& data) : FakeMap(data, CelValue::CreateBool, [](CelValue v) -> absl::optional { if (!v.IsBool()) { return absl::nullopt; } return v.BoolOrDie(); }) {} }; class FakeInt64Map : public FakeMap { public: explicit FakeInt64Map(const std::map& data) : FakeMap(data, CelValue::CreateInt64, [](CelValue v) -> absl::optional { if (!v.IsInt64()) { return absl::nullopt; } return v.Int64OrDie(); }) {} }; class FakeUint64Map : public FakeMap { public: explicit FakeUint64Map(const std::map& data) : FakeMap(data, CelValue::CreateUint64, [](CelValue v) -> absl::optional { if (!v.IsUint64()) { return absl::nullopt; } return v.Uint64OrDie(); }) {} }; class FakeStringMap : public FakeMap { public: explicit FakeStringMap(const std::map& data) : FakeMap( data, [](CelValue::StringHolder v) { return CelValue::CreateString(v); }, [](CelValue v) -> absl::optional { if (!v.IsString()) { return absl::nullopt; } return v.StringOrDie(); }) {} }; // Test list index access function TEST_F(BuiltinsTest, ListIndex) { constexpr int64_t kValues[] = {3, 4, 5, 6}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateInt64(value)); } FakeList cel_list(values); for (size_t i = 0; i < values.size(); i++) { CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kIndex, {}, {CelValue::CreateList(&cel_list), CelValue::CreateInt64(i)}, &result_value)); ASSERT_TRUE(result_value.IsInt64()); EXPECT_THAT(result_value.Int64OrDie(), Eq(kValues[i])); } } // Test Equality/Non-Equality operation for lists TEST_F(BuiltinsTest, TestListEqual) { const FakeList kList0({}); const FakeList kList1({CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); const FakeList kList2({CelValue::CreateInt64(1), CelValue::CreateInt64(3)}); const FakeList kList3({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); std::vector values; values.push_back(CelValue::CreateList(&kList0)); values.push_back(CelValue::CreateList(&kList1)); values.push_back(CelValue::CreateList(&kList2)); values.push_back(CelValue::CreateList(&kList3)); for (size_t i = 0; i < values.size(); i++) { for (size_t j = 0; j < values.size(); j++) { if (i == j) { TestComparison(builtin::kEqual, values[i], values[j], true); TestComparison(builtin::kInequal, values[i], values[j], false); } else { TestComparison(builtin::kInequal, values[i], values[j], true); TestComparison(builtin::kEqual, values[i], values[j], false); } } } const FakeList kList({CelValue::CreateInt64(1), CelValue::CreateBool(true)}); TestNoMatchingEqualOverload(CelValue::CreateList(&kList1), CelValue::CreateList(&kList)); } // Test map index access function TEST_F(BuiltinsTest, MapInt64Index) { constexpr int64_t kValues[] = {3, -4, 5, -6}; std::map data; for (auto value : kValues) { data[value] = CelValue::CreateInt64(value * value); } FakeInt64Map cel_map(data); for (int64_t value : kValues) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIndex, {}, {CelValue::CreateMap(&cel_map), CelValue::CreateInt64(value)}, &result_value)); ASSERT_TRUE(result_value.IsInt64()); EXPECT_THAT(result_value.Int64OrDie(), Eq(value * value)); } CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kIndex, {}, {CelValue::CreateMap(&cel_map), CelValue::CreateInt64(100)}, &result_value)); ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } TEST_F(BuiltinsTest, MapUint64Index) { constexpr uint64_t kValues[] = {3, 4, 5, 6}; std::map data; for (auto value : kValues) { data[value] = CelValue::CreateUint64(value * value); } FakeUint64Map cel_map(data); for (uint64_t value : kValues) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIndex, {}, {CelValue::CreateMap(&cel_map), CelValue::CreateUint64(value)}, &result_value)); ASSERT_TRUE(result_value.IsUint64()); EXPECT_THAT(result_value.Uint64OrDie(), Eq(value * value)); } CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kIndex, {}, {CelValue::CreateMap(&cel_map), CelValue::CreateUint64(100)}, &result_value)); ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } TEST_F(BuiltinsTest, MapStringIndex) { std::vector kValues = {"test0", "test1", "test2"}; std::map data; for (size_t i = 0; i < kValues.size(); i++) { data[CelValue::StringHolder(&kValues[i])] = CelValue::CreateInt64(i); } FakeStringMap cel_map(data); for (size_t i = 0; i < kValues.size(); i++) { std::string value = kValues[i]; CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIndex, {}, {CelValue::CreateMap(&cel_map), CelValue::CreateString(&value)}, &result_value)); ASSERT_TRUE(result_value.IsInt64()); EXPECT_THAT(result_value.Int64OrDie(), Eq(i)); } CelValue result_value; const std::string kMissingKey = "no_such_key_is_present"; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIndex, {}, {CelValue::CreateMap(&cel_map), CelValue::CreateString(&kMissingKey)}, &result_value)); ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } TEST_F(BuiltinsTest, MapBoolIndex) { std::vector kValues = {true, false}; std::map data; for (size_t i = 0; i < kValues.size(); i++) { data[kValues[i]] = CelValue::CreateInt64(i); } FakeBoolMap cel_map(data); for (size_t i = 0; i < kValues.size(); i++) { bool value = kValues[i]; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kIndex, {}, {CelValue::CreateMap(&cel_map), CelValue::CreateBool(value)}, &result_value)); ASSERT_TRUE(result_value.IsInt64()); EXPECT_THAT(result_value.Int64OrDie(), Eq(i)); } } // Test Equality/Non-Equality operation for maps TEST_F(BuiltinsTest, TestMapEqual) { const FakeInt64Map kMap0({}); const FakeInt64Map kMap1({{0, CelValue::CreateInt64(0)}}); const FakeInt64Map kMap2({{0, CelValue::CreateInt64(1)}}); const FakeInt64Map kMap3( {{0, CelValue::CreateInt64(0)}, {1, CelValue::CreateInt64(1)}}); std::vector values; values.push_back(CelValue::CreateMap(&kMap0)); values.push_back(CelValue::CreateMap(&kMap1)); values.push_back(CelValue::CreateMap(&kMap2)); values.push_back(CelValue::CreateMap(&kMap3)); for (size_t i = 0; i < values.size(); i++) { for (size_t j = 0; j < values.size(); j++) { if (i == j) { TestComparison(builtin::kEqual, values[i], values[j], true); TestComparison(builtin::kInequal, values[i], values[j], false); } else { TestComparison(builtin::kInequal, values[i], values[j], true); TestComparison(builtin::kEqual, values[i], values[j], false); } } } const FakeInt64Map kMap({{0, CelValue::CreateBool(true)}}); TestNoMatchingEqualOverload(CelValue::CreateMap(&kMap1), CelValue::CreateMap(&kMap)); } TEST_F(BuiltinsTest, TestNestedEqual) { const std::string test = "testvalue"; Duration dur; dur.set_seconds(2); dur.set_nanos(1); Timestamp ts; ts.set_seconds(100); ts.set_nanos(100); const FakeInt64Map kMap({{0, CelValue::CreateBool(true)}}); const FakeList kList1({CelValue::CreateBool(true)}); const FakeList kList2({CelValue::CreateInt64(12)}); const FakeList kList3({CelValue::CreateUint64(13)}); const FakeList kList4({CelValue::CreateDouble(14)}); const FakeList kList5({CelValue::CreateString(&test)}); const FakeList kList6({CelValue::CreateBytes(&test)}); const FakeList kList7({CelValue::CreateNull()}); const FakeList kList8({CelProtoWrapper::CreateDuration(&dur)}); const FakeList kList9({CelProtoWrapper::CreateTimestamp(&ts)}); const FakeList kList10({CelValue::CreateList(&kList1)}); const FakeList kList11({CelValue::CreateMap(&kMap)}); std::vector values; values.push_back(CelValue::CreateList(&kList1)); values.push_back(CelValue::CreateList(&kList2)); values.push_back(CelValue::CreateList(&kList3)); values.push_back(CelValue::CreateList(&kList4)); values.push_back(CelValue::CreateList(&kList5)); values.push_back(CelValue::CreateList(&kList6)); values.push_back(CelValue::CreateList(&kList7)); values.push_back(CelValue::CreateList(&kList8)); values.push_back(CelValue::CreateList(&kList9)); values.push_back(CelValue::CreateList(&kList10)); values.push_back(CelValue::CreateList(&kList11)); for (size_t i = 0; i < values.size(); i++) { for (size_t j = 0; j < values.size(); j++) { if (i == j) { TestComparison(builtin::kEqual, values[i], values[j], true); TestComparison(builtin::kInequal, values[i], values[j], false); } else { TestNoMatchingEqualOverload(values[i], values[j]); } } } } TEST_F(BuiltinsTest, StringSize) { std::string test = "testvalue"; CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kSize, {}, {CelValue::CreateString(&test)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), 9); } TEST_F(BuiltinsTest, StringUnicodeSize) { std::string test = "πέντε"; CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kSize, {}, {CelValue::CreateString(&test)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), 5); } TEST_F(BuiltinsTest, BytesSize) { std::string test = "testvalue"; CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kSize, {}, {CelValue::CreateBytes(&test)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), test.size()); } TEST_F(BuiltinsTest, ListSize) { constexpr int64_t kValues[] = {3, 4, 5, 6}; std::vector values; for (auto value : kValues) { values.push_back(CelValue::CreateInt64(value)); } FakeList cel_list(values); CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kSize, {}, {CelValue::CreateList(&cel_list)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), values.size()); } TEST_F(BuiltinsTest, MapSize) { constexpr int64_t kValues[] = {3, -4, 5, -6}; std::map data; for (auto value : kValues) { data[value] = CelValue::CreateInt64(value * value); } FakeInt64Map cel_map(data); CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kSize, {}, {CelValue::CreateMap(&cel_map)}, &result_value)); ASSERT_EQ(result_value.IsInt64(), true); ASSERT_EQ(result_value.Int64OrDie(), data.size()); } TEST_F(BuiltinsTest, TestBoolListIn) { FakeList cel_list({CelValue::CreateBool(false), CelValue::CreateBool(false)}); TestInList(&cel_list, CelValue::CreateBool(false), true); TestInList(&cel_list, CelValue::CreateBool(true), false); } TEST_F(BuiltinsTest, TestInt64ListIn) { FakeList cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); TestInList(&cel_list, CelValue::CreateInt64(2), true); TestInList(&cel_list, CelValue::CreateInt64(3), false); } TEST_F(BuiltinsTest, TestUint64ListIn) { FakeList cel_list({CelValue::CreateUint64(1), CelValue::CreateUint64(2)}); TestInList(&cel_list, CelValue::CreateUint64(2), true); TestInList(&cel_list, CelValue::CreateUint64(3), false); } TEST_F(BuiltinsTest, TestDoubleListIn) { FakeList cel_list({CelValue::CreateDouble(1), CelValue::CreateDouble(2)}); TestInList(&cel_list, CelValue::CreateDouble(2), true); TestInList(&cel_list, CelValue::CreateDouble(3), false); } TEST_F(BuiltinsTest, TestStringListIn) { std::string v0 = "test0"; std::string v1 = "test1"; std::string v2 = "test2"; FakeList cel_list({CelValue::CreateString(&v0), CelValue::CreateString(&v1)}); TestInList(&cel_list, CelValue::CreateString(&v1), true); TestInList(&cel_list, CelValue::CreateString(&v2), false); } TEST_F(BuiltinsTest, TestBytesListIn) { std::vector values; std::string v0 = "test0"; std::string v1 = "test1"; std::string v2 = "test2"; FakeList cel_list({CelValue::CreateBytes(&v0), CelValue::CreateBytes(&v1)}); TestInList(&cel_list, CelValue::CreateBytes(&v1), true); TestInList(&cel_list, CelValue::CreateBytes(&v2), false); } TEST_F(HeterogeneousEqualityTest, MixedTypes) { FakeList cel_list({CelValue::CreateDuration(absl::Seconds(1)), CelValue::CreateNull(), CelValue::CreateInt64(1)}); ASSERT_NO_FATAL_FAILURE( TestInList(&cel_list, CelValue::CreateDuration(absl::Seconds(1)), true)); ASSERT_NO_FATAL_FAILURE( TestInList(&cel_list, CelValue::CreateInt64(1), true)); ASSERT_NO_FATAL_FAILURE( TestInList(&cel_list, CelValue::CreateUint64(1), true)); ASSERT_NO_FATAL_FAILURE( TestInList(&cel_list, CelValue::CreateInt64(2), false)); ASSERT_NO_FATAL_FAILURE( TestInList(&cel_list, CelValue::CreateStringView("abc"), false)); } TEST_F(HeterogeneousEqualityTest, NullIn) { FakeList cel_list({CelValue::CreateInt64(0), CelValue::CreateNull(), CelValue::CreateInt64(1)}); ASSERT_NO_FATAL_FAILURE( TestInList(&cel_list, CelValue::CreateInt64(1), true)); ASSERT_NO_FATAL_FAILURE(TestInList(&cel_list, CelValue::CreateNull(), true)); ASSERT_NO_FATAL_FAILURE( TestInList(&cel_list, CelValue::CreateInt64(2), false)); } TEST_F(HeterogeneousEqualityTest, NullNotIn) { FakeList cel_list({CelValue::CreateInt64(0), CelValue::CreateInt64(1)}); ASSERT_NO_FATAL_FAILURE(TestInList(&cel_list, CelValue::CreateNull(), false)); } TEST_F(BuiltinsTest, TestMapInError) { FakeErrorMap cel_map; std::vector kValues = { CelValue::CreateBool(true), CelValue::CreateInt64(1), CelValue::CreateStringView("hello"), CelValue::CreateUint64(2), }; options_.enable_heterogeneous_equality = true; for (auto key : kValues) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); ASSERT_TRUE(result_value.IsBool()) << key.DebugString() << " : " << result_value.DebugString(); EXPECT_FALSE(result_value.BoolOrDie()); } options_.enable_heterogeneous_equality = false; for (auto key : kValues) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); EXPECT_TRUE(result_value.IsError()); EXPECT_EQ(result_value.ErrorOrDie()->message(), "bad key type"); EXPECT_EQ(result_value.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); } } TEST_F(BuiltinsTest, TestBoolMapIn) { constexpr bool kValues[] = {true, true}; std::map data; for (auto value : kValues) { data[value] = CelValue::CreateBool(value); } FakeBoolMap cel_map(data); TestInMap(&cel_map, CelValue::CreateBool(true), true); TestInMap(&cel_map, CelValue::CreateBool(false), false); TestInMap(&cel_map, CelValue::CreateUint64(3), false); } TEST_F(BuiltinsTest, TestInt64MapIn) { constexpr int64_t kValues[] = {3, -4, 5, -6}; std::map data; for (auto value : kValues) { data[value] = CelValue::CreateInt64(value * value); } FakeInt64Map cel_map(data); options_.enable_heterogeneous_equality = false; TestInMap(&cel_map, CelValue::CreateInt64(-4), true); TestInMap(&cel_map, CelValue::CreateInt64(4), false); TestInMap(&cel_map, CelValue::CreateUint64(3), false); TestInMap(&cel_map, CelValue::CreateUint64(4), false); options_.enable_heterogeneous_equality = true; TestInMap(&cel_map, CelValue::CreateUint64(3), true); TestInMap(&cel_map, CelValue::CreateUint64(4), false); TestInMap(&cel_map, CelValue::CreateDouble(NAN), false); TestInMap(&cel_map, CelValue::CreateDouble(-4.0), true); TestInMap(&cel_map, CelValue::CreateDouble(-4.1), false); TestInMap(&cel_map, CelValue::CreateDouble(std::numeric_limits::max()), false); } TEST_F(BuiltinsTest, TestUint64MapIn) { constexpr uint64_t kValues[] = {3, 4, 5, 6}; std::map data; for (auto value : kValues) { data[value] = CelValue::CreateUint64(value * value); } FakeUint64Map cel_map(data); options_.enable_heterogeneous_equality = false; TestInMap(&cel_map, CelValue::CreateUint64(4), true); TestInMap(&cel_map, CelValue::CreateUint64(44), false); TestInMap(&cel_map, CelValue::CreateInt64(4), false); options_.enable_heterogeneous_equality = true; TestInMap(&cel_map, CelValue::CreateInt64(-1), false); TestInMap(&cel_map, CelValue::CreateInt64(4), true); TestInMap(&cel_map, CelValue::CreateDouble(4.0), true); TestInMap(&cel_map, CelValue::CreateDouble(-4.0), false); TestInMap(&cel_map, CelValue::CreateDouble(7.0), false); } TEST_F(BuiltinsTest, TestStringMapIn) { std::vector kValues = {"test0", "test1", "test2", "42"}; std::map data; for (size_t i = 0; i < kValues.size(); i++) { data[CelValue::StringHolder(&kValues[i])] = CelValue::CreateInt64(i); } FakeStringMap cel_map(data); TestInMap(&cel_map, CelValue::CreateString(&kValues[0]), true); TestInMap(&cel_map, CelValue::CreateString(&kValues[3]), true); TestInMap(&cel_map, CelValue::CreateInt64(42), false); } TEST_F(BuiltinsTest, TestInt64Arithmetics) { TestArithmeticalOperationInt64(builtin::kAdd, 2, 3, 5); TestArithmeticalOperationInt64(builtin::kSubtract, 2, 3, -1); TestArithmeticalOperationInt64(builtin::kMultiply, 2, 3, 6); TestArithmeticalOperationInt64(builtin::kDivide, 10, 5, 2); } TEST_F(BuiltinsTest, TestInt64ArithmeticOverflow) { int64_t min = std::numeric_limits::lowest(); int64_t max = std::numeric_limits::max(); TestArithmeticalErrorInt64(builtin::kAdd, max, 1, absl::StatusCode::kOutOfRange); TestArithmeticalErrorInt64(builtin::kSubtract, min, max, absl::StatusCode::kOutOfRange); TestArithmeticalErrorInt64(builtin::kMultiply, max, 2, absl::StatusCode::kOutOfRange); TestArithmeticalErrorInt64(builtin::kModulo, min, -1, absl::StatusCode::kOutOfRange); TestArithmeticalErrorInt64(builtin::kDivide, min, -1, absl::StatusCode::kOutOfRange); TestArithmeticalErrorInt64(builtin::kDivide, min, 0, absl::StatusCode::kInvalidArgument); } TEST_F(BuiltinsTest, TestUint64Arithmetics) { TestArithmeticalOperationUint64(builtin::kAdd, 2, 3, 5); TestArithmeticalOperationUint64(builtin::kSubtract, 3, 2, 1); TestArithmeticalOperationUint64(builtin::kMultiply, 2, 3, 6); TestArithmeticalOperationUint64(builtin::kDivide, 10, 5, 2); } TEST_F(BuiltinsTest, TestUint64ArithmeticOverflow) { CelValue result_value; uint64_t max = std::numeric_limits::max(); TestArithmeticalErrorUint64(builtin::kAdd, max, 1, absl::StatusCode::kOutOfRange); TestArithmeticalErrorUint64(builtin::kSubtract, 2, 3, absl::StatusCode::kOutOfRange); TestArithmeticalErrorUint64(builtin::kMultiply, max, 2, absl::StatusCode::kOutOfRange); TestArithmeticalErrorUint64(builtin::kDivide, 1, 0, absl::StatusCode::kInvalidArgument); } TEST_F(BuiltinsTest, TestDoubleArithmetics) { TestArithmeticalOperationDouble(builtin::kAdd, 2.5, 3, 5.5); TestArithmeticalOperationDouble(builtin::kSubtract, 2.9, 3.9, -1.); TestArithmeticalOperationDouble(builtin::kMultiply, 2, 3, 6); TestArithmeticalOperationDouble(builtin::kDivide, 1.44, 1.2, 1.2); } TEST_F(BuiltinsTest, TestDoubleDivisionByZero) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kDivide, {}, {CelValue::CreateDouble(1), CelValue::CreateDouble(0)}, &result_value)); ASSERT_TRUE(result_value.IsDouble()); ASSERT_EQ(result_value.DoubleOrDie(), std::numeric_limits::infinity()); } // Test Concatenation operation for string TEST_F(BuiltinsTest, TestConcatString) { const std::string kString1 = "t1"; const std::string kString2 = "t2"; std::vector args = {CelValue::CreateString(&kString1), CelValue::CreateString(&kString2)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kAdd, {}, args, &result_value)); ASSERT_TRUE(result_value.IsString()); EXPECT_EQ(result_value.StringOrDie().value(), kString1 + kString2); } // Test Concatenation operation for Bytes TEST_F(BuiltinsTest, TestConcatBytes) { const std::string kBytes1 = "t1"; const std::string kBytes2 = "t2"; std::vector args = {CelValue::CreateBytes(&kBytes1), CelValue::CreateBytes(&kBytes2)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kAdd, {}, args, &result_value)); ASSERT_TRUE(result_value.IsBytes()); EXPECT_EQ(result_value.BytesOrDie().value(), kBytes1 + kBytes2); } // Test Concatenation operation for CelList TEST_F(BuiltinsTest, TestConcatList) { const std::vector kValues({5, 6, 7, 8}); const FakeList kList1( {CelValue::CreateInt64(kValues[0]), CelValue::CreateInt64(kValues[1])}); const FakeList kList2( {CelValue::CreateInt64(kValues[2]), CelValue::CreateInt64(kValues[3])}); std::vector args = {CelValue::CreateList(&kList1), CelValue::CreateList(&kList2)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kAdd, {}, args, &result_value)); ASSERT_TRUE(result_value.IsList()); const CelList* result_list = result_value.ListOrDie(); ASSERT_EQ(result_list->size(), kValues.size()); for (int i = 0; i < result_list->size(); i++) { CelValue item = (*result_list)[i]; ASSERT_TRUE(item.IsInt64()); EXPECT_EQ(item.Int64OrDie(), kValues[i]); } } TEST_F(BuiltinsTest, MatchesPartialTrue) { std::string target = "haystack"; std::string regex = "\\w{2}ack"; std::vector args = {CelValue::CreateString(&target), CelValue::CreateString(®ex)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kRegexMatch, {}, args, &result_value)); ASSERT_TRUE(result_value.IsBool()); EXPECT_TRUE(result_value.BoolOrDie()); } TEST_F(BuiltinsTest, MatchesPartialFalse) { std::string target = "haystack"; std::string regex = "hy"; std::vector args = {CelValue::CreateString(&target), CelValue::CreateString(®ex)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kRegexMatch, {}, args, &result_value)); ASSERT_TRUE(result_value.IsBool()); EXPECT_FALSE(result_value.BoolOrDie()); } TEST_F(BuiltinsTest, MatchesPartialError) { std::string target = "haystack"; std::string invalid_regex = "("; std::vector args = {CelValue::CreateString(&target), CelValue::CreateString(&invalid_regex)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kRegexMatch, {}, args, &result_value)); EXPECT_TRUE(result_value.IsError()); } TEST_F(BuiltinsTest, MatchesMaxSize) { std::string target = "haystack"; std::string large_regex = "[hj][ab][yt][st][tv][ac]"; std::vector args = {CelValue::CreateString(&target), CelValue::CreateString(&large_regex)}; CelValue result_value; InterpreterOptions options; options.regex_max_program_size = 1; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kRegexMatch, {}, args, &result_value, options)); EXPECT_TRUE(result_value.IsError()); } TEST_F(BuiltinsTest, StringToIntNonInt) { std::string target = "not_a_number"; std::vector args = {CelValue::CreateString(&target)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInt, {}, args, &result_value)); ASSERT_TRUE(result_value.IsError()); } TEST_F(BuiltinsTest, IntToString) { std::vector args = {CelValue::CreateInt64(-42)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kString, {}, args, &result_value)); ASSERT_TRUE(result_value.IsString()); EXPECT_EQ(result_value.StringOrDie().value(), "-42"); } TEST_F(BuiltinsTest, UIntToString) { std::vector args = {CelValue::CreateUint64(42)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kString, {}, args, &result_value)); ASSERT_TRUE(result_value.IsString()); EXPECT_EQ(result_value.StringOrDie().value(), "42"); } TEST_F(BuiltinsTest, DoubleToString) { std::vector args = {CelValue::CreateDouble(37.5)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kString, {}, args, &result_value)); ASSERT_TRUE(result_value.IsString()); EXPECT_EQ(result_value.StringOrDie().value(), "37.5"); } TEST_F(BuiltinsTest, BytesToString) { std::string input = "abcd"; std::vector args = {CelValue::CreateBytes(&input)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kString, {}, args, &result_value)); ASSERT_TRUE(result_value.IsString()); EXPECT_EQ(result_value.StringOrDie().value(), "abcd"); } TEST_F(BuiltinsTest, BytesToStringInvalid) { std::string input = "\xFF"; std::vector args = {CelValue::CreateBytes(&input)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kString, {}, args, &result_value)); ASSERT_TRUE(result_value.IsError()); } TEST_F(BuiltinsTest, StringToString) { std::string input = "abcd"; std::vector args = {CelValue::CreateString(&input)}; CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kString, {}, args, &result_value)); ASSERT_TRUE(result_value.IsString()); EXPECT_EQ(result_value.StringOrDie().value(), "abcd"); } // Type operations TEST_F(BuiltinsTest, TypeComparisons) { std::vector> paired_values; paired_values.push_back( {CelValue::CreateBool(false), CelValue::CreateBool(true)}); paired_values.push_back( {CelValue::CreateInt64(-1), CelValue::CreateInt64(1)}); paired_values.push_back( {CelValue::CreateUint64(1), CelValue::CreateUint64(2)}); paired_values.push_back( {CelValue::CreateDouble(1.), CelValue::CreateDouble(2.)}); std::string str1 = "test1"; std::string str2 = "test2"; paired_values.push_back( {CelValue::CreateString(&str1), CelValue::CreateString(&str2)}); paired_values.push_back( {CelValue::CreateBytes(&str1), CelValue::CreateBytes(&str2)}); FakeList cel_list1({CelValue::CreateBool(false)}); FakeList cel_list2({CelValue::CreateBool(true)}); paired_values.push_back( {CelValue::CreateList(&cel_list1), CelValue::CreateList(&cel_list2)}); std::map data1; std::map data2; FakeInt64Map cel_map1(data1); FakeInt64Map cel_map2(data2); paired_values.push_back( {CelValue::CreateMap(&cel_map1), CelValue::CreateMap(&cel_map2)}); for (size_t i = 0; i < paired_values.size(); i++) { for (size_t j = 0; j < paired_values.size(); j++) { CelValue result1; CelValue result2; PerformRun(builtin::kType, {}, {paired_values[i].first}, &result1); PerformRun(builtin::kType, {}, {paired_values[j].second}, &result2); ASSERT_TRUE(result1.IsCelType()) << "Unexpected result for value at index" << i << ":" << result1.DebugString(); ASSERT_TRUE(result2.IsCelType()) << "Unexpected result for value at index" << j << ":" << result2.DebugString(); if (i == j) { ASSERT_EQ(result1.CelTypeOrDie(), result2.CelTypeOrDie()); } else { ASSERT_TRUE(result1.CelTypeOrDie() != result2.CelTypeOrDie()); } } } } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_attribute.cc ================================================ #include "eval/public/cel_attribute.h" #include #include #include #include #include #include "absl/strings/string_view.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { namespace { // Visitation for attribute qualifier kinds struct QualifierVisitor { CelAttributeQualifierPattern operator()(absl::string_view v) { if (v == "*") { return CelAttributeQualifierPattern::CreateWildcard(); } return CelAttributeQualifierPattern::OfString(std::string(v)); } CelAttributeQualifierPattern operator()(int64_t v) { return CelAttributeQualifierPattern::OfInt(v); } CelAttributeQualifierPattern operator()(uint64_t v) { return CelAttributeQualifierPattern::OfUint(v); } CelAttributeQualifierPattern operator()(bool v) { return CelAttributeQualifierPattern::OfBool(v); } CelAttributeQualifierPattern operator()(CelAttributeQualifierPattern v) { return v; } }; } // namespace CelAttributeQualifierPattern CreateCelAttributeQualifierPattern( const CelValue& value) { switch (value.type()) { case cel::Kind::kInt64: return CelAttributeQualifierPattern::OfInt(value.Int64OrDie()); case cel::Kind::kUint64: return CelAttributeQualifierPattern::OfUint(value.Uint64OrDie()); case cel::Kind::kString: return CelAttributeQualifierPattern::OfString( std::string(value.StringOrDie().value())); case cel::Kind::kBool: return CelAttributeQualifierPattern::OfBool(value.BoolOrDie()); default: return CelAttributeQualifierPattern(CelAttributeQualifier()); } } CelAttributeQualifier CreateCelAttributeQualifier(const CelValue& value) { switch (value.type()) { case cel::Kind::kInt64: return CelAttributeQualifier::OfInt(value.Int64OrDie()); case cel::Kind::kUint64: return CelAttributeQualifier::OfUint(value.Uint64OrDie()); case cel::Kind::kString: return CelAttributeQualifier::OfString( std::string(value.StringOrDie().value())); case cel::Kind::kBool: return CelAttributeQualifier::OfBool(value.BoolOrDie()); default: return CelAttributeQualifier(); } } CelAttributePattern CreateCelAttributePattern( absl::string_view variable, std::initializer_list> path_spec) { std::vector path; path.reserve(path_spec.size()); for (const auto& spec_elem : path_spec) { path.emplace_back(absl::visit(QualifierVisitor(), spec_elem)); } return CelAttributePattern(std::string(variable), std::move(path)); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_attribute.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ #include #include #include #include #include #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/attribute.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { // CelAttributeQualifier represents a segment in // attribute resolutuion path. A segment can be qualified by values of // following types: string/int64_t/uint64_t/bool. using CelAttributeQualifier = ::cel::AttributeQualifier; // CelAttribute represents resolved attribute path. using CelAttribute = ::cel::Attribute; // CelAttributeQualifierPattern matches a segment in // attribute resolutuion path. CelAttributeQualifierPattern is capable of // matching path elements of types string/int64_t/uint64_t/bool. using CelAttributeQualifierPattern = ::cel::AttributeQualifierPattern; // CelAttributePattern is a fully-qualified absolute attribute path pattern. // Supported segments steps in the path are: // - field selection; // - map lookup by key; // - list access by index. using CelAttributePattern = ::cel::AttributePattern; CelAttributeQualifierPattern CreateCelAttributeQualifierPattern( const CelValue& value); CelAttributeQualifier CreateCelAttributeQualifier(const CelValue& value); // Short-hand helper for creating |CelAttributePattern|s. string_view arguments // must outlive the returned pattern. CelAttributePattern CreateCelAttributePattern( absl::string_view variable, std::initializer_list> path_spec = {}); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ ================================================ FILE: eval/public/cel_attribute_test.cc ================================================ #include "eval/public/cel_attribute.h" #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using cel::expr::Expr; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::SizeIs; class DummyMap : public CelMap { public: absl::optional operator[](CelValue value) const override { return CelValue::CreateNull(); } absl::StatusOr ListKeys() const override { return absl::UnimplementedError("CelMap::ListKeys is not implemented"); } int size() const override { return 0; } }; class DummyList : public CelList { public: int size() const override { return 0; } CelValue operator[](int index) const override { return CelValue::CreateNull(); } }; TEST(CelAttributeQualifierTest, TestBoolAccess) { auto qualifier = CreateCelAttributeQualifier(CelValue::CreateBool(true)); EXPECT_FALSE(qualifier.GetStringKey().has_value()); EXPECT_FALSE(qualifier.GetInt64Key().has_value()); EXPECT_FALSE(qualifier.GetUint64Key().has_value()); EXPECT_TRUE(qualifier.GetBoolKey().has_value()); EXPECT_THAT(qualifier.GetBoolKey().value(), Eq(true)); EXPECT_THAT(qualifier.AsString(), IsOkAndHolds("true")); } TEST(CelAttributeQualifierTest, TestInt64Access) { auto qualifier = CreateCelAttributeQualifier(CelValue::CreateInt64(-1)); EXPECT_FALSE(qualifier.GetBoolKey().has_value()); EXPECT_FALSE(qualifier.GetStringKey().has_value()); EXPECT_FALSE(qualifier.GetUint64Key().has_value()); EXPECT_TRUE(qualifier.GetInt64Key().has_value()); EXPECT_THAT(qualifier.GetInt64Key().value(), Eq(-1)); EXPECT_THAT(qualifier.AsString(), IsOkAndHolds("-1")); } TEST(CelAttributeQualifierTest, TestUint64Access) { auto qualifier = CreateCelAttributeQualifier(CelValue::CreateUint64(1)); EXPECT_FALSE(qualifier.GetBoolKey().has_value()); EXPECT_FALSE(qualifier.GetStringKey().has_value()); EXPECT_FALSE(qualifier.GetInt64Key().has_value()); EXPECT_TRUE(qualifier.GetUint64Key().has_value()); EXPECT_THAT(qualifier.GetUint64Key().value(), Eq(1UL)); EXPECT_THAT(qualifier.AsString(), IsOkAndHolds("1")); } TEST(CelAttributeQualifierTest, TestStringAccess) { const std::string test = "test"; auto qualifier = CreateCelAttributeQualifier(CelValue::CreateString(&test)); EXPECT_FALSE(qualifier.GetBoolKey().has_value()); EXPECT_FALSE(qualifier.GetInt64Key().has_value()); EXPECT_FALSE(qualifier.GetUint64Key().has_value()); EXPECT_TRUE(qualifier.GetStringKey().has_value()); EXPECT_THAT(qualifier.GetStringKey().value(), Eq("test")); EXPECT_THAT(qualifier.AsString(), IsOkAndHolds("test")); } void TestAllInequalities(const CelAttributeQualifier& qualifier) { EXPECT_FALSE(qualifier == CreateCelAttributeQualifier(CelValue::CreateBool(false))); EXPECT_FALSE(qualifier == CreateCelAttributeQualifier(CelValue::CreateInt64(0))); EXPECT_FALSE(qualifier == CreateCelAttributeQualifier(CelValue::CreateUint64(0))); const std::string test = "Those are not the droids you are looking for."; EXPECT_FALSE(qualifier == CreateCelAttributeQualifier(CelValue::CreateString(&test))); } TEST(CelAttributeQualifierTest, TestBoolComparison) { auto qualifier = CreateCelAttributeQualifier(CelValue::CreateBool(true)); TestAllInequalities(qualifier); EXPECT_TRUE(qualifier == CreateCelAttributeQualifier(CelValue::CreateBool(true))); } TEST(CelAttributeQualifierTest, TestInt64Comparison) { auto qualifier = CreateCelAttributeQualifier(CelValue::CreateInt64(true)); TestAllInequalities(qualifier); EXPECT_TRUE(qualifier == CreateCelAttributeQualifier(CelValue::CreateInt64(true))); } TEST(CelAttributeQualifierTest, TestUint64Comparison) { auto qualifier = CreateCelAttributeQualifier(CelValue::CreateUint64(true)); TestAllInequalities(qualifier); EXPECT_TRUE(qualifier == CreateCelAttributeQualifier(CelValue::CreateUint64(true))); } TEST(CelAttributeQualifierTest, TestStringComparison) { const std::string kTest = "test"; auto qualifier = CreateCelAttributeQualifier(CelValue::CreateString(&kTest)); TestAllInequalities(qualifier); EXPECT_TRUE(qualifier == CreateCelAttributeQualifier(CelValue::CreateString(&kTest))); } void TestAllQualifierMismatches(const CelAttributeQualifierPattern& qualifier) { const std::string test = "Those are not the droids you are looking for."; EXPECT_FALSE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateBool(false)))); EXPECT_FALSE( qualifier.IsMatch(CreateCelAttributeQualifier(CelValue::CreateInt64(0)))); EXPECT_FALSE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateUint64(0)))); EXPECT_FALSE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateString(&test)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierBoolMatch) { auto qualifier = CreateCelAttributeQualifierPattern(CelValue::CreateBool(true)); TestAllQualifierMismatches(qualifier); EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateBool(true)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierInt64Match) { auto qualifier = CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)); TestAllQualifierMismatches(qualifier); EXPECT_TRUE( qualifier.IsMatch(CreateCelAttributeQualifier(CelValue::CreateInt64(1)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierUint64Match) { auto qualifier = CreateCelAttributeQualifierPattern(CelValue::CreateUint64(1)); TestAllQualifierMismatches(qualifier); EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateUint64(1)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierStringMatch) { const std::string test = "test"; auto qualifier = CreateCelAttributeQualifierPattern(CelValue::CreateString(&test)); TestAllQualifierMismatches(qualifier); EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateString(&test)))); } TEST(CelAttributeQualifierPatternTest, TestQualifierWildcardMatch) { auto qualifier = CelAttributeQualifierPattern::CreateWildcard(); EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateBool(false)))); EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateBool(true)))); EXPECT_TRUE( qualifier.IsMatch(CreateCelAttributeQualifier(CelValue::CreateInt64(0)))); EXPECT_TRUE( qualifier.IsMatch(CreateCelAttributeQualifier(CelValue::CreateInt64(1)))); EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateUint64(0)))); EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateUint64(1)))); const std::string kTest1 = "test1"; const std::string kTest2 = "test2"; EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateString(&kTest1)))); EXPECT_TRUE(qualifier.IsMatch( CreateCelAttributeQualifier(CelValue::CreateString(&kTest2)))); } TEST(CreateCelAttributePattern, Basic) { const std::string kTest = "def"; CelAttributePattern pattern = CreateCelAttributePattern( "abc", {kTest, static_cast(1), static_cast(-1), false, CelAttributeQualifierPattern::CreateWildcard()}); EXPECT_THAT(pattern.variable(), Eq("abc")); ASSERT_THAT(pattern.qualifier_path(), SizeIs(5)); EXPECT_TRUE(pattern.qualifier_path()[4].IsWildcard()); } TEST(CreateCelAttributePattern, EmptyPath) { CelAttributePattern pattern = CreateCelAttributePattern("abc"); EXPECT_THAT(pattern.variable(), Eq("abc")); EXPECT_THAT(pattern.qualifier_path(), IsEmpty()); } TEST(CreateCelAttributePattern, Wildcards) { const std::string kTest = "*"; CelAttributePattern pattern = CreateCelAttributePattern( "abc", {kTest, "false", CelAttributeQualifierPattern::CreateWildcard()}); EXPECT_THAT(pattern.variable(), Eq("abc")); ASSERT_THAT(pattern.qualifier_path(), SizeIs(3)); EXPECT_TRUE(pattern.qualifier_path()[0].IsWildcard()); EXPECT_FALSE(pattern.qualifier_path()[1].IsWildcard()); EXPECT_TRUE(pattern.qualifier_path()[2].IsWildcard()); } TEST(CelAttribute, AsStringBasic) { CelAttribute attr( "var", { CreateCelAttributeQualifier(CelValue::CreateStringView("qual1")), CreateCelAttributeQualifier(CelValue::CreateStringView("qual2")), CreateCelAttributeQualifier(CelValue::CreateStringView("qual3")), }); ASSERT_OK_AND_ASSIGN(std::string string_format, attr.AsString()); EXPECT_EQ(string_format, "var.qual1.qual2.qual3"); } TEST(CelAttribute, AsStringInvalidRoot) { CelAttribute attr( "", { CreateCelAttributeQualifier(CelValue::CreateStringView("qual1")), CreateCelAttributeQualifier(CelValue::CreateStringView("qual2")), CreateCelAttributeQualifier(CelValue::CreateStringView("qual3")), }); EXPECT_EQ(attr.AsString().status().code(), absl::StatusCode::kInvalidArgument); } TEST(CelAttribute, InvalidQualifiers) { Expr expr; expr.mutable_ident_expr()->set_name("var"); google::protobuf::Arena arena; CelAttribute attr1("var", { CreateCelAttributeQualifier( CelValue::CreateDuration(absl::Minutes(2))), }); CelAttribute attr2("var", { CreateCelAttributeQualifier( CelProtoWrapper::CreateMessage(&expr, &arena)), }); CelAttribute attr3( "var", { CreateCelAttributeQualifier(CelValue::CreateBool(false)), }); // Implementation detail: Messages as attribute qualifiers are unsupported, // so the implementation treats them inequal to any other. This is included // for coverage. EXPECT_FALSE(attr1 == attr2); EXPECT_FALSE(attr2 == attr1); EXPECT_FALSE(attr2 == attr2); EXPECT_FALSE(attr1 == attr3); EXPECT_FALSE(attr3 == attr1); EXPECT_FALSE(attr2 == attr3); EXPECT_FALSE(attr3 == attr2); // If the attribute includes an unsupported qualifier, return invalid argument // error. EXPECT_THAT(attr1.AsString(), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(attr2.AsString(), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(CelAttribute, AsStringQualiferTypes) { CelAttribute attr( "var", { CreateCelAttributeQualifier(CelValue::CreateStringView("qual1")), CreateCelAttributeQualifier(CelValue::CreateUint64(1)), CreateCelAttributeQualifier(CelValue::CreateInt64(-1)), CreateCelAttributeQualifier(CelValue::CreateBool(false)), }); ASSERT_OK_AND_ASSIGN(std::string string_format, attr.AsString()); EXPECT_EQ(string_format, "var.qual1[1][-1][false]"); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_builtins.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_BUILTINS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_BUILTINS_H_ #include "base/builtins.h" namespace google { namespace api { namespace expr { namespace runtime { // Alias new namespace until external CEL users can be updated. namespace builtin = cel::builtin; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_BUILTINS_H_ ================================================ FILE: eval/public/cel_expr_builder_factory.cc ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "eval/public/cel_expr_builder_factory.h" #include #include #include "absl/base/nullability.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "common/kind.h" #include "common/memory.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/comprehension_vulnerability_check.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/qualified_reference_resolver.h" #include "eval/compiler/regex_precompilation_optimization.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "extensions/select_optimization.h" #include "internal/noop_delete.h" #include "runtime/internal/runtime_env.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::cel::MemoryManagerRef; using ::cel::extensions::CreateSelectOptimizationProgramOptimizer; using ::cel::extensions::kCelAttribute; using ::cel::extensions::kCelHasField; using ::cel::extensions::SelectOptimizationAstUpdater; using ::cel::runtime_internal::CreateConstantFoldingOptimizer; using ::cel::runtime_internal::RuntimeEnv; } // namespace std::unique_ptr CreateCelExpressionBuilder( const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options) { if (descriptor_pool == nullptr) { ABSL_LOG(ERROR) << "Cannot pass nullptr as descriptor pool to " "CreateCelExpressionBuilder"; return nullptr; } cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); absl_nullable std::shared_ptr shared_message_factory; if (message_factory != nullptr) { shared_message_factory = std::shared_ptr( message_factory, cel::internal::NoopDeleteFor()); } auto env = std::make_shared( std::shared_ptr( descriptor_pool, cel::internal::NoopDeleteFor()), shared_message_factory); if (auto status = env->Initialize(); !status.ok()) { ABSL_LOG(ERROR) << "Failed to validate standard message types: " << status.ToString(); // NOLINT: OSS compatibility return nullptr; } auto builder = std::make_unique( std::move(env), runtime_options); FlatExprBuilder& flat_expr_builder = builder->flat_expr_builder(); flat_expr_builder.AddAstTransform(NewReferenceResolverExtension( (options.enable_qualified_identifier_rewrites) ? ReferenceResolverOption::kAlways : ReferenceResolverOption::kCheckedOnly)); if (options.enable_comprehension_vulnerability_check) { builder->flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); } if (options.constant_folding) { std::shared_ptr shared_arena; if (options.constant_arena != nullptr) { shared_arena = std::shared_ptr( options.constant_arena, cel::internal::NoopDeleteFor()); } builder->flat_expr_builder().AddProgramOptimizer( CreateConstantFoldingOptimizer(std::move(shared_arena), std::move(shared_message_factory))); } if (options.enable_regex_precompilation) { flat_expr_builder.AddProgramOptimizer( CreateRegexPrecompilationExtension(options.regex_max_program_size)); } if (options.enable_select_optimization) { // Add AST transform to update select branches on a stored // CheckedExpression. This may already be performed by a type checker. flat_expr_builder.AddAstTransform( std::make_unique()); // Add overloads for select optimization signature. // These are never bound, only used to prevent the builder from failing on // the overloads check. absl::Status status = builder->GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( kCelAttribute, false, {cel::Kind::kAny, cel::Kind::kList})); if (!status.ok()) { ABSL_LOG(ERROR) << "Failed to register " << kCelAttribute << ": " << status; } status = builder->GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( kCelHasField, false, {cel::Kind::kAny, cel::Kind::kList})); if (!status.ok()) { ABSL_LOG(ERROR) << "Failed to register " << kCelHasField << ": " << status; } // Add runtime implementation. flat_expr_builder.AddProgramOptimizer( CreateSelectOptimizationProgramOptimizer()); } return builder; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_expr_builder_factory.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ #include #include "absl/base/attributes.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google { namespace api { namespace expr { namespace runtime { // Factory creates CelExpressionBuilder implementation for public use. std::unique_ptr CreateCelExpressionBuilder( const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::MessageFactory* message_factory, const InterpreterOptions& options = InterpreterOptions()); ABSL_DEPRECATED( "This overload uses the generated descriptor pool, which allows " "expressions to create any messages linked into the binary. This is not " "hermetic and potentially dangerous, you should select the descriptor pool " "carefully. Use the other overload and explicitly pass your descriptor " "pool. It can still be the generated descriptor pool, but the choice " "should be explicit. If you do not need struct creation, use " "`cel::GetMinimalDescriptorPool()`.") inline std::unique_ptr CreateCelExpressionBuilder( const InterpreterOptions& options = InterpreterOptions()) { return CreateCelExpressionBuilder(google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), options); } } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ ================================================ FILE: eval/public/cel_expression.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPRESSION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPRESSION_H_ #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "eval/public/base_activation.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { // CelEvaluationListener is the callback that is passed to (and called by) // CelExpression::Trace. It gets an expression node ID from the original // expression, its value and the arena object. If an expression node // is evaluated multiple times (e.g. as a part of Comprehension.loop_step) // then the order of the callback invocations is guaranteed to correspond // the order of variable sub-elements (e.g. the order of elements returned // by Comprehension.iter_range). using CelEvaluationListener = std::function; // An opaque state used for evaluation of a CEL expression. class CelEvaluationState { public: virtual ~CelEvaluationState() = default; }; // Base interface for expression evaluating objects. class CelExpression { public: virtual ~CelExpression() = default; // Initializes the state virtual std::unique_ptr InitializeState( google::protobuf::Arena* arena) const = 0; // Evaluates expression and returns value. // activation contains bindings from parameter names to values // arena parameter specifies Arena object where output result and // internal data will be allocated. virtual absl::StatusOr Evaluate(const BaseActivation& activation, google::protobuf::Arena* arena) const = 0; // Evaluates expression and returns value. // activation contains bindings from parameter names to values // state must be non-null and created prior to calling Evaluate by // InitializeState. virtual absl::StatusOr Evaluate( const BaseActivation& activation, CelEvaluationState* state) const = 0; // Trace evaluates expression calling the callback on each sub-tree. virtual absl::StatusOr Trace( const BaseActivation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback) const = 0; // Trace evaluates expression calling the callback on each sub-tree. // state must be non-null and created prior to calling Evaluate by // InitializeState. virtual absl::StatusOr Trace( const BaseActivation& activation, CelEvaluationState* state, CelEvaluationListener callback) const = 0; }; // Base class for Expression Builder implementations // Provides user with factory to register extension functions. // ExpressionBuilder MUST NOT be destroyed before CelExpression objects // it built. class CelExpressionBuilder { public: CelExpressionBuilder() = default; virtual ~CelExpressionBuilder() = default; // Creates CelExpression object from AST tree. // expr specifies root of AST tree. // Method implementation is expected to create copies of expr and source_info, // so that the returned CelExpression is not dependent on the lifetime of // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info) const = 0; // Creates CelExpression object from AST tree. // expr specifies root of AST tree. // non-fatal build warnings are written to warnings if encountered. // Method implementation is expected to create copies of expr and source_info, // so that the returned CelExpression is not dependent on the lifetime of // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info, std::vector* warnings) const = 0; // Creates CelExpression object from a checked expression. // This includes an AST, source info, type hints and ident hints. // Method implementation is expected to create copy of checked_expr, // so that the returned CelExpression is not dependent on the lifetime of // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr) const { // Default implementation just passes through the expr and source info. return CreateExpression(&checked_expr->expr(), &checked_expr->source_info()); } // Creates CelExpression object from a checked expression. // This includes an AST, source info, type hints and ident hints. // non-fatal build warnings are written to warnings if encountered. // Method implementation is expected to create copy of checked_expr, // so that the returned CelExpression is not dependent on the lifetime of // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr, std::vector* warnings) const { // Default implementation just passes through the expr and source_info. return CreateExpression(&checked_expr->expr(), &checked_expr->source_info(), warnings); } // CelFunction registry. Extension function should be registered with it // prior to expression creation. virtual CelFunctionRegistry* GetRegistry() const = 0; // CEL Type registry. Provides a means to resolve the CEL built-in types to // CelValue instances, and to extend the set of types and enums known to // expressions by registering them ahead of time. virtual CelTypeRegistry* GetTypeRegistry() const = 0; virtual void set_container(std::string container) = 0; virtual absl::string_view container() const = 0; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPRESSION_H_ ================================================ FILE: eval/public/cel_function.cc ================================================ #include "eval/public/cel_function.h" #include #include #include "absl/status/statusor.h" #include "absl/types/span.h" #include "common/value.h" #include "eval/internal/interop.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "runtime/function.h" namespace google::api::expr::runtime { using ::cel::Value; using ::cel::interop_internal::ToLegacyValue; bool CelFunction::MatchArguments(absl::Span arguments) const { auto types_size = descriptor().types().size(); if (types_size != arguments.size()) { return false; } for (size_t i = 0; i < types_size; i++) { const auto& value = arguments[i]; CelValue::Type arg_type = descriptor().types()[i]; if (value.type() != arg_type && arg_type != CelValue::Type::kAny) { return false; } } return true; } bool CelFunction::MatchArguments(absl::Span arguments) const { auto types_size = descriptor().types().size(); if (types_size != arguments.size()) { return false; } for (size_t i = 0; i < types_size; i++) { const auto& value = arguments[i]; CelValue::Type arg_type = descriptor().types()[i]; if (value->kind() != arg_type && arg_type != CelValue::Type::kAny) { return false; } } return true; } absl::StatusOr CelFunction::Invoke( absl::Span arguments, const cel::Function::InvokeContext& context) const { std::vector legacy_args; legacy_args.reserve(arguments.size()); // Users shouldn't be able to create expressions that call registered // functions with unconvertible types, but it's possible to create an AST that // can trigger this by making an unexpected call on a value that the // interpreter expects to only be used with internal program steps. for (const auto& arg : arguments) { CEL_ASSIGN_OR_RETURN(legacy_args.emplace_back(), ToLegacyValue(context.arena(), arg, true)); } CelValue legacy_result; CEL_RETURN_IF_ERROR(Evaluate(legacy_args, &legacy_result, context.arena())); return cel::interop_internal::LegacyValueToModernValueOrDie( context.arena(), legacy_result, /*unchecked=*/true); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_function.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/value.h" #include "eval/public/cel_value.h" #include "runtime/function.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { // Type that describes CelFunction. // This complex structure is needed for overloads support. using CelFunctionDescriptor = ::cel::FunctionDescriptor; // CelFunction is a handler that represents single // CEL function. // CelFunction provides Evaluate() method, that performs // evaluation of the function. CelFunction instances provide // descriptors that contain function information: // - name // - is function receiver style (e.f(g) vs f(e,g)) // - amount of arguments and their types. // Function overloads are resolved based on their arguments and // receiver style. class CelFunction : public ::cel::Function { public: // Build CelFunction from descriptor explicit CelFunction(CelFunctionDescriptor descriptor) : descriptor_(std::move(descriptor)) {} // Non-copyable CelFunction(const CelFunction& other) = delete; CelFunction& operator=(const CelFunction& other) = delete; ~CelFunction() override = default; // Evaluates CelValue based on arguments supplied. // If result content is to be allocated (e.g. string concatenation), // arena parameter must be used as allocation manager. // Provides resulting value in *result, returns evaluation success/failure. // Methods should discriminate between internal evaluator errors, that // makes further evaluation impossible or unreasonable (example - argument // type or number mismatch) and business logic errors (example division by // zero). When former happens, error Status is returned and *result is // not changed. In case of business logic error, returned Status is Ok, and // error is provided as CelValue - wrapped CelError in *result. virtual absl::Status Evaluate(absl::Span arguments, CelValue* result, google::protobuf::Arena* arena) const = 0; // Determines whether instance of CelFunction is applicable to // arguments supplied. // Method is called during runtime. bool MatchArguments(absl::Span arguments) const; bool MatchArguments(absl::Span arguments) const; // Implements cel::Function. using cel::Function::Invoke; absl::StatusOr Invoke( absl::Span arguments, const cel::Function::InvokeContext& context) const final; // CelFunction descriptor const CelFunctionDescriptor& descriptor() const { return descriptor_; } private: CelFunctionDescriptor descriptor_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_H_ ================================================ FILE: eval/public/cel_function_adapter.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ #include #include #include #include #include #include "absl/status/status.h" #include "eval/public/cel_function_adapter_impl.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace internal { // A type code matcher that adds support for google::protobuf::Message. struct ProtoAdapterTypeCodeMatcher { template constexpr static std::optional type_code() { if constexpr (std::is_same_v) { return CelValue::Type::kMessage; } else { return internal::TypeCodeMatcher().type_code(); } } }; // A value converter that handles wrapping google::protobuf::Messages as CelValues. struct ProtoAdapterValueConverter : public internal::ValueConverterBase { using BaseType = internal::ValueConverterBase; using BaseType::NativeToValue; using BaseType::ValueToNative; absl::Status NativeToValue(const ::google::protobuf::Message* value, ::google::protobuf::Arena* arena, CelValue* result) { if (value == nullptr) { return absl::Status(absl::StatusCode::kInvalidArgument, "Null Message pointer returned"); } *result = CelProtoWrapper::CreateMessage(value, arena); return absl::OkStatus(); } }; } // namespace internal // FunctionAdapter is a helper class that simplifies creation of CelFunction // implementations. // // The static Create member function accepts CelFunction::Evalaute method // implementations as std::function, allowing them to be lambdas/regular C++ // functions. CEL method descriptors ddeduced based on C++ function signatures. // // The adapted CelFunction::Evaluate implementation will set result to the // value returned by the handler. To handle errors, choose CelValue as the // return type, and use the CreateError/Create* helpers in cel_value.h. // // The wrapped std::function may return absl::StatusOr. If the wrapped // function returns the absl::Status variant, the generated CelFunction // implementation will return a non-ok status code, rather than a CelError // wrapping an absl::Status value. A returned non-ok status indicates a hard // error, meaning the interpreter cannot reasonably continue evaluation (e.g. // data corruption or broken invariant). To create a CelError that follows // logical pruning rules, the extension function implementation should return a // CelError or an error-typed CelValue. // // FunctionAdapter // ReturnType: the C++ return type of the function implementation // Arguments: the C++ Argument type of the function implementation // // Static Methods: // // Create(absl::string_view function_name, bool receiver_style, // FunctionType func) -> absl::StatusOr> // // Usage example: // // auto func = [](::google::protobuf::Arena* arena, int64_t i, int64_t j) -> bool { // return i < j; // }; // // CEL_ASSIGN_OR_RETURN(auto cel_func, // FunctionAdapter::Create("<", false, func)); // // CreateAndRegister(absl::string_view function_name, bool receiver_style, // FunctionType func, CelFunctionRegisry registry) // -> absl::Status // // Usage example: // // auto func = [](::google::protobuf::Arena* arena, int64_t i, int64_t j) -> bool { // return i < j; // }; // // CEL_RETURN_IF_ERROR(( // FunctionAdapter::CreateAndRegister("<", false, // func, cel_expression_builder->GetRegistry())); // template using FunctionAdapter = internal::FunctionAdapterImpl:: FunctionAdapter; template using UnaryFunctionAdapter = internal::FunctionAdapterImpl< internal::ProtoAdapterTypeCodeMatcher, internal::ProtoAdapterValueConverter>::UnaryFunction; template using BinaryFunctionAdapter = internal::FunctionAdapterImpl< internal::ProtoAdapterTypeCodeMatcher, internal::ProtoAdapterValueConverter>::BinaryFunction; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_H_ ================================================ FILE: eval/public/cel_function_adapter_impl.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ #include #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #if defined(__clang__) || !defined(__GNUC__) // Do not disable. #else #define CEL_CPP_DISABLE_PARTIAL_SPECIALIZATION 1 #endif namespace google::api::expr::runtime { namespace internal { // TypeCodeMatch template helper. // Used for CEL type deduction based on C++ native type. struct TypeCodeMatcher { template constexpr static std::optional type_code() { if constexpr (std::is_same_v) { // A bit of a trick - to pass Any kind of value, we use generic CelValue // parameters. return CelValue::Type::kAny; } else { int index = CelValue::IndexOf::value; if (index < 0) return {}; CelValue::Type arg_type = static_cast(index); if (arg_type >= CelValue::Type::kAny) { return {}; } return arg_type; } } }; // Template helper to construct an argument list for a CelFunctionDescriptor. template struct TypeAdder { template bool AddType(std::vector* arg_types) const { auto kind = TypeCodeMatcher().template type_code(); if (!kind) { return false; } arg_types->push_back(*kind); return AddType(arg_types); return true; } template bool AddType(std::vector* arg_types) const { return true; } }; // Template helper for C++ types to CEL conversions. // Uses CRTP to dispatch to derived class overloads in the StatusOr helper. template struct ValueConverterBase { // Value to native uwraps a CelValue to a native type. template bool ValueToNative(CelValue value, T* result) { if constexpr (std::is_same_v) { *result = std::move(value); return true; } else { return value.GetValue(result); } } // Native to value wraps a native return type to a CelValue. absl::Status NativeToValue(bool value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateBool(value); return absl::OkStatus(); } absl::Status NativeToValue(int64_t value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateInt64(value); return absl::OkStatus(); } absl::Status NativeToValue(uint64_t value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateUint64(value); return absl::OkStatus(); } absl::Status NativeToValue(double value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateDouble(value); return absl::OkStatus(); } absl::Status NativeToValue(CelValue::StringHolder value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateString(value); return absl::OkStatus(); } absl::Status NativeToValue(CelValue::BytesHolder value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateBytes(value); return absl::OkStatus(); } absl::Status NativeToValue(const CelList* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelList pointer returned"); } *result = CelValue::CreateList(value); return absl::OkStatus(); } absl::Status NativeToValue(const CelMap* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelMap pointer returned"); } *result = CelValue::CreateMap(value); return absl::OkStatus(); } absl::Status NativeToValue(CelValue::CelTypeHolder value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateCelType(value); return absl::OkStatus(); } absl::Status NativeToValue(const CelError* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelError pointer returned"); } *result = CelValue::CreateError(value); return absl::OkStatus(); } // Special case -- just forward a CelValue. absl::Status NativeToValue(const CelValue& value, ::google::protobuf::Arena*, CelValue* result) { *result = value; return absl::OkStatus(); } template absl::Status NativeToValue(absl::StatusOr value, ::google::protobuf::Arena* arena, CelValue* result) { CEL_ASSIGN_OR_RETURN(auto held_value, value); return Derived().NativeToValue(held_value, arena, result); } }; struct ValueConverter : public ValueConverterBase {}; // Generalized implementation for function adapters. See comments on // instantiated versions for details on usage. // // TypeCodeMatcher provides the mapping from C++ type to CEL type. // ValueConverter provides value conversions from native to CEL and vice versa. // ReturnType and Arguments types are instantiated for the particular shape of // the adapted functions. template class FunctionAdapterImpl { public: // Implementations for the common cases of unary and binary functions. // This reduces the binary size substantially over the generic templated // versions. template class BinaryFunction : public CelFunction { public: using FuncType = std::function; static std::unique_ptr Create(absl::string_view name, bool receiver_style, FuncType handler) { constexpr auto arg1_type = TypeCodeMatcher::template type_code(); static_assert(arg1_type.has_value(), "T does not map to a CEL type."); constexpr auto arg2_type = TypeCodeMatcher::template type_code(); static_assert(arg2_type.has_value(), "U does not map to a CEL type."); std::vector arg_types{*arg1_type, *arg2_type}; return absl::WrapUnique(new BinaryFunction( CelFunctionDescriptor(name, receiver_style, std::move(arg_types)), std::move(handler))); } absl::Status Evaluate(absl::Span arguments, CelValue* result, google::protobuf::Arena* arena) const override { if (arguments.size() != 2) { return absl::InternalError("Argument number mismatch, expected 2"); } T arg; if (!ValueConverter().ValueToNative(arguments[0], &arg)) { return absl::InternalError("C++ to CEL type conversion failed"); } U arg2; if (!ValueConverter().ValueToNative(arguments[1], &arg2)) { return absl::InternalError("C++ to CEL type conversion failed"); } ReturnType handlerResult = handler_(arena, arg, arg2); return ValueConverter().NativeToValue(handlerResult, arena, result); } private: BinaryFunction(CelFunctionDescriptor descriptor, FuncType handler) : CelFunction(descriptor), handler_(std::move(handler)) {} FuncType handler_; }; template class UnaryFunction : public CelFunction { public: using FuncType = std::function; static std::unique_ptr Create(absl::string_view name, bool receiver_style, FuncType handler) { constexpr auto arg_type = TypeCodeMatcher::template type_code(); static_assert(arg_type.has_value(), "T does not map to a CEL type."); std::vector arg_types{*arg_type}; return absl::WrapUnique(new UnaryFunction( CelFunctionDescriptor(name, receiver_style, std::move(arg_types)), std::move(handler))); } absl::Status Evaluate(absl::Span arguments, CelValue* result, google::protobuf::Arena* arena) const override { if (arguments.size() != 1) { return absl::InternalError("Argument number mismatch, expected 1"); } T arg; if (!ValueConverter().ValueToNative(arguments[0], &arg)) { return absl::InternalError("C++ to CEL type conversion failed"); } ReturnType handlerResult = handler_(arena, arg); return ValueConverter().NativeToValue(handlerResult, arena, result); } private: UnaryFunction(CelFunctionDescriptor descriptor, FuncType handler) : CelFunction(descriptor), handler_(std::move(handler)) {} FuncType handler_; }; // Generalized implementation. template class FunctionAdapter : public CelFunction { public: using FuncType = std::function; using TypeAdder = internal::TypeAdder; FunctionAdapter(CelFunctionDescriptor descriptor, FuncType handler) : CelFunction(std::move(descriptor)), handler_(std::move(handler)) {} static absl::StatusOr> Create( absl::string_view name, bool receiver_type, std::function handler) { std::vector arg_types; arg_types.reserve(sizeof...(Arguments)); if (!TypeAdder().template AddType<0, Arguments...>(&arg_types)) { return absl::Status( absl::StatusCode::kInternal, absl::StrCat("Failed to create adapter for ", name, ": failed to determine input parameter type")); } return std::make_unique( CelFunctionDescriptor(name, receiver_type, std::move(arg_types)), std::move(handler)); } // Creates function handler and attempts to register it with // supplied function registry. static absl::Status CreateAndRegister( absl::string_view name, bool receiver_type, std::function handler, CelFunctionRegistry* registry) { CEL_ASSIGN_OR_RETURN(auto cel_function, Create(name, receiver_type, std::move(handler))); return registry->Register(std::move(cel_function)); } #if !defined(CEL_CPP_DISABLE_PARTIAL_SPECIALIZATION) template inline absl::Status RunWrap( absl::Span arguments, std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, ::google::protobuf::Arena* arena) const { if (!ValueConverter().ValueToNative(arguments[arg_index], &std::get(input))) { return absl::Status(absl::StatusCode::kInvalidArgument, "Type conversion failed"); } return RunWrap(arguments, input, result, arena); } template <> inline absl::Status RunWrap( absl::Span, std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, ::google::protobuf::Arena* arena) const { return ValueConverter().NativeToValue(absl::apply(handler_, input), arena, result); } #else inline absl::Status RunWrap( std::function func, ABSL_ATTRIBUTE_UNUSED const absl::Span argset, ::google::protobuf::Arena* arena, CelValue* result, ABSL_ATTRIBUTE_UNUSED int arg_index) const { return ValueConverter().NativeToValue(func(), arena, result); } template inline absl::Status RunWrap(std::function func, const absl::Span argset, ::google::protobuf::Arena* arena, CelValue* result, int arg_index) const { Arg argument; if (!ValueConverter().ValueToNative(argset[arg_index], &argument)) { return absl::Status(absl::StatusCode::kInvalidArgument, "Type conversion failed"); } std::function wrapped_func = [func, argument](Args... args) -> ReturnType { return func(argument, args...); }; return RunWrap(std::move(wrapped_func), argset, arena, result, arg_index + 1); } #endif absl::Status Evaluate(absl::Span arguments, CelValue* result, ::google::protobuf::Arena* arena) const override { if (arguments.size() != sizeof...(Arguments)) { return absl::Status(absl::StatusCode::kInternal, "Argument number mismatch"); } #if !defined(CEL_CPP_DISABLE_PARTIAL_SPECIALIZATION) std::tuple<::google::protobuf::Arena*, Arguments...> input; std::get<0>(input) = arena; return RunWrap<0>(arguments, input, result, arena); #else const auto* handler = &handler_; std::function wrapped_handler = [handler, arena](Arguments... args) -> ReturnType { return (*handler)(arena, args...); }; return RunWrap(std::move(wrapped_handler), arguments, arena, result, 0); #endif } private: FuncType handler_; }; }; } // namespace internal } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_ADAPTER_IMPL_H_ ================================================ FILE: eval/public/cel_function_adapter_test.cc ================================================ #include "eval/public/cel_function_adapter.h" #include #include #include #include #include "internal/status_macros.h" #include "internal/testing.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { TEST(CelFunctionAdapterTest, TestAdapterNoArg) { auto func = [](google::protobuf::Arena*) -> int64_t { return 100; }; ASSERT_OK_AND_ASSIGN( auto cel_func, (FunctionAdapter::Create("const", false, func))); absl::Span args; CelValue result = CelValue::CreateNull(); google::protobuf::Arena arena; ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); ASSERT_TRUE(result.IsInt64()); } TEST(CelFunctionAdapterTest, TestAdapterOneArg) { std::function func = [](google::protobuf::Arena* arena, int64_t i) -> int64_t { return i + 1; }; ASSERT_OK_AND_ASSIGN( auto cel_func, (FunctionAdapter::Create("_++_", false, func))); std::vector args_vec; args_vec.push_back(CelValue::CreateInt64(99)); CelValue result = CelValue::CreateNull(); google::protobuf::Arena arena; absl::Span args(&args_vec[0], args_vec.size()); ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 100); } TEST(CelFunctionAdapterTest, TestAdapterTwoArgs) { auto func = [](google::protobuf::Arena* arena, int64_t i, int64_t j) -> int64_t { return i + j; }; ASSERT_OK_AND_ASSIGN(auto cel_func, (FunctionAdapter::Create( "_++_", false, func))); std::vector args_vec; args_vec.push_back(CelValue::CreateInt64(20)); args_vec.push_back(CelValue::CreateInt64(22)); CelValue result = CelValue::CreateNull(); google::protobuf::Arena arena; absl::Span args(&args_vec[0], args_vec.size()); ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 42); } using StringHolder = CelValue::StringHolder; TEST(CelFunctionAdapterTest, TestAdapterThreeArgs) { auto func = [](google::protobuf::Arena* arena, StringHolder s1, StringHolder s2, StringHolder s3) -> StringHolder { std::string value = absl::StrCat(s1.value(), s2.value(), s3.value()); return StringHolder( google::protobuf::Arena::Create(arena, std::move(value))); }; ASSERT_OK_AND_ASSIGN( auto cel_func, (FunctionAdapter::Create("concat", false, func))); std::string test1 = "1"; std::string test2 = "2"; std::string test3 = "3"; std::vector args_vec; args_vec.push_back(CelValue::CreateString(&test1)); args_vec.push_back(CelValue::CreateString(&test2)); args_vec.push_back(CelValue::CreateString(&test3)); CelValue result = CelValue::CreateNull(); google::protobuf::Arena arena; absl::Span args(&args_vec[0], args_vec.size()); ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.StringOrDie().value(), "123"); } TEST(CelFunctionAdapterTest, TestTypeDeductionForCelValueBasicTypes) { auto func = [](google::protobuf::Arena* arena, bool, int64_t, uint64_t, double, CelValue::StringHolder, CelValue::BytesHolder, const google::protobuf::Message*, absl::Duration, absl::Time, const CelList*, const CelMap*, const CelError*) -> bool { return false; }; ASSERT_OK_AND_ASSIGN( auto cel_func, (FunctionAdapter::Create("dummy_func", false, func))); auto descriptor = cel_func->descriptor(); EXPECT_EQ(descriptor.receiver_style(), false); EXPECT_EQ(descriptor.name(), "dummy_func"); int pos = 0; ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kBool); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kInt64); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kUint64); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kDouble); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kString); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kBytes); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kMessage); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kDuration); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kTimestamp); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kList); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kMap); ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kError); } TEST(CelFunctionAdapterTest, TestAdapterStatusOrMessage) { auto func = [](google::protobuf::Arena* arena) -> absl::StatusOr { auto* ret = google::protobuf::Arena::Create(arena); ret->set_seconds(123); return ret; }; ASSERT_OK_AND_ASSIGN( auto cel_func, (FunctionAdapter>::Create( "const", false, func))); absl::Span args; CelValue result = CelValue::CreateNull(); google::protobuf::Arena arena; ASSERT_OK(cel_func->Evaluate(args, &result, &arena)); ASSERT_TRUE(result.IsTimestamp()); EXPECT_EQ(result.TimestampOrDie(), absl::FromUnixSeconds(123)); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/cel_function_registry.cc ================================================ #include "eval/public/cel_function_registry.h" #include #include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/value.h" #include "eval/internal/interop.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/status_macros.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { // Legacy cel function that proxies to the modern cel::Function interface. // // This is used to wrap new-style cel::Functions for clients consuming // legacy CelFunction-based APIs. The evaluate implementation on this class // should not be called by the CEL evaluator, but a sensible result is returned // for unit tests that haven't been migrated to the new APIs yet. class ProxyToModernCelFunction : public CelFunction { public: ProxyToModernCelFunction(const cel::FunctionDescriptor& descriptor, const cel::Function& implementation) : CelFunction(descriptor), implementation_(&implementation) {} absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { // This is only safe for use during interop where the MemoryManager is // assumed to always be backed by a google::protobuf::Arena instance. After all // dependencies on legacy CelFunction are removed, we can remove this // implementation. std::vector modern_args = cel::interop_internal::LegacyValueToModernValueOrDie(arena, args); CEL_ASSIGN_OR_RETURN( auto modern_result, implementation_->Invoke( modern_args, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), arena)); *result = cel::interop_internal::ModernValueToLegacyValueOrDie( arena, modern_result); return absl::OkStatus(); } private: // owned by the registry const cel::Function* implementation_; }; } // namespace absl::Status CelFunctionRegistry::RegisterAll( std::initializer_list registrars, const InterpreterOptions& opts) { for (Registrar registrar : registrars) { CEL_RETURN_IF_ERROR(registrar(this, opts)); } return absl::OkStatus(); } std::vector CelFunctionRegistry::FindOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const { std::vector matched_funcs = modern_registry_.FindStaticOverloads(name, receiver_style, types); // For backwards compatibility, lazily initialize a legacy CEL function // if required. // The registry should remain add-only until migration to the new type is // complete, so this should work whether the function was introduced via // the modern registry or the old registry wrapping a modern instance. std::vector results; results.reserve(matched_funcs.size()); { absl::MutexLock lock(mu_); for (cel::FunctionOverloadReference entry : matched_funcs) { std::unique_ptr& legacy_impl = functions_[&entry.implementation]; if (legacy_impl == nullptr) { legacy_impl = std::make_unique( entry.descriptor, entry.implementation); } results.push_back(legacy_impl.get()); } } return results; } std::vector CelFunctionRegistry::FindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const { std::vector lazy_overloads = modern_registry_.FindLazyOverloads(name, receiver_style, types); std::vector result; result.reserve(lazy_overloads.size()); for (const LazyOverload& overload : lazy_overloads) { result.push_back(&overload.descriptor); } return result; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_function_registry.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_REGISTRY_H_ #include #include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" namespace google::api::expr::runtime { // CelFunctionRegistry class allows to register builtin or custom // CelFunction handlers with it and look them up when creating // CelExpression objects from Expr ASTs. class CelFunctionRegistry { public: // Represents a single overload for a lazily provided function. using LazyOverload = cel::FunctionRegistry::LazyOverload; CelFunctionRegistry() = default; ~CelFunctionRegistry() = default; using Registrar = absl::Status (*)(CelFunctionRegistry*, const InterpreterOptions&); // Register CelFunction object. Object ownership is // passed to registry. // Function registration should be performed prior to // CelExpression creation. absl::Status Register(std::unique_ptr function) { // We need to copy the descriptor, otherwise there is no guarantee that the // lvalue reference to the descriptor is valid as function may be destroyed. auto descriptor = function->descriptor(); return Register(descriptor, std::move(function)); } absl::Status Register(const cel::FunctionDescriptor& descriptor, std::unique_ptr implementation) { return modern_registry_.Register(descriptor, std::move(implementation)); } absl::Status RegisterAll(std::initializer_list registrars, const InterpreterOptions& opts); // Register a lazily provided function. This overload uses a default provider // that delegates to the activation at evaluation time. absl::Status RegisterLazyFunction(const CelFunctionDescriptor& descriptor) { return modern_registry_.RegisterLazyFunction(descriptor); } // Find a subset of CelFunction that match overload conditions // As types may not be available during expression compilation, // further narrowing of this subset will happen at evaluation stage. // name - the name of CelFunction; // receiver_style - indicates whether function has receiver style; // types - argument types. If type is not known during compilation, // DYN value should be passed. // // Results refer to underlying registry entries by pointer. Results are // invalid after the registry is deleted. std::vector FindOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const; std::vector FindStaticOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const { return modern_registry_.FindStaticOverloads(name, receiver_style, types); } // Find subset of CelFunction providers that match overload conditions // As types may not be available during expression compilation, // further narrowing of this subset will happen at evaluation stage. // name - the name of CelFunction; // receiver_style - indicates whether function has receiver style; // types - argument types. If type is not known during compilation, // DYN value should be passed. std::vector FindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const; // Find subset of CelFunction providers that match overload conditions // As types may not be available during expression compilation, // further narrowing of this subset will happen at evaluation stage. // name - the name of CelFunction; // receiver_style - indicates whether function has receiver style; // types - argument types. If type is not known during compilation, // DYN value should be passed. std::vector ModernFindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types) const { return modern_registry_.FindLazyOverloads(name, receiver_style, types); } // Retrieve list of registered function descriptors. This includes both // static and lazy functions. absl::node_hash_map> ListFunctions() const { return modern_registry_.ListFunctions(); } // cel internal accessor for returning backing modern registry. // // This is intended to allow migrating the CEL evaluator internals while // maintaining the existing CelRegistry API. // // CEL users should not use this. const cel::FunctionRegistry& InternalGetRegistry() const { return modern_registry_; } cel::FunctionRegistry& InternalGetRegistry() { return modern_registry_; } private: cel::FunctionRegistry modern_registry_; // Maintain backwards compatibility for callers expecting CelFunction // interface. // This is not used internally, but some client tests check that a specific // CelFunction overload is used. // Lazily initialized. mutable absl::Mutex mu_; mutable absl::flat_hash_map> functions_ ABSL_GUARDED_BY(mu_); }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_FUNCTION_REGISTRY_H_ ================================================ FILE: eval/public/cel_function_registry_test.cc ================================================ #include "eval/public/cel_function_registry.h" #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/kind.h" #include "eval/internal/adapter_activation_impl.h" #include "eval/public/activation.h" #include "eval/public/cel_function.h" #include "internal/testing.h" #include "runtime/function_overload_reference.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Property; using ::testing::SizeIs; using ::testing::Truly; class ConstCelFunction : public CelFunction { public: ConstCelFunction() : CelFunction(MakeDescriptor()) {} explicit ConstCelFunction(const CelFunctionDescriptor& desc) : CelFunction(desc) {} static CelFunctionDescriptor MakeDescriptor() { return {"ConstFunction", false, {}}; } absl::Status Evaluate(absl::Span args, CelValue* output, google::protobuf::Arena* arena) const override { *output = CelValue::CreateInt64(42); return absl::OkStatus(); } }; TEST(CelFunctionRegistryTest, InsertAndRetrieveLazyFunction) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; Activation activation; ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); const auto descriptors = registry.FindLazyOverloads("LazyFunction", false, {}); EXPECT_THAT(descriptors, testing::SizeIs(1)); } // Confirm that lazy and static functions share the same descriptor space: // i.e. you can't insert both a lazy function and a static function for the same // descriptors. TEST(CelFunctionRegistryTest, LazyAndStaticFunctionShareDescriptorSpace) { CelFunctionRegistry registry; CelFunctionDescriptor desc = ConstCelFunction::MakeDescriptor(); ASSERT_OK(registry.RegisterLazyFunction(desc)); absl::Status status = registry.Register(ConstCelFunction::MakeDescriptor(), std::make_unique()); EXPECT_FALSE(status.ok()); } // Confirm that lazy and static functions share the same descriptor space: // i.e. you can't insert both a lazy function and a static function for the same // descriptors. TEST(CelFunctionRegistryTest, FindStaticOverloadsReturns) { CelFunctionRegistry registry; CelFunctionDescriptor desc = ConstCelFunction::MakeDescriptor(); ASSERT_OK(registry.Register(desc, std::make_unique(desc))); std::vector overloads = registry.FindStaticOverloads(desc.name(), false, {}); EXPECT_THAT(overloads, ElementsAre(Truly( [](const cel::FunctionOverloadReference& overload) -> bool { return overload.descriptor.name() == "ConstFunction"; }))) << "Expected single ConstFunction()"; } TEST(CelFunctionRegistryTest, ListFunctions) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); EXPECT_OK(registry.Register(ConstCelFunction::MakeDescriptor(), std::make_unique())); auto registered_functions = registry.ListFunctions(); EXPECT_THAT(registered_functions, SizeIs(2)); EXPECT_THAT(registered_functions["LazyFunction"], SizeIs(1)); EXPECT_THAT(registered_functions["ConstFunction"], SizeIs(1)); } TEST(CelFunctionRegistryTest, LegacyFindLazyOverloads) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); ASSERT_OK(registry.Register(ConstCelFunction::MakeDescriptor(), std::make_unique())); EXPECT_THAT(registry.FindLazyOverloads("LazyFunction", false, {}), ElementsAre(Truly([](const CelFunctionDescriptor* descriptor) { return descriptor->name() == "LazyFunction"; }))) << "Expected single lazy overload for LazyFunction()"; } TEST(CelFunctionRegistryTest, DefaultLazyProvider) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; Activation activation; cel::interop_internal::AdapterActivationImpl modern_activation(activation); EXPECT_OK(registry.RegisterLazyFunction(lazy_function_desc)); EXPECT_OK(activation.InsertFunction( std::make_unique(lazy_function_desc))); auto providers = registry.ModernFindLazyOverloads("LazyFunction", false, {}); EXPECT_THAT(providers, testing::SizeIs(1)); ASSERT_OK_AND_ASSIGN(auto func, providers[0].provider.GetFunction( lazy_function_desc, modern_activation)); ASSERT_TRUE(func.has_value()); EXPECT_THAT(func->descriptor, Property(&cel::FunctionDescriptor::name, Eq("LazyFunction"))); } TEST(CelFunctionRegistryTest, DefaultLazyProviderNoOverloadFound) { CelFunctionRegistry registry; Activation legacy_activation; cel::interop_internal::AdapterActivationImpl activation(legacy_activation); CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; EXPECT_OK(registry.RegisterLazyFunction(lazy_function_desc)); EXPECT_OK(legacy_activation.InsertFunction( std::make_unique(lazy_function_desc))); const auto providers = registry.ModernFindLazyOverloads("LazyFunction", false, {}); ASSERT_THAT(providers, testing::SizeIs(1)); const auto& provider = providers[0].provider; auto func = provider.GetFunction({"LazyFunc", false, {cel::Kind::kInt64}}, activation); ASSERT_OK(func.status()); EXPECT_EQ(*func, absl::nullopt); } TEST(CelFunctionRegistryTest, DefaultLazyProviderAmbiguousLookup) { CelFunctionRegistry registry; Activation legacy_activation; cel::interop_internal::AdapterActivationImpl activation(legacy_activation); CelFunctionDescriptor desc1{"LazyFunc", false, {CelValue::Type::kInt64}}; CelFunctionDescriptor desc2{"LazyFunc", false, {CelValue::Type::kUint64}}; CelFunctionDescriptor match_desc{"LazyFunc", false, {CelValue::Type::kAny}}; ASSERT_OK(registry.RegisterLazyFunction(match_desc)); ASSERT_OK(legacy_activation.InsertFunction( std::make_unique(desc1))); ASSERT_OK(legacy_activation.InsertFunction( std::make_unique(desc2))); auto providers = registry.ModernFindLazyOverloads("LazyFunc", false, {cel::Kind::kAny}); ASSERT_THAT(providers, testing::SizeIs(1)); const auto& provider = providers[0].provider; auto func = provider.GetFunction(match_desc, activation); EXPECT_THAT(std::string(func.status().message()), HasSubstr("Couldn't resolve function")); } TEST(CelFunctionRegistryTest, CanRegisterNonStrictFunction) { { CelFunctionRegistry registry; CelFunctionDescriptor descriptor("NonStrictFunction", /*receiver_style=*/false, {CelValue::Type::kAny}, /*is_strict=*/false); ASSERT_OK(registry.Register( descriptor, std::make_unique(descriptor))); EXPECT_THAT(registry.FindStaticOverloads("NonStrictFunction", false, {CelValue::Type::kAny}), SizeIs(1)); } { CelFunctionRegistry registry; CelFunctionDescriptor descriptor("NonStrictLazyFunction", /*receiver_style=*/false, {CelValue::Type::kAny}, /*is_strict=*/false); EXPECT_OK(registry.RegisterLazyFunction(descriptor)); EXPECT_THAT(registry.FindLazyOverloads("NonStrictLazyFunction", false, {CelValue::Type::kAny}), SizeIs(1)); } } using NonStrictTestCase = std::tuple; using NonStrictRegistrationFailTest = testing::TestWithParam; TEST_P(NonStrictRegistrationFailTest, IfOtherOverloadExistsRegisteringNonStrictFails) { bool existing_function_is_lazy, new_function_is_lazy; std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); CelFunctionRegistry registry; CelFunctionDescriptor descriptor("OverloadedFunction", /*receiver_style=*/false, {CelValue::Type::kAny}, /*is_strict=*/true); if (existing_function_is_lazy) { ASSERT_OK(registry.RegisterLazyFunction(descriptor)); } else { ASSERT_OK(registry.Register( descriptor, std::make_unique(descriptor))); } CelFunctionDescriptor new_descriptor( "OverloadedFunction", /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, /*is_strict=*/false); absl::Status status; if (new_function_is_lazy) { status = registry.RegisterLazyFunction(new_descriptor); } else { status = registry.Register( new_descriptor, std::make_unique(new_descriptor)); } EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("Only one overload"))); } TEST_P(NonStrictRegistrationFailTest, IfOtherNonStrictExistsRegisteringStrictFails) { bool existing_function_is_lazy, new_function_is_lazy; std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); CelFunctionRegistry registry; CelFunctionDescriptor descriptor("OverloadedFunction", /*receiver_style=*/false, {CelValue::Type::kAny}, /*is_strict=*/false); if (existing_function_is_lazy) { ASSERT_OK(registry.RegisterLazyFunction(descriptor)); } else { ASSERT_OK(registry.Register( descriptor, std::make_unique(descriptor))); } CelFunctionDescriptor new_descriptor( "OverloadedFunction", /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, /*is_strict=*/true); absl::Status status; if (new_function_is_lazy) { status = registry.RegisterLazyFunction(new_descriptor); } else { status = registry.Register( new_descriptor, std::make_unique(new_descriptor)); } EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("Only one overload"))); } TEST_P(NonStrictRegistrationFailTest, CanRegisterStrictFunctionsWithoutLimit) { bool existing_function_is_lazy, new_function_is_lazy; std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); CelFunctionRegistry registry; CelFunctionDescriptor descriptor("OverloadedFunction", /*receiver_style=*/false, {CelValue::Type::kAny}, /*is_strict=*/true); if (existing_function_is_lazy) { ASSERT_OK(registry.RegisterLazyFunction(descriptor)); } else { ASSERT_OK(registry.Register( descriptor, std::make_unique(descriptor))); } CelFunctionDescriptor new_descriptor( "OverloadedFunction", /*receiver_style=*/false, {CelValue::Type::kAny, CelValue::Type::kAny}, /*is_strict=*/true); absl::Status status; if (new_function_is_lazy) { status = registry.RegisterLazyFunction(new_descriptor); } else { status = registry.Register( new_descriptor, std::make_unique(new_descriptor)); } EXPECT_OK(status); } INSTANTIATE_TEST_SUITE_P(NonStrictRegistrationFailTest, NonStrictRegistrationFailTest, testing::Combine(testing::Bool(), testing::Bool())); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_number.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/cel_number.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { absl::optional GetNumberFromCelValue(const CelValue& value) { if (int64_t val; value.GetValue(&val)) { return CelNumber(val); } else if (uint64_t val; value.GetValue(&val)) { return CelNumber(val); } else if (double val; value.GetValue(&val)) { return CelNumber(val); } return absl::nullopt; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_number.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ #include #include #include #include "absl/types/optional.h" #include "eval/public/cel_value.h" #include "internal/number.h" namespace google::api::expr::runtime { using CelNumber = cel::internal::Number; // Return a CelNumber if the value holds a numeric type, otherwise return // nullopt. absl::optional GetNumberFromCelValue(const CelValue& value); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_NUMERIC_H_ ================================================ FILE: eval/public/cel_number_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/cel_number.h" #include #include #include "absl/types/optional.h" #include "eval/public/cel_value.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { using ::testing::Optional; TEST(CelNumber, GetNumberFromCelValue) { EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateDouble(1.1)), Optional(CelNumber::FromDouble(1.1))); EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateInt64(1)), Optional(CelNumber::FromDouble(1.0))); EXPECT_THAT(GetNumberFromCelValue(CelValue::CreateUint64(1)), Optional(CelNumber::FromDouble(1.0))); EXPECT_EQ(GetNumberFromCelValue(CelValue::CreateDuration(absl::Seconds(1))), absl::nullopt); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_options.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/cel_options.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options) { return cel::RuntimeOptions{/*.container=*/"", options.unknown_processing, options.enable_missing_attribute_errors, options.enable_timestamp_duration_overflow_errors, options.short_circuiting, options.enable_comprehension, options.comprehension_max_iterations, options.enable_comprehension_list_append, options.enable_comprehension_mutable_map, options.enable_regex, options.regex_max_program_size, options.enable_string_conversion, options.enable_string_concat, options.enable_list_concat, options.enable_list_contains, options.fail_on_warnings, options.enable_qualified_type_identifiers, options.enable_heterogeneous_equality, options.enable_empty_wrapper_null_unboxing, options.enable_lazy_bind_initialization, options.max_recursion_depth, options.enable_recursive_tracing, options.enable_fast_builtins}; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_options.h ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ #include "absl/base/attributes.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { using UnknownProcessingOptions = cel::UnknownProcessingOptions; using ProtoWrapperTypeOptions = cel::ProtoWrapperTypeOptions; // LINT.IfChange // Interpreter options for controlling evaluation and builtin functions. struct InterpreterOptions { // Level of unknown support enabled. UnknownProcessingOptions unknown_processing = UnknownProcessingOptions::kDisabled; bool enable_missing_attribute_errors = false; // Enable timestamp duration overflow checks. // // The CEL-Spec indicates that overflow should occur outside the range of // string-representable timestamps, and at the limit of durations which can be // expressed with a single int64 value. bool enable_timestamp_duration_overflow_errors = false; // Enable short-circuiting of the logical operator evaluation. If enabled, // AND, OR, and TERNARY do not evaluate the entire expression once the the // resulting value is known from the left-hand side. bool short_circuiting = true; // Enable constant folding during the expression creation. // // Note that expression tracing will apply to a modified expression if this // option is enabled. bool constant_folding = false; // Optionally specified arena for constant folding. If not specified, the // builder will create one as needed per expression built. Any arena created // by the builder will be destroyed when the corresponding expression is // destroyed. google::protobuf::Arena* constant_arena = nullptr; // Enable comprehension expressions (e.g. exists, all) bool enable_comprehension = true; // Set maximum number of iterations in the comprehension expressions if // comprehensions are enabled. The limit applies globally per an evaluation, // including the nested loops as well. Use value 0 to disable the upper bound. int comprehension_max_iterations = 10000; // Enable list append within comprehensions. Note, this option is not safe // with hand-rolled ASTs. bool enable_comprehension_list_append = false; // Enable mutable map construction within comprehensions. Note, this option is // not safe with hand-rolled ASTs. bool enable_comprehension_mutable_map = false; // Enable RE2 match() overload. bool enable_regex = true; // Set maximum program size for RE2 regex if regex overload is enabled. // Evaluates to an error if a regex exceeds it. Use value 0 to disable the // upper bound. int regex_max_program_size = 0; // Enable string() overloads. bool enable_string_conversion = true; // Enable string concatenation overload. bool enable_string_concat = true; // Enable list concatenation overload. bool enable_list_concat = true; // Enable list membership overload. bool enable_list_contains = true; // Treat builder warnings as fatal errors. bool fail_on_warnings = true; // Enable the resolution of qualified type identifiers as type values instead // of field selections. // // This toggle may cause certain identifiers which overlap with CEL built-in // type or with protobuf message types linked into the binary to be resolved // as static type values rather than as per-eval variables. bool enable_qualified_type_identifiers = false; // Enable a check for memory vulnerabilities within comprehension // sub-expressions. // // Note: This flag is not necessary if you are only using Core CEL macros. // // Consider enabling this feature when using custom comprehensions, and // absolutely enable the feature when using hand-written ASTs for // comprehension expressions. bool enable_comprehension_vulnerability_check = false; // Enable heterogeneous comparisons (e.g. support for cross-type comparisons). ABSL_DEPRECATED( "The ability to disable heterogeneous equality is being removed in the " "near future") bool enable_heterogeneous_equality = true; // Enables unwrapping proto wrapper types to null if unset. e.g. if an // expression access a field of type google.protobuf.Int64Value that is unset, // that will result in a Null cel value, as opposed to returning the // cel representation of the proto defined default int64: 0. bool enable_empty_wrapper_null_unboxing = false; // Enables expression rewrites to disambiguate namespace qualified identifiers // from container access for variables and receiver-style calls for functions. // // Note: This makes an implicit copy of the input expression for lifetime // safety. bool enable_qualified_identifier_rewrites = false; // Historically regular expressions were compiled on each invocation to // `matches` and not re-used, even if the regular expression is a constant. // Enabling this option causes constant regular expressions to be compiled // ahead-of-time and re-used for each invocation to `matches`. A side effect // of this is that invalid regular expressions will result in errors when // building an expression. // // It is recommended that this option be enabled in conjunction with // enable_constant_folding. // // Note: In most cases enabling this option is safe, however to perform this // optimization overloads are not consulted for applicable calls. If you have // overridden the default `matches` function you should not enable this // option. bool enable_regex_precompilation = false; // Enable select optimization, replacing long select chains with a single // operation. // // This assumes that the type information at check time agrees with the // configured types at runtime. // // Important: The select optimization follows spec behavior for traversals. // - `enable_empty_wrapper_null_unboxing` is ignored and optimized traversals // always operates as though it is `true`. // - `enable_heterogeneous_equality` is ignored and optimized traversals // always operate as though it is `true`. bool enable_select_optimization = false; // Enable lazy cel.bind alias initialization. // // This is now always enabled. Setting this option has no effect. It will be // removed in a later update. bool enable_lazy_bind_initialization = true; // Enable recursive planning with a maximum recursion depth for evaluable // programs. // // This limit is proportional to the maximum number of recursive Evaluate // calls that a single expression program might require while evaluating. This // is coarse -- the actual C++ stack requirements will vary depending on the // expression. // // This does not account for re-entrant evaluation in a client's extension // function (i.e. a CEL function that calls Evaluate on another CEL program) // // If the limit is exceeded, the planner will return an error instead of // planning the program. // // -1 means unbounded. // 0 means disabled (using a heap-based stack machine instead), which is the // default. int max_recursion_depth = 0; // Enable tracing support for recursively planned programs. // // Unlike the stack machine implementation, supporting tracing can affect // performance whether or not tracing is requested for a given evaluation. bool enable_recursive_tracing = false; // Enable fast implementations for some CEL standard functions. // // Uses a custom implementation for some functions in the CEL standard, // bypassing normal dispatching logic and safety checks for functions. // // This prevents extending or disabling these functions in most cases. The // expression planner will make a best effort attempt to check if custom // overloads have been added for these functions, and will attempt to use them // if they exist. // // Currently applies to !_, @not_strictly_false, _==_, _!=_, @in bool enable_fast_builtins = true; // When enabled, string(double) will format the double with enough precision // to ensure that the original double value can be recovered exactly. // // If available, will use the `std::to_chars` standard library function to // perform the conversion to generate the shortest representation. // // Otherwise, will fall back to formatting with the worst-case required // precision. bool enable_precision_preserving_double_format = true; }; // LINT.ThenChange(//depot/google3/runtime/runtime_options.h) cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ ================================================ FILE: eval/public/cel_type_registry.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/cel_type_registry.h" #include #include #include #include #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "eval/public/structs/legacy_type_adapter.h" #include "google/protobuf/descriptor.h" namespace google::api::expr::runtime { namespace { void AddEnumFromDescriptor(const google::protobuf::EnumDescriptor* desc, CelTypeRegistry& registry) { std::vector enumerators; enumerators.reserve(desc->value_count()); for (int i = 0; i < desc->value_count(); i++) { enumerators.push_back( {std::string(desc->value(i)->name()), desc->value(i)->number()}); } registry.RegisterEnum(desc->full_name(), std::move(enumerators)); } } // namespace void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_descriptor) { AddEnumFromDescriptor(enum_descriptor, *this); } void CelTypeRegistry::RegisterEnum(absl::string_view enum_name, std::vector enumerators) { modern_type_registry_.RegisterEnum(enum_name, std::move(enumerators)); } // Find a type's CelValue instance by its fully qualified name. absl::optional CelTypeRegistry::FindTypeAdapter( absl::string_view fully_qualified_type_name) const { auto maybe_adapter = GetFirstTypeProvider()->ProvideLegacyType(fully_qualified_type_name); if (maybe_adapter.has_value()) { return maybe_adapter; } return absl::nullopt; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_type_registry.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/type_provider.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_provider.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { // CelTypeRegistry manages the set of registered types available for use within // object literal construction, enum comparisons, and type testing. // // The CelTypeRegistry is intended to live for the duration of all CelExpression // values created by a given CelExpressionBuilder and one is created by default // within the standard CelExpressionBuilder. // // By default, all core CEL types and all linked protobuf message types are // implicitly registered by way of the generated descriptor pool. A descriptor // pool can be given to avoid accidentally exposing linked protobuf types to CEL // which were intended to remain internal or to operate on hermetic descriptor // pools. class CelTypeRegistry { public: // Representation of an enum constant. using Enumerator = cel::TypeRegistry::Enumerator; // Representation of an enum. using Enumeration = cel::TypeRegistry::Enumeration; CelTypeRegistry() : CelTypeRegistry(google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory()) {} CelTypeRegistry(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nullable message_factory) : modern_type_registry_(descriptor_pool, message_factory) {} ~CelTypeRegistry() = default; // Register an enum whose values may be used within CEL expressions. // // Enum registration must be performed prior to CelExpression creation. void Register(const google::protobuf::EnumDescriptor* enum_descriptor); // Register an enum whose values may be used within CEL expressions. // // Enum registration must be performed prior to CelExpression creation. void RegisterEnum(absl::string_view name, std::vector enumerators); // Get the first registered type provider. std::shared_ptr GetFirstTypeProvider() const { return cel::runtime_internal::GetLegacyRuntimeTypeProvider( modern_type_registry_); } // Returns the effective type provider that has been configured with the // registry. // // This is a composited type provider that should check in order: // - builtins // - custom enumerations // - registered extension type providers in the order registered. const cel::TypeProvider& GetTypeProvider() const { return modern_type_registry_.GetComposedTypeProvider(); } // Find a type adapter given a fully qualified type name. // Adapter provides a generic interface for the reflection operations the // interpreter needs to provide. absl::optional FindTypeAdapter( absl::string_view fully_qualified_type_name) const; // Return the registered enums configured within the type registry in the // internal format that can be identified as int constants at plan time. const absl::flat_hash_map& resolveable_enums() const { return modern_type_registry_.resolveable_enums(); } // Return the registered enums configured within the type registry. // // This is provided for validating registry setup, it should not be used // internally. // // Invalidated whenever registered enums are updated. absl::flat_hash_set ListResolveableEnums() const { const auto& enums = resolveable_enums(); absl::flat_hash_set result; result.reserve(enums.size()); for (const auto& entry : enums) { result.insert(entry.first); } return result; } // Accessor for underlying modern registry. // // This is exposed for migrating runtime internals, CEL users should not call // this. cel::TypeRegistry& InternalGetModernRegistry() { return modern_type_registry_; } const cel::TypeRegistry& InternalGetModernRegistry() const { return modern_type_registry_; } private: // Internal modern registry. cel::TypeRegistry modern_type_registry_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ ================================================ FILE: eval/public/cel_type_registry_protobuf_reflection_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "google/protobuf/struct.pb.h" #include "absl/types/optional.h" #include "common/memory.h" #include "common/type.h" #include "eval/public/cel_type_registry.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::cel::MemoryManagerRef; using ::cel::StructType; using ::cel::Type; using ::google::protobuf::Struct; using ::testing::AllOf; using ::testing::Contains; using ::testing::Eq; using ::testing::Optional; using ::testing::Pair; using ::testing::UnorderedElementsAre; MATCHER_P(TypeNameIs, name, "") { const Type& type = arg; *result_listener << "got typename: " << type.name(); return type.name() == name; } MATCHER_P(MatchesEnumDescriptor, desc, "") { const auto& enum_type = arg; if (enum_type.enumerators.size() != desc->value_count()) { return false; } for (int i = 0; i < desc->value_count(); i++) { const auto& constant = enum_type.enumerators[i]; const auto* value_desc = desc->value(i); if (value_desc->name() != constant.name) { return false; } if (value_desc->number() != constant.number) { return false; } } return true; } TEST(CelTypeRegistryTest, RegisterEnumDescriptor) { CelTypeRegistry registry; registry.Register(google::protobuf::GetEnumDescriptor()); EXPECT_THAT( registry.ListResolveableEnums(), UnorderedElementsAre("google.protobuf.NullValue", "google.api.expr.runtime.TestMessage.TestEnum")); EXPECT_THAT( registry.resolveable_enums(), AllOf(Contains(Pair( "google.protobuf.NullValue", MatchesEnumDescriptor( google::protobuf::GetEnumDescriptor()))), Contains(Pair( "google.api.expr.runtime.TestMessage.TestEnum", MatchesEnumDescriptor( google::protobuf::GetEnumDescriptor()))))); } TEST(CelTypeRegistryTypeProviderTest, StructTypes) { CelTypeRegistry registry; google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); ASSERT_OK_AND_ASSIGN(absl::optional struct_message_type, registry.GetTypeProvider().FindType( "google.api.expr.runtime.TestMessage")); ASSERT_TRUE(struct_message_type.has_value()); ASSERT_TRUE((*struct_message_type).Is()) << (*struct_message_type).DebugString(); EXPECT_THAT(struct_message_type->As()->name(), Eq("google.api.expr.runtime.TestMessage")); // Can't override builtins. ASSERT_OK_AND_ASSIGN( absl::optional struct_type, registry.GetTypeProvider().FindType("google.protobuf.Struct")); EXPECT_THAT(struct_type, Optional(TypeNameIs("map"))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_type_registry_test.cc ================================================ #include "eval/public/cel_type_registry.h" #include #include #include #include #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/type_provider.h" #include "common/memory.h" #include "common/type.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_provider.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { using ::cel::MemoryManagerRef; using ::cel::Type; using ::cel::TypeProvider; using ::testing::Contains; using ::testing::Key; using ::testing::Optional; class TestTypeProvider : public LegacyTypeProvider { public: explicit TestTypeProvider(std::vector types) : types_(std::move(types)) {} // Return a type adapter for an opaque type // (no reflection operations supported). absl::optional ProvideLegacyType( absl::string_view name) const override { for (const auto& type : types_) { if (name == type) { return LegacyTypeAdapter(/*access=*/nullptr, /*mutation=*/nullptr); } } return absl::nullopt; } private: std::vector types_; }; TEST(CelTypeRegistryTest, RegisterEnum) { CelTypeRegistry registry; registry.RegisterEnum("google.api.expr.runtime.TestMessage.TestEnum", { {"TEST_ENUM_UNSPECIFIED", 0}, {"TEST_ENUM_1", 10}, {"TEST_ENUM_2", 20}, {"TEST_ENUM_3", 30}, }); EXPECT_THAT(registry.resolveable_enums(), Contains(Key("google.api.expr.runtime.TestMessage.TestEnum"))); } TEST(CelTypeRegistryTest, TestRegisterBuiltInEnum) { CelTypeRegistry registry; ASSERT_THAT(registry.resolveable_enums(), Contains(Key("google.protobuf.NullValue"))); } TEST(CelTypeRegistryTest, TestGetFirstTypeProviderSuccess) { CelTypeRegistry registry; auto type_provider = registry.GetFirstTypeProvider(); ASSERT_NE(type_provider, nullptr); ASSERT_FALSE( type_provider->ProvideLegacyType("google.protobuf.Int64").has_value()); ASSERT_TRUE( type_provider->ProvideLegacyType("google.protobuf.Any").has_value()); } TEST(CelTypeRegistryTest, TestFindTypeAdapterFound) { CelTypeRegistry registry; auto desc = registry.FindTypeAdapter("google.protobuf.Any"); ASSERT_TRUE(desc.has_value()); } TEST(CelTypeRegistryTest, TestFindTypeAdapterFoundMultipleProviders) { CelTypeRegistry registry; auto desc = registry.FindTypeAdapter("google.protobuf.Any"); ASSERT_TRUE(desc.has_value()); } TEST(CelTypeRegistryTest, TestFindTypeAdapterNotFound) { CelTypeRegistry registry; auto desc = registry.FindTypeAdapter("missing.MessageType"); EXPECT_FALSE(desc.has_value()); } MATCHER_P(TypeNameIs, name, "") { const Type& type = arg; *result_listener << "got typename: " << type.name(); return type.name() == name; } TEST(CelTypeRegistryTypeProviderTest, Builtins) { CelTypeRegistry registry; // simple ASSERT_OK_AND_ASSIGN(absl::optional bool_type, registry.GetTypeProvider().FindType("bool")); EXPECT_THAT(bool_type, Optional(TypeNameIs("bool"))); // opaque ASSERT_OK_AND_ASSIGN( absl::optional timestamp_type, registry.GetTypeProvider().FindType("google.protobuf.Timestamp")); EXPECT_THAT(timestamp_type, Optional(TypeNameIs("google.protobuf.Timestamp"))); // wrapper ASSERT_OK_AND_ASSIGN( absl::optional int_wrapper_type, registry.GetTypeProvider().FindType("google.protobuf.Int64Value")); EXPECT_THAT(int_wrapper_type, Optional(TypeNameIs("google.protobuf.Int64Value"))); // json ASSERT_OK_AND_ASSIGN( absl::optional json_struct_type, registry.GetTypeProvider().FindType("google.protobuf.Struct")); EXPECT_THAT(json_struct_type, Optional(TypeNameIs("map"))); // special ASSERT_OK_AND_ASSIGN( absl::optional any_type, registry.GetTypeProvider().FindType("google.protobuf.Any")); EXPECT_THAT(any_type, Optional(TypeNameIs("google.protobuf.Any"))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_value.cc ================================================ #include "eval/public/cel_value.h" #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/memory.h" #include "eval/internal/errors.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "extensions/protobuf/memory_manager.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::google::protobuf::Arena; namespace interop = ::cel::interop_internal; constexpr absl::string_view kNullTypeName = "null_type"; constexpr absl::string_view kBoolTypeName = "bool"; constexpr absl::string_view kInt64TypeName = "int"; constexpr absl::string_view kUInt64TypeName = "uint"; constexpr absl::string_view kDoubleTypeName = "double"; constexpr absl::string_view kStringTypeName = "string"; constexpr absl::string_view kBytesTypeName = "bytes"; constexpr absl::string_view kDurationTypeName = "google.protobuf.Duration"; constexpr absl::string_view kTimestampTypeName = "google.protobuf.Timestamp"; // Leading "." to prevent potential namespace clash. constexpr absl::string_view kListTypeName = "list"; constexpr absl::string_view kMapTypeName = "map"; constexpr absl::string_view kCelTypeTypeName = "type"; struct DebugStringVisitor { google::protobuf::Arena* const arena; std::string operator()(bool arg) { return absl::StrFormat("%d", arg); } std::string operator()(int64_t arg) { return absl::StrFormat("%lld", arg); } std::string operator()(uint64_t arg) { return absl::StrFormat("%llu", arg); } std::string operator()(double arg) { return absl::StrFormat("%f", arg); } std::string operator()(CelValue::NullType) { return "null"; } std::string operator()(CelValue::StringHolder arg) { return absl::StrFormat("%s", arg.value()); } std::string operator()(CelValue::BytesHolder arg) { return absl::StrFormat("%s", arg.value()); } std::string operator()(const MessageWrapper& arg) { return arg.message_ptr() == nullptr ? "NULL" : arg.legacy_type_info()->DebugString(arg); } std::string operator()(absl::Duration arg) { return absl::FormatDuration(arg); } std::string operator()(absl::Time arg) { return absl::FormatTime(arg, absl::UTCTimeZone()); } std::string operator()(const CelList* arg) { std::vector elements; elements.reserve(arg->size()); for (int i = 0; i < arg->size(); i++) { elements.push_back(arg->Get(arena, i).DebugString()); } return absl::StrCat("[", absl::StrJoin(elements, ", "), "]"); } std::string operator()(const CelMap* arg) { auto keys_or_error = arg->ListKeys(arena); if (!keys_or_error.status().ok()) { return "invalid list keys"; } const CelList* keys = std::move(keys_or_error.value()); std::vector elements; elements.reserve(keys->size()); for (int i = 0; i < keys->size(); i++) { const auto& key = (*keys).Get(arena, i); const auto& optional_value = arg->Get(arena, key); elements.push_back(absl::StrCat("<", key.DebugString(), ">: <", optional_value.has_value() ? optional_value->DebugString() : "nullopt", ">")); } return absl::StrCat("{", absl::StrJoin(elements, ", "), "}"); } std::string operator()(const UnknownSet* arg) { return "?"; // Not implemented. } std::string operator()(CelValue::CelTypeHolder arg) { return absl::StrCat(arg.value()); } std::string operator()(const CelError* arg) { return arg->ToString(); } }; } // namespace ABSL_CONST_INIT const absl::string_view kPayloadUrlMissingAttributePath = cel::runtime_internal::kPayloadUrlMissingAttributePath; CelValue CelValue::CreateDuration(absl::Duration value) { if (value >= cel::runtime_internal::kDurationHigh || value <= cel::runtime_internal::kDurationLow) { return CelValue(cel::runtime_internal::DurationOverflowError()); } return CreateUncheckedDuration(value); } // TODO(issues/136): These don't match the CEL runtime typenames. They should // be updated where possible for consistency. std::string CelValue::TypeName(Type value_type) { switch (value_type) { case Type::kNullType: return "null_type"; case Type::kBool: return "bool"; case Type::kInt64: return "int64"; case Type::kUint64: return "uint64"; case Type::kDouble: return "double"; case Type::kString: return "string"; case Type::kBytes: return "bytes"; case Type::kMessage: return "Message"; case Type::kDuration: return "Duration"; case Type::kTimestamp: return "Timestamp"; case Type::kList: return "CelList"; case Type::kMap: return "CelMap"; case Type::kCelType: return "CelType"; case Type::kUnknownSet: return "UnknownSet"; case Type::kError: return "CelError"; case Type::kAny: return "Any type"; default: return "unknown"; } } absl::Status CelValue::CheckMapKeyType(const CelValue& key) { switch (key.type()) { case CelValue::Type::kString: case CelValue::Type::kInt64: case CelValue::Type::kUint64: case CelValue::Type::kBool: return absl::OkStatus(); default: return absl::InvalidArgumentError(absl::StrCat( "Invalid map key type: '", CelValue::TypeName(key.type()), "'")); } } CelValue CelValue::ObtainCelType() const { switch (type()) { case Type::kNullType: return CreateCelType(CelTypeHolder(kNullTypeName)); case Type::kBool: return CreateCelType(CelTypeHolder(kBoolTypeName)); case Type::kInt64: return CreateCelType(CelTypeHolder(kInt64TypeName)); case Type::kUint64: return CreateCelType(CelTypeHolder(kUInt64TypeName)); case Type::kDouble: return CreateCelType(CelTypeHolder(kDoubleTypeName)); case Type::kString: return CreateCelType(CelTypeHolder(kStringTypeName)); case Type::kBytes: return CreateCelType(CelTypeHolder(kBytesTypeName)); case Type::kMessage: { MessageWrapper wrapper; CelValue::GetValue(&wrapper); if (wrapper.message_ptr() == nullptr) { return CreateCelType(CelTypeHolder(kNullTypeName)); } // Descritptor::full_name() returns const reference, so using pointer // should be safe. return CreateCelType( CelTypeHolder(wrapper.legacy_type_info()->GetTypename(wrapper))); } case Type::kDuration: return CreateCelType(CelTypeHolder(kDurationTypeName)); case Type::kTimestamp: return CreateCelType(CelTypeHolder(kTimestampTypeName)); case Type::kList: return CreateCelType(CelTypeHolder(kListTypeName)); case Type::kMap: return CreateCelType(CelTypeHolder(kMapTypeName)); case Type::kCelType: return CreateCelType(CelTypeHolder(kCelTypeTypeName)); case Type::kUnknownSet: return *this; case Type::kError: return *this; default: { static const CelError* invalid_type_error = new CelError(absl::InvalidArgumentError("Unsupported CelValue type")); return CreateError(invalid_type_error); } } } // Returns debug string describing a value const std::string CelValue::DebugString() const { google::protobuf::Arena arena; return absl::StrCat(CelValue::TypeName(type()), ": ", InternalVisit(DebugStringVisitor{&arena})); } namespace { class EmptyCelList final : public CelList { public: static const EmptyCelList* Get() { static const absl::NoDestructor instance; return &*instance; } CelValue operator[](int index) const override { static const CelError* invalid_argument = new CelError(absl::InvalidArgumentError("index out of bounds")); return CelValue::CreateError(invalid_argument); } int size() const override { return 0; } bool empty() const override { return true; } }; class EmptyCelMap final : public CelMap { public: static const EmptyCelMap* Get() { static const absl::NoDestructor instance; return &*instance; } absl::optional operator[](CelValue key) const override { return absl::nullopt; } absl::StatusOr Has(const CelValue& key) const override { CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(key)); return false; } int size() const override { return 0; } bool empty() const override { return true; } absl::StatusOr ListKeys() const override { return EmptyCelList::Get(); } }; } // namespace CelValue CelValue::CreateList() { return CreateList(EmptyCelList::Get()); } CelValue CelValue::CreateMap() { return CreateMap(EmptyCelMap::Get()); } CelValue CreateErrorValue(cel::MemoryManagerRef manager, absl::string_view message, absl::StatusCode error_code) { // TODO(uncreated-issue/1): assume arena-style allocator while migrating to new // value type. Arena* arena = cel::extensions::ProtoMemoryManagerArena(manager); return CreateErrorValue(arena, message, error_code); } CelValue CreateErrorValue(cel::MemoryManagerRef manager, const absl::Status& status) { // TODO(uncreated-issue/1): assume arena-style allocator while migrating to new // value type. Arena* arena = cel::extensions::ProtoMemoryManagerArena(manager); return CreateErrorValue(arena, status); } CelValue CreateErrorValue(Arena* arena, absl::string_view message, absl::StatusCode error_code) { CelError* error = Arena::Create(arena, error_code, message); return CelValue::CreateError(error); } CelValue CreateErrorValue(Arena* arena, const absl::Status& status) { CelError* error = Arena::Create(arena, status); return CelValue::CreateError(error); } CelValue CreateNoMatchingOverloadError(cel::MemoryManagerRef manager, absl::string_view fn) { return CelValue::CreateError(interop::CreateNoMatchingOverloadError( cel::extensions::ProtoMemoryManagerArena(manager), fn)); } CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena, absl::string_view fn) { return CelValue::CreateError( interop::CreateNoMatchingOverloadError(arena, fn)); } bool CheckNoMatchingOverloadError(CelValue value) { return value.IsError() && value.ErrorOrDie()->code() == absl::StatusCode::kUnknown && absl::StrContains(value.ErrorOrDie()->message(), cel::runtime_internal::kErrNoMatchingOverload); } CelValue CreateNoSuchFieldError(cel::MemoryManagerRef manager, absl::string_view field) { return CelValue::CreateError(interop::CreateNoSuchFieldError( cel::extensions::ProtoMemoryManagerArena(manager), field)); } CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena, absl::string_view field) { return CelValue::CreateError(interop::CreateNoSuchFieldError(arena, field)); } CelValue CreateNoSuchKeyError(cel::MemoryManagerRef manager, absl::string_view key) { return CelValue::CreateError(interop::CreateNoSuchKeyError( cel::extensions::ProtoMemoryManagerArena(manager), key)); } CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key) { return CelValue::CreateError(interop::CreateNoSuchKeyError(arena, key)); } bool CheckNoSuchKeyError(CelValue value) { return value.IsError() && absl::StartsWith(value.ErrorOrDie()->message(), cel::runtime_internal::kErrNoSuchKey); } CelValue CreateMissingAttributeError(google::protobuf::Arena* arena, absl::string_view missing_attribute_path) { return CelValue::CreateError( interop::CreateMissingAttributeError(arena, missing_attribute_path)); } CelValue CreateMissingAttributeError(cel::MemoryManagerRef manager, absl::string_view missing_attribute_path) { // TODO(uncreated-issue/1): assume arena-style allocator while migrating // to new value type. return CelValue::CreateError(interop::CreateMissingAttributeError( cel::extensions::ProtoMemoryManagerArena(manager), missing_attribute_path)); } bool IsMissingAttributeError(const CelValue& value) { const CelError* error; if (!value.GetValue(&error)) return false; if (error && error->code() == absl::StatusCode::kInvalidArgument) { auto path = error->GetPayload( cel::runtime_internal::kPayloadUrlMissingAttributePath); return path.has_value(); } return false; } CelValue CreateUnknownFunctionResultError(cel::MemoryManagerRef manager, absl::string_view help_message) { return CelValue::CreateError(interop::CreateUnknownFunctionResultError( cel::extensions::ProtoMemoryManagerArena(manager), help_message)); } CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, absl::string_view help_message) { return CelValue::CreateError( interop::CreateUnknownFunctionResultError(arena, help_message)); } bool IsUnknownFunctionResult(const CelValue& value) { const CelError* error; if (!value.GetValue(&error)) return false; if (error == nullptr || error->code() != absl::StatusCode::kUnavailable) { return false; } auto payload = error->GetPayload( cel::runtime_internal::kPayloadUrlUnknownFunctionResult); return payload.has_value() && payload.value() == "true"; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/cel_value.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_H_ // CelValue is a holder, capable of storing all kinds of data // supported by CEL. // CelValue defines explicitly typed/named getters/setters. // When storing pointers to objects, CelValue does not accept ownership // to them and does not control their lifecycle. Instead objects are expected // to be either external to expression evaluation, and controlled beyond the // scope or to be allocated and associated with some allocation/ownership // controller (Arena). // Usage examples: // (a) For primitive types: // CelValue value = CelValue::CreateInt64(1); // (b) For string: // string* msg = google::protobuf::Arena::Create(arena,"test"); // CelValue value = CelValue::CreateString(msg); // (c) For messages: // const MyMessage * msg = google::protobuf::Arena::Create(arena); // CelValue value = CelProtoWrapper::CreateMessage(msg, &arena); #include #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "common/kind.h" #include "common/memory.h" #include "common/native_type.h" #include "eval/public/cel_value_internal.h" #include "eval/public/message_wrapper.h" #include "eval/public/unknown_set.h" #include "internal/casts.h" #include "internal/status_macros.h" #include "internal/utf8.h" #include "google/protobuf/message.h" namespace cel::interop_internal { struct CelListAccess; struct CelMapAccess; } // namespace cel::interop_internal namespace google::api::expr::runtime { using CelError = absl::Status; // Break cyclic dependencies for container types. class CelList; class CelMap; class LegacyTypeAdapter; class CelValue { public: // This class is a container to hold strings/bytes. // Template parameter N is an artificial discriminator, used to create // distinct types for String and Bytes (we need distinct types for Oneof). template class StringHolderBase { public: StringHolderBase() : value_(absl::string_view()) {} StringHolderBase(const StringHolderBase&) = default; StringHolderBase& operator=(const StringHolderBase&) = default; // string parameter is passed through pointer to ensure string_view is not // initialized with string rvalue. Also, according to Google style guide, // passing pointers conveys the message that the reference to string is kept // in the constructed holder object. explicit StringHolderBase(const std::string* str) : value_(*str) {} absl::string_view value() const { return value_; } // Group of comparison operations. friend bool operator==(StringHolderBase value1, StringHolderBase value2) { return value1.value_ == value2.value_; } friend bool operator!=(StringHolderBase value1, StringHolderBase value2) { return value1.value_ != value2.value_; } friend bool operator<(StringHolderBase value1, StringHolderBase value2) { return value1.value_ < value2.value_; } friend bool operator<=(StringHolderBase value1, StringHolderBase value2) { return value1.value_ <= value2.value_; } friend bool operator>(StringHolderBase value1, StringHolderBase value2) { return value1.value_ > value2.value_; } friend bool operator>=(StringHolderBase value1, StringHolderBase value2) { return value1.value_ >= value2.value_; } friend class CelValue; private: explicit StringHolderBase(absl::string_view other) : value_(other) {} absl::string_view value_; }; // Helper structure for String datatype. using StringHolder = StringHolderBase<0>; // Helper structure for Bytes datatype. using BytesHolder = StringHolderBase<1>; // Helper structure for CelType datatype. using CelTypeHolder = StringHolderBase<2>; // Type for CEL Null values. Implemented as a monostate to behave well in // absl::variant. using NullType = absl::monostate; // GCC: fully qualified to avoid change of meaning error. using MessageWrapper = google::api::expr::runtime::MessageWrapper; private: // CelError MUST BE the last in the declaration - it is a ceiling for Type // enum using ValueHolder = internal::ValueHolder< NullType, bool, int64_t, uint64_t, double, StringHolder, BytesHolder, MessageWrapper, absl::Duration, absl::Time, const CelList*, const CelMap*, const UnknownSet*, CelTypeHolder, const CelError*>; public: // Metafunction providing positions corresponding to specific // types. If type is not supported, compile-time error will occur. template using IndexOf = ValueHolder::IndexOf; // Enum for types supported. // This is not recommended for use in exhaustive switches in client code. // Types may be updated over time. using Type = ::cel::Kind; // Legacy enumeration that is here for testing purposes. Do not use. enum class LegacyType { kNullType = IndexOf::value, kBool = IndexOf::value, kInt64 = IndexOf::value, kUint64 = IndexOf::value, kDouble = IndexOf::value, kString = IndexOf::value, kBytes = IndexOf::value, kMessage = IndexOf::value, kDuration = IndexOf::value, kTimestamp = IndexOf::value, kList = IndexOf::value, kMap = IndexOf::value, kUnknownSet = IndexOf::value, kCelType = IndexOf::value, kError = IndexOf::value, kAny // Special value. Used in function descriptors. }; // Default constructor. // Creates CelValue with null data type. CelValue() : CelValue(NullType()) {} // Returns Type that describes the type of value stored. Type type() const { return static_cast(value_.index()); } // Returns debug string describing a value const std::string DebugString() const; // We will use factory methods instead of public constructors // The reason for this is the high risk of implicit type conversions // between bool/int/pointer types. // We rely on copy elision to avoid extra copying. static CelValue CreateNull() { return CelValue(NullType()); } // Transitional factory for migrating to null types. static CelValue CreateNullTypedValue() { return CelValue(NullType()); } static CelValue CreateBool(bool value) { return CelValue(value); } static CelValue CreateInt64(int64_t value) { return CelValue(value); } static CelValue CreateUint64(uint64_t value) { return CelValue(value); } static CelValue CreateDouble(double value) { return CelValue(value); } static CelValue CreateString(StringHolder holder) { ABSL_ASSERT(::cel::internal::Utf8IsValid(holder.value())); return CelValue(holder); } // Returns a string value from a string_view. Warning: the caller is // responsible for the lifecycle of the backing string. Prefer CreateString // instead. static CelValue CreateStringView(absl::string_view value) { return CelValue(StringHolder(value)); } static CelValue CreateString(const std::string* str) { return CelValue(StringHolder(str)); } static CelValue CreateBytes(BytesHolder holder) { return CelValue(holder); } static CelValue CreateBytesView(absl::string_view value) { return CelValue(BytesHolder(value)); } static CelValue CreateBytes(const std::string* str) { return CelValue(BytesHolder(str)); } static CelValue CreateDuration(absl::Duration value); static CelValue CreateUncheckedDuration(absl::Duration value) { return CelValue(value); } static CelValue CreateTimestamp(absl::Time value) { return CelValue(value); } static CelValue CreateList(const CelList* value) { CheckNullPointer(value, Type::kList); return CelValue(value); } // Creates a CelValue backed by an empty immutable list. static CelValue CreateList(); static CelValue CreateMap(const CelMap* value) { CheckNullPointer(value, Type::kMap); return CelValue(value); } // Creates a CelValue backed by an empty immutable map. static CelValue CreateMap(); static CelValue CreateUnknownSet(const UnknownSet* value) { CheckNullPointer(value, Type::kUnknownSet); return CelValue(value); } static CelValue CreateCelType(CelTypeHolder holder) { return CelValue(holder); } static CelValue CreateCelTypeView(absl::string_view value) { // This factory method is used for dealing with string references which // come from protobuf objects or other containers which promise pointer // stability. In general, this is a risky method to use and should not // be invoked outside the core CEL library. return CelValue(CelTypeHolder(value)); } static CelValue CreateError(const CelError* value) { CheckNullPointer(value, Type::kError); return CelValue(value); } // Returns an absl::OkStatus() when the key is a valid protobuf map type, // meaning it is a scalar value that is neither floating point nor bytes. static absl::Status CheckMapKeyType(const CelValue& key); // Obtain the CelType of the value. CelValue ObtainCelType() const; // Methods for accessing values of specific type // They have the common usage pattern - prior to accessing the // value, the caller should check that the value of this type is indeed // stored in CelValue, using type() or Is...() methods. // Returns stored boolean value. // Fails if stored value type is not boolean. bool BoolOrDie() const { return GetValueOrDie(Type::kBool); } // Returns stored int64 value. // Fails if stored value type is not int64. int64_t Int64OrDie() const { return GetValueOrDie(Type::kInt64); } // Returns stored uint64 value. // Fails if stored value type is not uint64. uint64_t Uint64OrDie() const { return GetValueOrDie(Type::kUint64); } // Returns stored double value. // Fails if stored value type is not double. double DoubleOrDie() const { return GetValueOrDie(Type::kDouble); } // Returns stored const string* value. // Fails if stored value type is not const string*. StringHolder StringOrDie() const { return GetValueOrDie(Type::kString); } BytesHolder BytesOrDie() const { return GetValueOrDie(Type::kBytes); } // Returns stored const Message* value. // Fails if stored value type is not const Message*. const google::protobuf::Message* MessageOrDie() const { MessageWrapper wrapped = MessageWrapperOrDie(); ABSL_ASSERT(wrapped.HasFullProto()); return static_cast(wrapped.message_ptr()); } ABSL_DEPRECATED("Use MessageOrDie") MessageWrapper MessageWrapperOrDie() const { return GetValueOrDie(Type::kMessage); } // Returns stored duration value. // Fails if stored value type is not duration. const absl::Duration DurationOrDie() const { return GetValueOrDie(Type::kDuration); } // Returns stored timestamp value. // Fails if stored value type is not timestamp. const absl::Time TimestampOrDie() const { return GetValueOrDie(Type::kTimestamp); } // Returns stored const CelList* value. // Fails if stored value type is not const CelList*. const CelList* ListOrDie() const { return GetValueOrDie(Type::kList); } // Returns stored const CelMap * value. // Fails if stored value type is not const CelMap *. const CelMap* MapOrDie() const { return GetValueOrDie(Type::kMap); } // Returns stored const CelTypeHolder value. // Fails if stored value type is not CelTypeHolder. CelTypeHolder CelTypeOrDie() const { return GetValueOrDie(Type::kCelType); } // Returns stored const UnknownAttributeSet * value. // Fails if stored value type is not const UnknownAttributeSet *. const UnknownSet* UnknownSetOrDie() const { return GetValueOrDie(Type::kUnknownSet); } // Returns stored const CelError * value. // Fails if stored value type is not const CelError *. const CelError* ErrorOrDie() const { return GetValueOrDie(Type::kError); } bool IsNull() const { return value_.template Visit(NullCheckOp()); } bool IsBool() const { return value_.is(); } bool IsInt64() const { return value_.is(); } bool IsUint64() const { return value_.is(); } bool IsDouble() const { return value_.is(); } bool IsString() const { return value_.is(); } bool IsBytes() const { return value_.is(); } bool IsMessage() const { return value_.is(); } bool IsDuration() const { return value_.is(); } bool IsTimestamp() const { return value_.is(); } bool IsList() const { return value_.is(); } bool IsMap() const { return value_.is(); } bool IsUnknownSet() const { return value_.is(); } bool IsCelType() const { return value_.is(); } bool IsError() const { return value_.is(); } // Invokes op() with the active value, and returns the result. // All overloads of op() must have the same return type. // Note: this depends on the internals of CelValue, so use with caution. template ReturnType InternalVisit(Op&& op) const { return value_.template Visit(std::forward(op)); } // Invokes op() with the active value, and returns the result. // All overloads of op() must have the same return type. // TODO(uncreated-issue/2): Move to CelProtoWrapper to retain the assumed // google::protobuf::Message variant version behavior for client code. template ReturnType Visit(Op&& op) const { return value_.template Visit( internal::MessageVisitAdapter(std::forward(op))); } // Template-style getter. // Returns true, if assignment successful template bool GetValue(Arg* value) const { return this->template InternalVisit(AssignerOp(value)); } // Provides type names for internal logging. static std::string TypeName(Type value_type); // Factory for message wrapper. This should only be used by internal // libraries. // TODO(uncreated-issue/2): exposed for testing while wiring adapter APIs. Should // make private visibility after refactors are done. ABSL_DEPRECATED("Use CelProtoWrapper::CreateMessage") static CelValue CreateMessageWrapper(MessageWrapper value) { CheckNullPointer(value.message_ptr(), Type::kMessage); CheckNullPointer(value.legacy_type_info(), Type::kMessage); return CelValue(value); } private: ValueHolder value_; template struct AssignerOp { explicit AssignerOp(T* val) : value(val) {} template bool operator()(const U&) { return false; } bool operator()(const T& arg) { *value = arg; return true; } T* value; }; // Specialization for MessageWrapper to support legacy behavior while // migrating off hard dependency on google::protobuf::Message. // TODO(uncreated-issue/2): Move to CelProtoWrapper. template struct AssignerOp< T, std::enable_if_t>> { explicit AssignerOp(const google::protobuf::Message** val) : value(val) {} template bool operator()(const U&) { return false; } bool operator()(const MessageWrapper& held_value) { if (!held_value.HasFullProto()) { return false; } *value = static_cast(held_value.message_ptr()); return true; } const google::protobuf::Message** value; }; struct NullCheckOp { template bool operator()(const T&) const { return false; } bool operator()(NullType) const { return true; } // Note: this is not typically possible, but is supported for allowing // function resolution for null ptrs as Messages. bool operator()(const MessageWrapper& arg) const { return arg.message_ptr() == nullptr; } }; // Constructs CelValue wrapping value supplied as argument. // Value type T should be supported by specification of ValueHolder. template explicit CelValue(T value) : value_(value) {} // Crashes with a null pointer error. static void CrashNullPointer(Type type) ABSL_ATTRIBUTE_COLD { ABSL_LOG(FATAL) << "Null pointer supplied for " << TypeName(type); // Crash ok } // Null pointer checker for pointer-based types. static void CheckNullPointer(const void* ptr, Type type) { if (ABSL_PREDICT_FALSE(ptr == nullptr)) { CrashNullPointer(type); } } // Crashes with a type mismatch error. static void CrashTypeMismatch(Type requested_type, Type actual_type) ABSL_ATTRIBUTE_COLD { ABSL_LOG(FATAL) << "Type mismatch" // Crash ok << ": expected " << TypeName(requested_type) // Crash ok << ", encountered " << TypeName(actual_type); // Crash ok } // Gets value of type specified template T GetValueOrDie(Type requested_type) const { auto value_ptr = value_.get(); if (ABSL_PREDICT_FALSE(value_ptr == nullptr)) { CrashTypeMismatch(requested_type, type()); } return *value_ptr; } friend class CelProtoWrapper; friend class ProtoMessageTypeAdapter; friend class EvaluatorStack; friend class TestOnly_FactoryAccessor; }; static_assert(absl::is_trivially_destructible::value, "Non-trivially-destructible CelValue impacts " "performance"); // CelList is a base class for list adapting classes. class CelList { public: ABSL_DEPRECATED( "Unless you are sure of the underlying CelList implementation, call Get " "and pass an arena instead") virtual CelValue operator[](int index) const = 0; // Like `operator[](int)` above, but also accepts an arena. Prefer calling // this variant if the arena is known. virtual CelValue Get(google::protobuf::Arena* arena, int index) const { static_cast(arena); return (*this)[index]; } // List size virtual int size() const = 0; // Default empty check. Can be overridden in subclass for performance. virtual bool empty() const { return size() == 0; } virtual ~CelList() {} private: friend struct cel::interop_internal::CelListAccess; friend struct cel::NativeTypeTraits; virtual cel::NativeTypeId GetNativeTypeId() const { return cel::NativeTypeId(); } }; // CelMap is a base class for map accessors. class CelMap { public: // Map lookup. If value found, returns CelValue in return type. // // Per the protobuf specification, acceptable key types are bool, int64, // uint64, string. Any key type that is not supported should result in valued // response containing an absl::StatusCode::kInvalidArgument wrapped as a // CelError. // // Type specializations are permitted since CEL supports such distinctions // at type-check time. For example, the expression `1 in map_str` where the // variable `map_str` is of type map(string, string) will yield a type-check // error. To be consistent, the runtime should also yield an invalid argument // error if the type does not agree with the expected key types held by the // container. // TODO(issues/122): Make this method const correct. ABSL_DEPRECATED( "Unless you are sure of the underlying CelMap implementation, call Get " "and pass an arena instead") virtual absl::optional operator[](CelValue key) const = 0; // Like `operator[](CelValue)` above, but also accepts an arena. Prefer // calling this variant if the arena is known. virtual absl::optional Get(google::protobuf::Arena* arena, CelValue key) const { static_cast(arena); return (*this)[key]; } // Return whether the key is present within the map. // // Typically, key resolution will be a simple boolean result; however, there // are scenarios where the conversion of the input key to the underlying // key-type will produce an absl::StatusCode::kInvalidArgument. // // Evaluators are responsible for handling non-OK results by propagating the // error, as appropriate, up the evaluation stack either as a `StatusOr` or // as a `CelError` value, depending on the context. virtual absl::StatusOr Has(const CelValue& key) const { // This check safeguards against issues with invalid key types such as NaN. CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(key)); google::protobuf::Arena arena; auto value = (*this).Get(&arena, key); if (!value.has_value()) { return false; } // This protects from issues that may occur when looking up a key value, // such as a failure to convert an int64 to an int32 map key. if (value->IsError()) { return *value->ErrorOrDie(); } return true; } // Map size virtual int size() const = 0; // Default empty check. Can be overridden in subclass for performance. virtual bool empty() const { return size() == 0; } // Return list of keys. CelList is owned by Arena, so no // ownership is passed. ABSL_DEPRECATED( "Unless you are sure of the underlying CelMap implementation, call " "ListKeys and pass an arena instead") virtual absl::StatusOr ListKeys() const = 0; // Like `ListKeys()` above, but also accepts an arena. Prefer calling this // variant if the arena is known. virtual absl::StatusOr ListKeys(google::protobuf::Arena* arena) const { static_cast(arena); return ListKeys(); } virtual ~CelMap() {} private: friend struct cel::interop_internal::CelMapAccess; friend struct cel::NativeTypeTraits; virtual cel::NativeTypeId GetNativeTypeId() const { return cel::NativeTypeId(); } }; // Utility method that generates CelValue containing CelError. // message an error message // error_code error code CelValue CreateErrorValue( cel::MemoryManagerRef manager ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view message, absl::StatusCode error_code = absl::StatusCode::kUnknown); CelValue CreateErrorValue( google::protobuf::Arena* arena, absl::string_view message, absl::StatusCode error_code = absl::StatusCode::kUnknown); // Utility method for generating a CelValue from an absl::Status. CelValue CreateErrorValue(cel::MemoryManagerRef manager ABSL_ATTRIBUTE_LIFETIME_BOUND, const absl::Status& status); // Utility method for generating a CelValue from an absl::Status. CelValue CreateErrorValue(google::protobuf::Arena* arena, const absl::Status& status); // Create an error for failed overload resolution, optionally including the name // of the function. CelValue CreateNoMatchingOverloadError(cel::MemoryManagerRef manager ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view fn = ""); ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena, absl::string_view fn = ""); bool CheckNoMatchingOverloadError(CelValue value); CelValue CreateNoSuchFieldError(cel::MemoryManagerRef manager ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view field = ""); ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena, absl::string_view field = ""); CelValue CreateNoSuchKeyError(cel::MemoryManagerRef manager ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view key); ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view key); bool CheckNoSuchKeyError(CelValue value); // Returns an error indicating that evaluation has accessed an attribute whose // value is undefined. For example, this may represent a field in a proto // message bound to the activation whose value can't be determined by the // hosting application. CelValue CreateMissingAttributeError(cel::MemoryManagerRef manager ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view missing_attribute_path); ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateMissingAttributeError(google::protobuf::Arena* arena, absl::string_view missing_attribute_path); ABSL_CONST_INIT extern const absl::string_view kPayloadUrlMissingAttributePath; bool IsMissingAttributeError(const CelValue& value); // Returns error indicating the result of the function is unknown. This is used // as a signal to create an unknown set if unknown function handling is opted // into. CelValue CreateUnknownFunctionResultError(cel::MemoryManagerRef manager ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view help_message); ABSL_DEPRECATED("Prefer using the generic MemoryManager overload") CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, absl::string_view help_message); // Returns true if this is unknown value error indicating that evaluation // called an extension function whose value is unknown for the given args. // This is used as a signal to convert to an UnknownSet if the behavior is opted // into. bool IsUnknownFunctionResult(const CelValue& value); } // namespace google::api::expr::runtime namespace cel { template <> struct NativeTypeTraits final { static NativeTypeId Id(const google::api::expr::runtime::CelList& cel_list) { return cel_list.GetNativeTypeId(); } }; template struct NativeTypeTraits< T, std::enable_if_t, std::negation>>>> final { static NativeTypeId Id(const google::api::expr::runtime::CelList& cel_list) { return NativeTypeTraits::Id(cel_list); } }; template <> struct NativeTypeTraits final { static NativeTypeId Id(const google::api::expr::runtime::CelMap& cel_map) { return cel_map.GetNativeTypeId(); } }; template struct NativeTypeTraits< T, std::enable_if_t, std::negation>>>> final { static NativeTypeId Id(const google::api::expr::runtime::CelMap& cel_map) { return NativeTypeTraits::Id(cel_map); } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_H_ ================================================ FILE: eval/public/cel_value_internal.h ================================================ /* * Copyright 2018 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ #include #include "absl/base/macros.h" #include "absl/types/variant.h" #include "eval/public/message_wrapper.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime::internal { // Helper classes needed for IndexOf metafunction implementation. template struct IndexDef {}; // This partial IndexDef type specialization provides additional constant // "value", associated with the type. template struct IndexDef { static constexpr int value = N; }; // TypeIndexer is a template class, representing metafunction to find the index // of a type in a type list. template struct TypeIndexer : public TypeIndexer, IndexDef::value> {}; template struct TypeIndexer : public IndexDef::value> {}; // ValueHolder class wraps absl::variant, adding IndexOf metafunction to it. template class ValueHolder { public: template explicit ValueHolder(T t) : value_(t) {} // Metafunction to find the index of a type in a type list. template using IndexOf = TypeIndexer<0, sizeof...(Args), T, Args...>; template const T* get() const { return absl::get_if(&value_); } template bool is() const { return absl::holds_alternative(value_); } int index() const { return value_.index(); } template ReturnType Visit(Op&& op) const { return absl::visit(std::forward(op), value_); } private: absl::variant value_; }; // Adapter for visitor clients that depend on google::protobuf::Message as a variant type. template struct MessageVisitAdapter { explicit MessageVisitAdapter(Op&& op) : op(std::forward(op)) {} template T operator()(const ArgT& arg) { return op(arg); } T operator()(const MessageWrapper& wrapper) { ABSL_ASSERT(wrapper.HasFullProto()); return op(static_cast(wrapper.message_ptr())); } Op op; }; } // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_INTERNAL_H_ ================================================ FILE: eval/public/cel_value_producer.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_PRODUCER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_PRODUCER_H_ #include "eval/public/cel_value.h" namespace google::api::expr::runtime { // CelValueProducer produces CelValue during CEL Expression evaluation. // It is intended to be used with Activation, to provide on-demand CelValue // calculations. // ValueProducer serves as performance optimization. Multiple calls to value // producer during the execution of the same expression should return the same // value. class CelValueProducer { public: virtual ~CelValueProducer() {} // Produces CelValue. // If CelValue payload is not a primitive type, it must be owned by arena. virtual CelValue Produce(google::protobuf::Arena* arena) = 0; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_VALUE_PRODUCER_H_ ================================================ FILE: eval/public/cel_value_test.cc ================================================ #include "eval/public/cel_value.h" #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "common/memory.h" #include "eval/internal/errors.h" #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/testing/matchers.h" #include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::runtime_internal::kDurationHigh; using ::cel::runtime_internal::kDurationLow; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::NotNull; class DummyMap : public CelMap { public: absl::optional operator[](CelValue value) const override { return CelValue::CreateNull(); } absl::StatusOr ListKeys() const override { return absl::UnimplementedError("CelMap::ListKeys is not implemented"); } int size() const override { return 0; } }; class DummyList : public CelList { public: int size() const override { return 0; } CelValue operator[](int index) const override { return CelValue::CreateNull(); } }; TEST(CelValueTest, TestType) { ::google::protobuf::Arena arena; CelValue value_null = CelValue::CreateNull(); EXPECT_THAT(value_null.type(), Eq(CelValue::Type::kNullType)); CelValue value_bool = CelValue::CreateBool(false); EXPECT_THAT(value_bool.type(), Eq(CelValue::Type::kBool)); CelValue value_int64 = CelValue::CreateInt64(0); EXPECT_THAT(value_int64.type(), Eq(CelValue::Type::kInt64)); CelValue value_uint64 = CelValue::CreateUint64(1); EXPECT_THAT(value_uint64.type(), Eq(CelValue::Type::kUint64)); CelValue value_double = CelValue::CreateDouble(1.0); EXPECT_THAT(value_double.type(), Eq(CelValue::Type::kDouble)); std::string str = "test"; CelValue value_str = CelValue::CreateString(&str); EXPECT_THAT(value_str.type(), Eq(CelValue::Type::kString)); std::string bytes_str = "bytes"; CelValue value_bytes = CelValue::CreateBytes(&bytes_str); EXPECT_THAT(value_bytes.type(), Eq(CelValue::Type::kBytes)); UnknownSet unknown_set; CelValue value_unknown = CelValue::CreateUnknownSet(&unknown_set); EXPECT_THAT(value_unknown.type(), Eq(CelValue::Type::kUnknownSet)); CelValue missing_attribute_error = CreateMissingAttributeError(&arena, "destination.ip"); EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); EXPECT_EQ(missing_attribute_error.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(missing_attribute_error.ErrorOrDie()->message(), "MissingAttributeError: destination.ip"); } int CountTypeMatch(const CelValue& value) { int count = 0; bool value_bool; count += (value.GetValue(&value_bool)) ? 1 : 0; int64_t value_int64; count += (value.GetValue(&value_int64)) ? 1 : 0; uint64_t value_uint64; count += (value.GetValue(&value_uint64)) ? 1 : 0; double value_double; count += (value.GetValue(&value_double)) ? 1 : 0; std::string test = ""; CelValue::StringHolder value_str(&test); count += (value.GetValue(&value_str)) ? 1 : 0; CelValue::BytesHolder value_bytes(&test); count += (value.GetValue(&value_bytes)) ? 1 : 0; const google::protobuf::Message* value_msg; count += (value.GetValue(&value_msg)) ? 1 : 0; const CelList* value_list; count += (value.GetValue(&value_list)) ? 1 : 0; const CelMap* value_map; count += (value.GetValue(&value_map)) ? 1 : 0; const CelError* value_error; count += (value.GetValue(&value_error)) ? 1 : 0; const UnknownSet* value_unknown; count += (value.GetValue(&value_unknown)) ? 1 : 0; return count; } // This test verifies CelValue support of bool type. TEST(CelValueTest, TestBool) { CelValue value = CelValue::CreateBool(true); EXPECT_TRUE(value.IsBool()); EXPECT_THAT(value.BoolOrDie(), Eq(true)); // test template getter bool value2 = false; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_EQ(value2, true); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } // This test verifies CelValue support of int64 type. TEST(CelValueTest, TestInt64) { int64_t v = 1; CelValue value = CelValue::CreateInt64(v); EXPECT_TRUE(value.IsInt64()); EXPECT_THAT(value.Int64OrDie(), Eq(1)); // test template getter int64_t value2 = 0; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_EQ(value2, 1); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } // This test verifies CelValue support of uint64 type. TEST(CelValueTest, TestUint64) { uint64_t v = 1; CelValue value = CelValue::CreateUint64(v); EXPECT_TRUE(value.IsUint64()); EXPECT_THAT(value.Uint64OrDie(), Eq(1)); // test template getter uint64_t value2 = 0; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_EQ(value2, 1); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } // This test verifies CelValue support of int64 type. TEST(CelValueTest, TestDouble) { double v0 = 1.; CelValue value = CelValue::CreateDouble(v0); EXPECT_TRUE(value.IsDouble()); EXPECT_THAT(value.DoubleOrDie(), Eq(v0)); // test template getter double value2 = 0; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_DOUBLE_EQ(value2, 1); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } TEST(CelValueTest, TestDurationRangeCheck) { EXPECT_THAT(CelValue::CreateDuration(absl::Seconds(1)), test::IsCelDuration(absl::Seconds(1))); EXPECT_THAT( CelValue::CreateDuration(kDurationHigh), test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Duration is out of range")))); EXPECT_THAT( CelValue::CreateDuration(kDurationLow), test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Duration is out of range")))); EXPECT_THAT(CelValue::CreateDuration(kDurationLow + absl::Seconds(1)), test::IsCelDuration(kDurationLow + absl::Seconds(1))); } // This test verifies CelValue support of string type. TEST(CelValueTest, TestString) { constexpr char kTestStr0[] = "test0"; std::string v = kTestStr0; CelValue value = CelValue::CreateString(&v); // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsString()); EXPECT_THAT(value.StringOrDie().value(), Eq(std::string(kTestStr0))); // test template getter std::string test = ""; CelValue::StringHolder value2(&test); EXPECT_TRUE(value.GetValue(&value2)); EXPECT_THAT(value2.value(), Eq(kTestStr0)); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } // This test verifies CelValue support of Bytes type. TEST(CelValueTest, TestBytes) { constexpr char kTestStr0[] = "test0"; std::string v = kTestStr0; CelValue value = CelValue::CreateBytes(&v); // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsBytes()); EXPECT_THAT(value.BytesOrDie().value(), Eq(std::string(kTestStr0))); // test template getter std::string test = ""; CelValue::BytesHolder value2(&test); EXPECT_TRUE(value.GetValue(&value2)); EXPECT_THAT(value2.value(), Eq(kTestStr0)); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } // This test verifies CelValue support of List type. TEST(CelValueTest, TestList) { DummyList dummy_list; CelValue value = CelValue::CreateList(&dummy_list); EXPECT_TRUE(value.IsList()); EXPECT_THAT(value.ListOrDie(), Eq(&dummy_list)); // test template getter const CelList* value2; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_THAT(value2, Eq(&dummy_list)); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } TEST(CelValueTest, TestEmptyList) { ::google::protobuf::Arena arena; CelValue value = CelValue::CreateList(); EXPECT_TRUE(value.IsList()); const CelList* value2; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_TRUE(value2->empty()); EXPECT_EQ(value2->size(), 0); EXPECT_THAT(value2->Get(&arena, 0), test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument))); } // This test verifies CelValue support of Map type. TEST(CelValueTest, TestMap) { DummyMap dummy_map; CelValue value = CelValue::CreateMap(&dummy_map); EXPECT_TRUE(value.IsMap()); EXPECT_THAT(value.MapOrDie(), Eq(&dummy_map)); // test template getter const CelMap* value2; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_THAT(value2, Eq(&dummy_map)); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } TEST(CelValueTest, TestEmptyMap) { ::google::protobuf::Arena arena; CelValue value = CelValue::CreateMap(); EXPECT_TRUE(value.IsMap()); const CelMap* value2; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_TRUE(value2->empty()); EXPECT_EQ(value2->size(), 0); EXPECT_THAT(value2->Has(CelValue::CreateBool(false)), IsOkAndHolds(false)); EXPECT_THAT(value2->Get(&arena, CelValue::CreateBool(false)), Eq(absl::nullopt)); EXPECT_THAT(value2->ListKeys(&arena), IsOkAndHolds(NotNull())); } TEST(CelValueTest, TestCelType) { CelValue value_null = CelValue::CreateNullTypedValue(); EXPECT_THAT(value_null.ObtainCelType().CelTypeOrDie().value(), Eq("null_type")); CelValue value_bool = CelValue::CreateBool(false); EXPECT_THAT(value_bool.ObtainCelType().CelTypeOrDie().value(), Eq("bool")); CelValue value_int64 = CelValue::CreateInt64(0); EXPECT_THAT(value_int64.ObtainCelType().CelTypeOrDie().value(), Eq("int")); CelValue value_uint64 = CelValue::CreateUint64(0); EXPECT_THAT(value_uint64.ObtainCelType().CelTypeOrDie().value(), Eq("uint")); CelValue value_double = CelValue::CreateDouble(1.0); EXPECT_THAT(value_double.ObtainCelType().CelTypeOrDie().value(), Eq("double")); std::string str = "test"; CelValue value_str = CelValue::CreateString(&str); EXPECT_THAT(value_str.ObtainCelType().CelTypeOrDie().value(), Eq("string")); std::string bytes_str = "bytes"; CelValue value_bytes = CelValue::CreateBytes(&bytes_str); EXPECT_THAT(value_bytes.type(), Eq(CelValue::Type::kBytes)); EXPECT_THAT(value_bytes.ObtainCelType().CelTypeOrDie().value(), Eq("bytes")); std::string msg_type_str = "google.api.expr.runtime.TestMessage"; CelValue msg_type = CelValue::CreateCelTypeView(msg_type_str); EXPECT_TRUE(msg_type.IsCelType()); EXPECT_THAT(msg_type.CelTypeOrDie().value(), Eq("google.api.expr.runtime.TestMessage")); EXPECT_THAT(msg_type.type(), Eq(CelValue::Type::kCelType)); UnknownSet unknown_set; CelValue value_unknown = CelValue::CreateUnknownSet(&unknown_set); EXPECT_THAT(value_unknown.type(), Eq(CelValue::Type::kUnknownSet)); EXPECT_TRUE(value_unknown.ObtainCelType().IsUnknownSet()); } // This test verifies CelValue support of Unknown type. TEST(CelValueTest, TestUnknownSet) { UnknownSet unknown_set; CelValue value = CelValue::CreateUnknownSet(&unknown_set); EXPECT_TRUE(value.IsUnknownSet()); EXPECT_THAT(value.UnknownSetOrDie(), Eq(&unknown_set)); // test template getter const UnknownSet* value2; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_THAT(value2, Eq(&unknown_set)); EXPECT_THAT(CountTypeMatch(value), Eq(1)); } TEST(CelValueTest, SpecialErrorFactories) { google::protobuf::Arena arena; auto manager = ProtoMemoryManagerRef(&arena); CelValue error = CreateNoSuchKeyError(manager, "key"); EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); EXPECT_TRUE(CheckNoSuchKeyError(error)); error = CreateNoSuchFieldError(manager, "field"); EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kNotFound))); error = CreateNoMatchingOverloadError(manager, "function"); EXPECT_THAT(error, test::IsCelError(StatusIs(absl::StatusCode::kUnknown))); EXPECT_TRUE(CheckNoMatchingOverloadError(error)); absl::Status error_status = absl::InternalError("internal error"); error_status.SetPayload("CreateErrorValuePreservesFullStatusMessage", absl::Cord("more information")); error = CreateErrorValue(manager, error_status); EXPECT_THAT(error, test::IsCelError(error_status)); error = CreateErrorValue(&arena, error_status); EXPECT_THAT(error, test::IsCelError(error_status)); } TEST(CelValueTest, MissingAttributeErrorsDeprecated) { google::protobuf::Arena arena; CelValue missing_attribute_error = CreateMissingAttributeError(&arena, "destination.ip"); EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); EXPECT_TRUE(missing_attribute_error.ObtainCelType().IsError()); } TEST(CelValueTest, MissingAttributeErrors) { google::protobuf::Arena arena; auto manager = ProtoMemoryManagerRef(&arena); CelValue missing_attribute_error = CreateMissingAttributeError(manager, "destination.ip"); EXPECT_TRUE(IsMissingAttributeError(missing_attribute_error)); EXPECT_TRUE(missing_attribute_error.ObtainCelType().IsError()); } TEST(CelValueTest, UnknownFunctionResultErrorsDeprecated) { google::protobuf::Arena arena; CelValue value = CreateUnknownFunctionResultError(&arena, "message"); EXPECT_TRUE(value.IsError()); EXPECT_TRUE(IsUnknownFunctionResult(value)); } TEST(CelValueTest, UnknownFunctionResultErrors) { google::protobuf::Arena arena; auto manager = ProtoMemoryManagerRef(&arena); CelValue value = CreateUnknownFunctionResultError(manager, "message"); EXPECT_TRUE(value.IsError()); EXPECT_TRUE(IsUnknownFunctionResult(value)); } TEST(CelValueTest, DebugString) { EXPECT_EQ(CelValue::CreateNull().DebugString(), "null_type: null"); EXPECT_EQ(CelValue::CreateBool(true).DebugString(), "bool: 1"); EXPECT_EQ(CelValue::CreateInt64(-12345).DebugString(), "int64: -12345"); EXPECT_EQ(CelValue::CreateUint64(12345).DebugString(), "uint64: 12345"); EXPECT_TRUE(absl::StartsWith(CelValue::CreateDouble(0.12345).DebugString(), "double: 0.12345")); const std::string abc("abc"); EXPECT_EQ(CelValue::CreateString(&abc).DebugString(), "string: abc"); EXPECT_EQ(CelValue::CreateBytes(&abc).DebugString(), "bytes: abc"); EXPECT_EQ(CelValue::CreateDuration(absl::Hours(24)).DebugString(), "Duration: 24h"); EXPECT_EQ( CelValue::CreateTimestamp(absl::FromUnixSeconds(86400)).DebugString(), "Timestamp: 1970-01-02T00:00:00+00:00"); UnknownSet unknown_set; EXPECT_EQ(CelValue::CreateUnknownSet(&unknown_set).DebugString(), "UnknownSet: ?"); absl::Status error = absl::InternalError("Blah..."); EXPECT_EQ(CelValue::CreateError(&error).DebugString(), "CelError: INTERNAL: Blah..."); // List and map DebugString() test coverage is in cel_proto_wrapper_test.cc. } TEST(CelValueTest, Message) { TestMessage message; auto value = CelValue::CreateMessageWrapper( CelValue::MessageWrapper(&message, TrivialTypeInfo::GetInstance())); EXPECT_TRUE(value.IsMessage()); CelValue::MessageWrapper held; ASSERT_TRUE(value.GetValue(&held)); EXPECT_TRUE(held.HasFullProto()); EXPECT_EQ(held.message_ptr(), static_cast(&message)); EXPECT_EQ(held.legacy_type_info(), TrivialTypeInfo::GetInstance()); // TrivialTypeInfo doesn't provide any details about the specific message. EXPECT_EQ(value.ObtainCelType().CelTypeOrDie().value(), "opaque"); EXPECT_EQ(value.DebugString(), "Message: opaque"); } TEST(CelValueTest, MessageLite) { TestMessage message; // Upcast to message lite. const google::protobuf::MessageLite* ptr = &message; auto value = CelValue::CreateMessageWrapper( CelValue::MessageWrapper(ptr, TrivialTypeInfo::GetInstance())); EXPECT_TRUE(value.IsMessage()); CelValue::MessageWrapper held; ASSERT_TRUE(value.GetValue(&held)); EXPECT_FALSE(held.HasFullProto()); EXPECT_EQ(held.message_ptr(), &message); EXPECT_EQ(held.legacy_type_info(), TrivialTypeInfo::GetInstance()); EXPECT_EQ(value.ObtainCelType().CelTypeOrDie().value(), "opaque"); EXPECT_EQ(value.DebugString(), "Message: opaque"); } TEST(CelValueTest, Size) { // CelValue performance degrades when it becomes larger. static_assert(sizeof(CelValue) <= 3 * sizeof(uintptr_t)); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/comparison_functions.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/comparison_functions.h" #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "runtime/standard/comparison_functions.h" namespace google::api::expr::runtime { absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { cel::RuntimeOptions modern_options = ConvertToRuntimeOptions(options); cel::FunctionRegistry& modern_registry = registry->InternalGetRegistry(); return cel::RegisterComparisonFunctions(modern_registry, modern_options); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/comparison_functions.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_COMPARISON_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_COMPARISON_FUNCTIONS_H_ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" namespace google::api::expr::runtime { // Register built in comparison functions (<, <=, >, >=). // // Most users should prefer to use RegisterBuiltinFunctions. // // This is call is included in RegisterBuiltinFunctions -- calling both // RegisterBuiltinFunctions and RegisterComparisonFunctions directly on the same // registry will result in an error. absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_COMPARISON_FUNCTIONS_H_ ================================================ FILE: eval/public/comparison_functions_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/comparison_functions.h" #include #include #include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "eval/public/activation.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/testing/matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::cel::expr::ParsedExpr; using ::google::rpc::context::AttributeContext; using ::testing::Combine; using ::testing::ValuesIn; MATCHER_P2(DefinesHomogenousOverload, name, argument_type, absl::StrCat(name, " for ", CelValue::TypeName(argument_type))) { const CelFunctionRegistry& registry = arg; return !registry .FindOverloads(name, /*receiver_style=*/false, {argument_type, argument_type}) .empty(); return false; } struct ComparisonTestCase { absl::string_view expr; bool result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; class ComparisonFunctionTest : public testing::TestWithParam> { public: ComparisonFunctionTest() { options_.enable_heterogeneous_equality = std::get<1>(GetParam()); options_.enable_empty_wrapper_null_unboxing = true; builder_ = CreateCelExpressionBuilder(options_); } CelFunctionRegistry& registry() { return *builder_->GetRegistry(); } absl::StatusOr Evaluate(absl::string_view expr, const CelValue& lhs, const CelValue& rhs) { CEL_ASSIGN_OR_RETURN(ParsedExpr parsed_expr, parser::Parse(expr)); Activation activation; activation.InsertValue("lhs", lhs); activation.InsertValue("rhs", rhs); CEL_ASSIGN_OR_RETURN(auto expression, builder_->CreateExpression( &parsed_expr.expr(), &parsed_expr.source_info())); return expression->Evaluate(activation, &arena_); } protected: std::unique_ptr builder_; InterpreterOptions options_; google::protobuf::Arena arena_; }; TEST_P(ComparisonFunctionTest, SmokeTest) { ComparisonTestCase test_case = std::get<0>(GetParam()); google::protobuf::LinkMessageReflection(); ASSERT_OK(RegisterComparisonFunctions(®istry(), options_)); ASSERT_OK_AND_ASSIGN(auto result, Evaluate(test_case.expr, test_case.lhs, test_case.rhs)); EXPECT_THAT(result, test::IsCelBool(test_case.result)); } INSTANTIATE_TEST_SUITE_P( LessThan, ComparisonFunctionTest, Combine(ValuesIn( {// less than {"false < true", true}, {"1 < 2", true}, {"-2 < -1", true}, {"1.1 < 1.2", true}, {"'a' < 'b'", true}, {"lhs < rhs", true, CelValue::CreateBytesView("a"), CelValue::CreateBytesView("b")}, {"lhs < rhs", true, CelValue::CreateDuration(absl::Seconds(1)), CelValue::CreateDuration(absl::Seconds(2))}, {"lhs < rhs", true, CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}}), // heterogeneous equality enabled testing::Bool())); INSTANTIATE_TEST_SUITE_P( GreaterThan, ComparisonFunctionTest, testing::Combine( testing::ValuesIn( {{"false > true", false}, {"1 > 2", false}, {"-2 > -1", false}, {"1.1 > 1.2", false}, {"'a' > 'b'", false}, {"lhs > rhs", false, CelValue::CreateBytesView("a"), CelValue::CreateBytesView("b")}, {"lhs > rhs", false, CelValue::CreateDuration(absl::Seconds(1)), CelValue::CreateDuration(absl::Seconds(2))}, {"lhs > rhs", false, CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}}), // heterogeneous equality enabled testing::Bool())); INSTANTIATE_TEST_SUITE_P( GreaterOrEqual, ComparisonFunctionTest, Combine(ValuesIn( {{"false >= true", false}, {"1 >= 2", false}, {"-2 >= -1", false}, {"1.1 >= 1.2", false}, {"'a' >= 'b'", false}, {"lhs >= rhs", false, CelValue::CreateBytesView("a"), CelValue::CreateBytesView("b")}, {"lhs >= rhs", false, CelValue::CreateDuration(absl::Seconds(1)), CelValue::CreateDuration(absl::Seconds(2))}, {"lhs >= rhs", false, CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}}), // heterogeneous equality enabled testing::Bool())); INSTANTIATE_TEST_SUITE_P( LessOrEqual, ComparisonFunctionTest, Combine(testing::ValuesIn( {{"false <= true", true}, {"1 <= 2", true}, {"-2 <= -1", true}, {"1.1 <= 1.2", true}, {"'a' <= 'b'", true}, {"lhs <= rhs", true, CelValue::CreateBytesView("a"), CelValue::CreateBytesView("b")}, {"lhs <= rhs", true, CelValue::CreateDuration(absl::Seconds(1)), CelValue::CreateDuration(absl::Seconds(2))}, {"lhs <= rhs", true, CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}}), // heterogeneous equality enabled testing::Bool())); INSTANTIATE_TEST_SUITE_P(HeterogeneousNumericComparisons, ComparisonFunctionTest, Combine(testing::ValuesIn( { // less than {"1 < 2u", true}, // int < uint {"2 < 1u", false}, {"1 < 2.1", true}, // int < double {"3 < 2.1", false}, {"1u < 2", true}, // uint < int {"2u < 1", false}, {"1u < -1.1", false}, // uint < double {"1u < 2.1", true}, {"1.1 < 2", true}, // double < int {"1.1 < 1", false}, {"1.0 < 1u", false}, // double < uint {"1.0 < 3u", true}, // less than or equal {"1 <= 2u", true}, // int <= uint {"2 <= 1u", false}, {"1 <= 2.1", true}, // int <= double {"3 <= 2.1", false}, {"1u <= 2", true}, // uint <= int {"1u <= 0", false}, {"1u <= -1.1", false}, // uint <= double {"2u <= 1.0", false}, {"1.1 <= 2", true}, // double <= int {"2.1 <= 2", false}, {"1.0 <= 1u", true}, // double <= uint {"1.1 <= 1u", false}, // greater than {"3 > 2u", true}, // int > uint {"3 > 4u", false}, {"3 > 2.1", true}, // int > double {"3 > 4.1", false}, {"3u > 2", true}, // uint > int {"3u > 4", false}, {"3u > -1.1", true}, // uint > double {"3u > 4.1", false}, {"3.1 > 2", true}, // double > int {"3.1 > 4", false}, {"3.0 > 1u", true}, // double > uint {"3.0 > 4u", false}, // greater than or equal {"3 >= 2u", true}, // int >= uint {"3 >= 4u", false}, {"3 >= 2.1", true}, // int >= double {"3 >= 4.1", false}, {"3u >= 2", true}, // uint >= int {"3u >= 4", false}, {"3u >= -1.1", true}, // uint >= double {"3u >= 4.1", false}, {"3.1 >= 2", true}, // double >= int {"3.1 >= 4", false}, {"3.0 >= 1u", true}, // double >= uint {"3.0 >= 4u", false}, {"1u >= -1", true}, {"1 >= 4u", false}, // edge cases {"-1 < 1u", true}, {"1 < 9223372036854775808u", true}}), testing::Values(true))); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/container_function_registrar.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/container_function_registrar.h" #include "eval/public/cel_options.h" #include "runtime/runtime_options.h" #include "runtime/standard/container_functions.h" namespace google::api::expr::runtime { absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); return cel::RegisterContainerFunctions(registry->InternalGetRegistry(), runtime_options); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/container_function_registrar.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" namespace google::api::expr::runtime { // Register built in container functions. // // Most users should prefer to use RegisterBuiltinFunctions. // // This call is included in RegisterBuiltinFunctions -- calling both // RegisterBuiltinFunctions and RegisterContainerFunctions directly on the same // registry will result in an error. absl::Status RegisterContainerFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINER_FUNCTION_REGISTRAR_H_ ================================================ FILE: eval/public/container_function_registrar_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/container_function_registrar.h" #include #include #include "eval/public/activation.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/equality_function_registrar.h" #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "parser/parser.h" namespace google::api::expr::runtime { namespace { using cel::expr::Expr; using cel::expr::SourceInfo; using ::testing::ValuesIn; struct TestCase { std::string test_name; std::string expr; absl::StatusOr result = CelValue::CreateBool(true); }; const CelList& CelNumberListExample() { static ContainerBackedListImpl* example = new ContainerBackedListImpl({CelValue::CreateInt64(1)}); return *example; } void ExpectResult(const TestCase& test_case) { auto parsed_expr = parser::Parse(test_case.expr); ASSERT_OK(parsed_expr); const Expr& expr_ast = parsed_expr->expr(); const SourceInfo& source_info = parsed_expr->source_info(); InterpreterOptions options; options.enable_timestamp_duration_overflow_errors = true; options.enable_comprehension_list_append = true; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterContainerFunctions(builder->GetRegistry(), options)); // Needed to avoid error - No overloads provided for FunctionStep creation. ASSERT_OK(RegisterEqualityFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expression, builder->CreateExpression(&expr_ast, &source_info)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(auto value, cel_expression->Evaluate(activation, &arena)); EXPECT_THAT(value, test::EqualsCelValue(*test_case.result)); } using ContainerFunctionParamsTest = testing::TestWithParam; TEST_P(ContainerFunctionParamsTest, StandardFunctions) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( ContainerFunctionParamsTest, ContainerFunctionParamsTest, ValuesIn( {{"FilterNumbers", "[1, 2, 3].filter(num, num == 1)", CelValue::CreateList(&CelNumberListExample())}, {"ListConcatEmptyInputs", "[] + [] == []", CelValue::CreateBool(true)}, {"ListConcatRightEmpty", "[1] + [] == [1]", CelValue::CreateBool(true)}, {"ListConcatLeftEmpty", "[] + [1] == [1]", CelValue::CreateBool(true)}, {"ListConcat", "[2] + [1] == [2, 1]", CelValue::CreateBool(true)}, {"ListSize", "[1, 2, 3].size() == 3", CelValue::CreateBool(true)}, {"MapSize", "{1: 2, 2: 4}.size() == 2", CelValue::CreateBool(true)}, {"EmptyListSize", "size({}) == 0", CelValue::CreateBool(true)}}), [](const testing::TestParamInfo& info) { return info.param.test_name; }); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/containers/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # TODO(issues/69): Expose this in a public API. package_group( name = "cel_internal", packages = ["//eval/..."], ) cc_library( name = "field_access", srcs = [ "field_access.cc", ], hdrs = [ "field_access.h", ], deps = [ "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//eval/public/structs:field_access_impl", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "container_backed_list_impl", srcs = [ ], hdrs = [ "container_backed_list_impl.h", ], deps = [ "//eval/public:cel_value", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "container_backed_map_impl", srcs = [ "container_backed_map_impl.cc", ], hdrs = [ "container_backed_map_impl.h", ], deps = [ "//eval/public:cel_value", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_library( name = "field_backed_list_impl", hdrs = [ "field_backed_list_impl.h", ], deps = [ ":internal_field_backed_list_impl", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", ], ) cc_library( name = "field_backed_map_impl", hdrs = [ "field_backed_map_impl.h", ], deps = [ ":internal_field_backed_map_impl", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "container_backed_map_impl_test", size = "small", srcs = [ "container_backed_map_impl_test.cc", ], deps = [ ":container_backed_map_impl", "//eval/public:cel_value", "//internal:testing", "@com_google_absl//absl/status", ], ) cc_test( name = "field_backed_list_impl_test", size = "small", srcs = [ "field_backed_list_impl_test.cc", ], deps = [ ":field_backed_list_impl", "//eval/testutil:test_message_cc_proto", "//internal:testing", "//testutil:util", ], ) cc_test( name = "field_backed_map_impl_test", size = "small", srcs = [ "field_backed_map_impl_test.cc", ], deps = [ ":field_backed_map_impl", "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) cc_test( name = "field_access_test", srcs = ["field_access_test.cc"], deps = [ ":field_access", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:testing", "//internal:time", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "internal_field_backed_list_impl", srcs = [ "internal_field_backed_list_impl.cc", ], hdrs = [ "internal_field_backed_list_impl.h", ], deps = [ "//eval/public:cel_value", "//eval/public/structs:field_access_impl", "//eval/public/structs:protobuf_value_factory", ], ) cc_test( name = "internal_field_backed_list_impl_test", size = "small", srcs = [ "internal_field_backed_list_impl_test.cc", ], deps = [ ":internal_field_backed_list_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//internal:testing", "//testutil:util", ], ) cc_library( name = "internal_field_backed_map_impl", srcs = [ "internal_field_backed_map_impl.cc", ], hdrs = [ "internal_field_backed_map_impl.h", ], deps = [ "//eval/public:cel_value", "//eval/public/structs:field_access_impl", "//eval/public/structs:protobuf_value_factory", "//extensions/protobuf/internal:map_reflection", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "internal_field_backed_map_impl_test", size = "small", srcs = [ "internal_field_backed_map_impl_test.cc", ], visibility = ["//visibility:private"], deps = [ ":internal_field_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) ================================================ FILE: eval/public/containers/container_backed_list_impl.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_LIST_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_LIST_IMPL_H_ #include #include #include "eval/public/cel_value.h" #include "google/protobuf/arena.h" namespace google { namespace api { namespace expr { namespace runtime { // CelList implementation that uses "repeated" message field // as backing storage. class ContainerBackedListImpl : public CelList { public: // message contains the "repeated" field // descriptor FieldDescriptor for the field explicit ContainerBackedListImpl(std::vector values) : values_(std::move(values)) {} // List size. int size() const override { return values_.size(); } // List element access operator. CelValue operator[](int index) const override { return values_[index]; } // List element access operator. CelValue Get(google::protobuf::Arena*, int index) const override { return values_[index]; } private: std::vector values_; }; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_LIST_IMPL_H_ ================================================ FILE: eval/public/containers/container_backed_map_impl.cc ================================================ #include "eval/public/containers/container_backed_map_impl.h" #include #include #include "absl/container/node_hash_map.h" #include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "eval/public/cel_value.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { // Helper classes for CelValue hasher function // We care only for hash operations for integral/string types, but others // should be present as well for CelValue::Visit(HasherOp) to compile. class HasherOp { public: template size_t operator()(const T& arg) { return std::hash()(arg); } size_t operator()(const absl::Time arg) { return absl::Hash()(arg); } size_t operator()(const absl::Duration arg) { return absl::Hash()(arg); } size_t operator()(const CelValue::StringHolder& arg) { return absl::Hash()(arg.value()); } size_t operator()(const CelValue::BytesHolder& arg) { return absl::Hash()(arg.value()); } size_t operator()(const CelValue::CelTypeHolder& arg) { return absl::Hash()(arg.value()); } // Needed for successful compilation resolution. size_t operator()(const std::nullptr_t&) { return 0; } }; // Helper classes to provide CelValue equality comparison operation template class EqualOp { public: explicit EqualOp(const T& arg) : arg_(arg) {} template bool operator()(const U&) const { return false; } bool operator()(const T& other) const { return other == arg_; } private: const T& arg_; }; class CelValueEq { public: explicit CelValueEq(const CelValue& other) : other_(other) {} template bool operator()(const Type& arg) { return other_.template Visit(EqualOp(arg)); } private: const CelValue& other_; }; } // namespace // Map element access operator. absl::optional CelMapBuilder::operator[](CelValue cel_key) const { auto item = values_map_.find(cel_key); if (item == values_map_.end()) { return absl::nullopt; } return item->second; } absl::Status CelMapBuilder::Add(CelValue key, CelValue value) { auto [unused, inserted] = values_map_.emplace(key, value); if (!inserted) { return absl::InvalidArgumentError("duplicate map keys"); } key_list_.Add(key); return absl::OkStatus(); } // CelValue hasher functor. size_t CelMapBuilder::Hasher::operator()(const CelValue& key) const { return key.template Visit(HasherOp()); } bool CelMapBuilder::Equal::operator()(const CelValue& key1, const CelValue& key2) const { if (key1.type() != key2.type()) { return false; } return key1.template Visit(CelValueEq(key2)); } absl::StatusOr> CreateContainerBackedMap( absl::Span> key_values) { auto map = std::make_unique(); for (const auto& key_value : key_values) { CEL_RETURN_IF_ERROR(map->Add(key_value.first, key_value.second)); } return map; } } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/containers/container_backed_map_impl.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_MAP_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_MAP_IMPL_H_ #include #include #include "absl/container/node_hash_map.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { // CelMap implementation that uses STL map container as backing storage. // KeyType is the type of key values stored in CelValue. // After building, upcast to CelMap to prevent further additions. class CelMapBuilder : public CelMap { public: CelMapBuilder() {} // Try to insert a key value pair into the map. Returns a status if key // already exists. absl::Status Add(CelValue key, CelValue value); int size() const override { return values_map_.size(); } absl::optional operator[](CelValue cel_key) const override; absl::StatusOr Has(const CelValue& cel_key) const override { return values_map_.contains(cel_key); } absl::StatusOr ListKeys() const override { return &key_list_; } private: // Custom CelList implementation for maintaining key list. class KeyList : public CelList { public: KeyList() {} int size() const override { return keys_.size(); } CelValue operator[](int index) const override { return keys_[index]; } void Add(const CelValue& key) { keys_.push_back(key); } private: std::vector keys_; }; struct Hasher { size_t operator()(const CelValue& key) const; }; struct Equal { bool operator()(const CelValue& key1, const CelValue& key2) const; }; absl::node_hash_map values_map_; KeyList key_list_; }; // Factory method creating container-backed CelMap. absl::StatusOr> CreateContainerBackedMap( absl::Span> key_values); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_CONTAINER_BACKED_MAP_IMPL_H_ ================================================ FILE: eval/public/containers/container_backed_map_impl_test.cc ================================================ #include "eval/public/containers/container_backed_map_impl.h" #include #include #include #include "absl/status/status.h" #include "eval/public/cel_value.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::testing::Eq; using ::testing::IsNull; using ::testing::Not; TEST(ContainerBackedMapImplTest, TestMapInt64) { std::vector> args = { {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(3)}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); ASSERT_THAT(cel_map, Not(IsNull())); EXPECT_THAT(cel_map->size(), Eq(2)); // Test lookup with key == 1 ( should succeed ) auto lookup1 = (*cel_map)[CelValue::CreateInt64(1)]; ASSERT_TRUE(lookup1); CelValue cel_value = lookup1.value(); ASSERT_TRUE(cel_value.IsInt64()); EXPECT_THAT(cel_value.Int64OrDie(), 2); // Test lookup with key == 1, different type ( should fail ) auto lookup2 = (*cel_map)[CelValue::CreateUint64(1)]; ASSERT_FALSE(lookup2); // Test lookup with key == 3 ( should fail ) auto lookup3 = (*cel_map)[CelValue::CreateInt64(3)]; ASSERT_FALSE(lookup3); } TEST(ContainerBackedMapImplTest, TestMapUint64) { std::vector> args = { {CelValue::CreateUint64(1), CelValue::CreateInt64(2)}, {CelValue::CreateUint64(2), CelValue::CreateInt64(3)}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); ASSERT_THAT(cel_map, Not(IsNull())); EXPECT_THAT(cel_map->size(), Eq(2)); // Test lookup with key == 1 ( should succeed ) auto lookup1 = (*cel_map)[CelValue::CreateUint64(1)]; ASSERT_TRUE(lookup1); CelValue cel_value = lookup1.value(); ASSERT_TRUE(cel_value.IsInt64()); EXPECT_THAT(cel_value.Int64OrDie(), 2); // Test lookup with key == 1, different type ( should fail ) auto lookup2 = (*cel_map)[CelValue::CreateInt64(1)]; ASSERT_FALSE(lookup2); // Test lookup with key == 3 ( should fail ) auto lookup3 = (*cel_map)[CelValue::CreateUint64(3)]; ASSERT_FALSE(lookup3); } TEST(ContainerBackedMapImplTest, TestMapString) { const std::string kKey1 = "1"; const std::string kKey2 = "2"; const std::string kKey3 = "3"; std::vector> args = { {CelValue::CreateString(&kKey1), CelValue::CreateInt64(2)}, {CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); ASSERT_THAT(cel_map, Not(IsNull())); EXPECT_THAT(cel_map->size(), Eq(2)); // Test lookup with key == 1 ( should succeed ) auto lookup1 = (*cel_map)[CelValue::CreateString(&kKey1)]; ASSERT_TRUE(lookup1); CelValue cel_value = lookup1.value(); ASSERT_TRUE(cel_value.IsInt64()); EXPECT_THAT(cel_value.Int64OrDie(), 2); // Test lookup with different type ( should fail ) auto lookup2 = (*cel_map)[CelValue::CreateInt64(1)]; ASSERT_FALSE(lookup2); // Test lookup with key3 ( should fail ) auto lookup3 = (*cel_map)[CelValue::CreateString(&kKey3)]; ASSERT_FALSE(lookup3); } TEST(CelMapBuilder, TestMapString) { const std::string kKey1 = "1"; const std::string kKey2 = "2"; const std::string kKey3 = "3"; std::vector> args = { {CelValue::CreateString(&kKey1), CelValue::CreateInt64(2)}, {CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)}}; CelMapBuilder builder; ASSERT_OK( builder.Add(CelValue::CreateString(&kKey1), CelValue::CreateInt64(2))); ASSERT_OK( builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3))); CelMap* cel_map = &builder; ASSERT_THAT(cel_map, Not(IsNull())); EXPECT_THAT(cel_map->size(), Eq(2)); // Test lookup with key == 1 ( should succeed ) auto lookup1 = (*cel_map)[CelValue::CreateString(&kKey1)]; ASSERT_TRUE(lookup1); CelValue cel_value = lookup1.value(); ASSERT_TRUE(cel_value.IsInt64()); EXPECT_THAT(cel_value.Int64OrDie(), 2); // Test lookup with different type ( should fail ) auto lookup2 = (*cel_map)[CelValue::CreateInt64(1)]; ASSERT_FALSE(lookup2); // Test lookup with key3 ( should fail ) auto lookup3 = (*cel_map)[CelValue::CreateString(&kKey3)]; ASSERT_FALSE(lookup3); } TEST(CelMapBuilder, RepeatKeysFail) { const std::string kKey1 = "1"; const std::string kKey2 = "2"; std::vector> args = { {CelValue::CreateString(&kKey1), CelValue::CreateInt64(2)}, {CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)}}; CelMapBuilder builder; ASSERT_OK( builder.Add(CelValue::CreateString(&kKey1), CelValue::CreateInt64(2))); ASSERT_OK( builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3))); EXPECT_THAT( builder.Add(CelValue::CreateString(&kKey2), CelValue::CreateInt64(3)), StatusIs(absl::StatusCode::kInvalidArgument, "duplicate map keys")); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/containers/field_access.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/containers/field_access.h" #include "absl/status/status.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/field_access_impl.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/map_field.h" namespace google::api::expr::runtime { using ::google::protobuf::Arena; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::MapValueConstRef; using ::google::protobuf::Message; absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const FieldDescriptor* desc, google::protobuf::Arena* arena, CelValue* result) { return CreateValueFromSingleField( msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, arena, result); } absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const FieldDescriptor* desc, ProtoWrapperTypeOptions options, google::protobuf::Arena* arena, CelValue* result) { CEL_ASSIGN_OR_RETURN( *result, internal::CreateValueFromSingleField( msg, desc, options, &CelProtoWrapper::InternalWrapMessage, arena)); return absl::OkStatus(); } absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, const FieldDescriptor* desc, google::protobuf::Arena* arena, int index, CelValue* result) { CEL_ASSIGN_OR_RETURN( *result, internal::CreateValueFromRepeatedField( msg, desc, index, &CelProtoWrapper::InternalWrapMessage, arena)); return absl::OkStatus(); } absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const FieldDescriptor* desc, const MapValueConstRef* value_ref, google::protobuf::Arena* arena, CelValue* result) { CEL_ASSIGN_OR_RETURN( *result, internal::CreateValueFromMapValue( msg, desc, value_ref, &CelProtoWrapper::InternalWrapMessage, arena)); return absl::OkStatus(); } absl::Status SetValueToSingleField(const CelValue& value, const FieldDescriptor* desc, Message* msg, Arena* arena) { return internal::SetValueToSingleField(value, desc, msg, arena); } absl::Status AddValueToRepeatedField(const CelValue& value, const FieldDescriptor* desc, Message* msg, Arena* arena) { return internal::AddValueToRepeatedField(value, desc, msg, arena); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/containers/field_access.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_FIELD_ACCESS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_FIELD_ACCESS_H_ #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { // Creates CelValue from singular message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // options Option to enable treating unset wrapper type fields as null. // arena Arena object to allocate result on, if needed. // result pointer to CelValue to store the result in. absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, google::protobuf::Arena* arena, CelValue* result); absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, ProtoWrapperTypeOptions options, google::protobuf::Arena* arena, CelValue* result); // Creates CelValue from repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // arena Arena object to allocate result on, if needed. // index position in the repeated field. // result pointer to CelValue to store the result in. absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, google::protobuf::Arena* arena, int index, CelValue* result); // Creates CelValue from map message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // value_ref pointer to map value. // arena Arena object to allocate result on, if needed. // result pointer to CelValue to store the result in. absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, const google::protobuf::MapValueConstRef* value_ref, google::protobuf::Arena* arena, CelValue* result); // Assigns content of CelValue to singular message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // arena Arena to perform allocations, if necessary, when setting the field. absl::Status SetValueToSingleField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, google::protobuf::Message* msg, google::protobuf::Arena* arena); // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // arena Arena to perform allocations, if necessary, when adding the value. absl::Status AddValueToRepeatedField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, google::protobuf::Message* msg, google::protobuf::Arena* arena); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_FIELD_ACCESS_H_ ================================================ FILE: eval/public/containers/field_access_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/containers/field_access.h" #include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "internal/time.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::MaxDuration; using ::cel::internal::MaxTimestamp; using ::google::protobuf::Arena; using ::google::protobuf::FieldDescriptor; using ::testing::HasSubstr; TEST(FieldAccessTest, SetDuration) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); auto status = SetValueToSingleField(CelValue::CreateDuration(MaxDuration()), field, &msg, &arena); EXPECT_TRUE(status.ok()); } TEST(FieldAccessTest, SetDurationBadDuration) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); auto status = SetValueToSingleField( CelValue::CreateDuration(MaxDuration() + absl::Seconds(1)), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetDurationBadInputType) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); auto status = SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetTimestamp) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); auto status = SetValueToSingleField(CelValue::CreateTimestamp(MaxTimestamp()), field, &msg, &arena); EXPECT_TRUE(status.ok()); } TEST(FieldAccessTest, SetTimestampBadTime) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); auto status = SetValueToSingleField( CelValue::CreateTimestamp(MaxTimestamp() + absl::Seconds(1)), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetTimestampBadInputType) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); auto status = SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetInt32Overflow) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_int32"); EXPECT_THAT( SetValueToSingleField( CelValue::CreateInt64(std::numeric_limits::max() + 1L), field, &msg, &arena), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Could not assign"))); } TEST(FieldAccessTest, SetUint32Overflow) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_uint32"); EXPECT_THAT( SetValueToSingleField( CelValue::CreateUint64(std::numeric_limits::max() + 1L), field, &msg, &arena), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Could not assign"))); } TEST(FieldAccessTest, SetMessage) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("standalone_message"); TestAllTypes::NestedMessage* nested_msg = google::protobuf::Arena::Create(&arena); nested_msg->set_bb(1); auto status = SetValueToSingleField( CelProtoWrapper::CreateMessage(nested_msg, &arena), field, &msg, &arena); EXPECT_TRUE(status.ok()); } TEST(FieldAccessTest, SetMessageWithNul) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("standalone_message"); auto status = SetValueToSingleField(CelValue::CreateNull(), field, &msg, &arena); EXPECT_TRUE(status.ok()); } constexpr std::array kWrapperFieldNames = { "single_bool_wrapper", "single_int64_wrapper", "single_int32_wrapper", "single_uint64_wrapper", "single_uint32_wrapper", "single_double_wrapper", "single_float_wrapper", "single_string_wrapper", "single_bytes_wrapper"}; // Unset wrapper type fields are treated as null if accessed after option // enabled. TEST(CreateValueFromFieldTest, UnsetWrapperTypesNullIfEnabled) { CelValue result; TestAllTypes test_message; google::protobuf::Arena arena; for (const auto& field : kWrapperFieldNames) { ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName(field), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)) << field; ASSERT_TRUE(result.IsNull()) << field << ": " << result.DebugString(); } } // Unset wrapper type fields are treated as proto default under old // behavior. TEST(CreateValueFromFieldTest, UnsetWrapperTypesDefaultValueIfDisabled) { CelValue result; TestAllTypes test_message; google::protobuf::Arena arena; for (const auto& field : kWrapperFieldNames) { ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName(field), ProtoWrapperTypeOptions::kUnsetProtoDefault, &arena, &result)) << field; ASSERT_FALSE(result.IsNull()) << field << ": " << result.DebugString(); } } // If a wrapper type is set to default value, the corresponding CelValue is the // proto default value. TEST(CreateValueFromFieldTest, SetWrapperTypesDefaultValue) { CelValue result; TestAllTypes test_message; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( single_bool_wrapper {} single_int64_wrapper {} single_int32_wrapper {} single_uint64_wrapper {} single_uint32_wrapper {} single_double_wrapper {} single_float_wrapper {} single_string_wrapper {} single_bytes_wrapper {} )pb", &test_message)); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_bool_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelBool(false)); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_int64_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelInt64(0)); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_int32_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelInt64(0)); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_uint64_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelUint64(0)); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_uint32_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelUint64(0)); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_double_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelDouble(0.0f)); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_float_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelDouble(0.0f)); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_string_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelString("")); ASSERT_OK(CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_bytes_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &arena, &result)); EXPECT_THAT(result, test::IsCelBytes("")); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/containers/field_backed_list_impl.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_LIST_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_LIST_IMPL_H_ #include "eval/public/cel_value.h" #include "eval/public/containers/internal_field_backed_list_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" namespace google { namespace api { namespace expr { namespace runtime { // CelList implementation that uses "repeated" message field // as backing storage. class FieldBackedListImpl : public internal::FieldBackedListImpl { public: // message contains the "repeated" field // descriptor FieldDescriptor for the field // arena is used for incidental allocations when unwrapping the field. FieldBackedListImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, google::protobuf::Arena* arena) : internal::FieldBackedListImpl( message, descriptor, &CelProtoWrapper::InternalWrapMessage, arena) { } }; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_LIST_IMPL_H_ ================================================ FILE: eval/public/containers/field_backed_list_impl_test.cc ================================================ #include "eval/public/containers/field_backed_list_impl.h" #include #include #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "testutil/util.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::testing::Eq; using ::testing::DoubleEq; using testutil::EqualsProto; // Helper method. Creates simple pipeline containing Select step and runs it. std::unique_ptr CreateList(const TestMessage* message, const std::string& field, google::protobuf::Arena* arena) { const google::protobuf::FieldDescriptor* field_desc = message->GetDescriptor()->FindFieldByName(field); return std::make_unique(message, field_desc, arena); } TEST(FieldBackedListImplTest, BoolDatatypeTest) { TestMessage message; message.add_bool_list(true); message.add_bool_list(false); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "bool_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].BoolOrDie(), true); EXPECT_EQ((*cel_list)[1].BoolOrDie(), false); } TEST(FieldBackedListImplTest, TestLength0) { TestMessage message; google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int32_list", &arena); ASSERT_EQ(cel_list->size(), 0); } TEST(FieldBackedListImplTest, TestLength1) { TestMessage message; message.add_int32_list(1); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int32_list", &arena); ASSERT_EQ(cel_list->size(), 1); EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); } TEST(FieldBackedListImplTest, TestLength100000) { TestMessage message; const int kLen = 100000; for (int i = 0; i < kLen; i++) { message.add_int32_list(i); } google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int32_list", &arena); ASSERT_EQ(cel_list->size(), kLen); for (int i = 0; i < kLen; i++) { EXPECT_EQ((*cel_list)[i].Int64OrDie(), i); } } TEST(FieldBackedListImplTest, Int32DatatypeTest) { TestMessage message; message.add_int32_list(1); message.add_int32_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int32_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); EXPECT_EQ((*cel_list)[1].Int64OrDie(), 2); } TEST(FieldBackedListImplTest, Int64DatatypeTest) { TestMessage message; message.add_int64_list(1); message.add_int64_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int64_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); EXPECT_EQ((*cel_list)[1].Int64OrDie(), 2); } TEST(FieldBackedListImplTest, Uint32DatatypeTest) { TestMessage message; message.add_uint32_list(1); message.add_uint32_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "uint32_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].Uint64OrDie(), 1); EXPECT_EQ((*cel_list)[1].Uint64OrDie(), 2); } TEST(FieldBackedListImplTest, Uint64DatatypeTest) { TestMessage message; message.add_uint64_list(1); message.add_uint64_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "uint64_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].Uint64OrDie(), 1); EXPECT_EQ((*cel_list)[1].Uint64OrDie(), 2); } TEST(FieldBackedListImplTest, FloatDatatypeTest) { TestMessage message; message.add_float_list(1); message.add_float_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "float_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_THAT((*cel_list)[0].DoubleOrDie(), DoubleEq(1)); EXPECT_THAT((*cel_list)[1].DoubleOrDie(), DoubleEq(2)); } TEST(FieldBackedListImplTest, DoubleDatatypeTest) { TestMessage message; message.add_double_list(1); message.add_double_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "double_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_THAT((*cel_list)[0].DoubleOrDie(), DoubleEq(1)); EXPECT_THAT((*cel_list)[1].DoubleOrDie(), DoubleEq(2)); } TEST(FieldBackedListImplTest, StringDatatypeTest) { TestMessage message; message.add_string_list("1"); message.add_string_list("2"); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "string_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].StringOrDie().value(), "1"); EXPECT_EQ((*cel_list)[1].StringOrDie().value(), "2"); } TEST(FieldBackedListImplTest, BytesDatatypeTest) { TestMessage message; message.add_bytes_list("1"); message.add_bytes_list("2"); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "bytes_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].BytesOrDie().value(), "1"); EXPECT_EQ((*cel_list)[1].BytesOrDie().value(), "2"); } TEST(FieldBackedListImplTest, MessageDatatypeTest) { TestMessage message; TestMessage* msg1 = message.add_message_list(); TestMessage* msg2 = message.add_message_list(); msg1->set_string_value("1"); msg2->set_string_value("2"); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "message_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_THAT(*msg1, EqualsProto(*((*cel_list)[0].MessageOrDie()))); EXPECT_THAT(*msg2, EqualsProto(*((*cel_list)[1].MessageOrDie()))); } TEST(FieldBackedListImplTest, EnumDatatypeTest) { TestMessage message; message.add_enum_list(TestMessage::TEST_ENUM_1); message.add_enum_list(TestMessage::TEST_ENUM_2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "enum_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_THAT((*cel_list)[0].Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); EXPECT_THAT((*cel_list)[1].Int64OrDie(), Eq(TestMessage::TEST_ENUM_2)); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/containers/field_backed_map_impl.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_MAP_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_MAP_IMPL_H_ #include "absl/status/statusor.h" #include "eval/public/cel_value.h" #include "eval/public/containers/internal_field_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { // CelMap implementation that uses "map" message field // as backing storage. // // Trivial subclass of internal implementation to avoid API changes for clients // that use this directly. class FieldBackedMapImpl : public internal::FieldBackedMapImpl { public: // message contains the "map" field. Object stores the pointer // to the message, thus it is expected that message outlives the // object. // descriptor FieldDescriptor for the field // arena is used for incidental allocations from unpacking the field. FieldBackedMapImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, google::protobuf::Arena* arena) : internal::FieldBackedMapImpl( message, descriptor, &CelProtoWrapper::InternalWrapMessage, arena) { } }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_FIELD_BACKED_MAP_IMPL_H_ ================================================ FILE: eval/public/containers/field_backed_map_impl_test.cc ================================================ #include "eval/public/containers/field_backed_map_impl.h" #include #include #include #include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::UnorderedPointwise; // Test factory for FieldBackedMaps from message and field name. std::unique_ptr CreateMap(const TestMessage* message, const std::string& field, google::protobuf::Arena* arena) { const google::protobuf::FieldDescriptor* field_desc = message->GetDescriptor()->FindFieldByName(field); return std::make_unique(message, field_desc, arena); } TEST(FieldBackedMapImplTest, BadKeyTypeTest) { TestMessage message; google::protobuf::Arena arena; constexpr std::array map_types = { "int64_int32_map", "uint64_int32_map", "string_int32_map", "bool_int32_map", "int32_int32_map", "uint32_uint32_map", }; for (auto map_type : map_types) { auto cel_map = CreateMap(&message, std::string(map_type), &arena); // Look up a boolean key. This should result in an error for both the // presence test and the value lookup. auto result = cel_map->Has(CelValue::CreateNull()); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); auto lookup = (*cel_map)[CelValue::CreateNull()]; EXPECT_TRUE(lookup.has_value()); EXPECT_TRUE(lookup->IsError()); EXPECT_THAT(lookup->ErrorOrDie()->code(), Eq(absl::StatusCode::kInvalidArgument)); } } TEST(FieldBackedMapImplTest, Int32KeyTest) { TestMessage message; auto field_map = message.mutable_int32_int32_map(); (*field_map)[0] = 1; (*field_map)[1] = 2; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "int32_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); // Look up nonexistent key EXPECT_FALSE((*cel_map)[CelValue::CreateInt64(3)].has_value()); EXPECT_FALSE(cel_map->Has(CelValue::CreateInt64(3)).value_or(true)); } TEST(FieldBackedMapImplTest, Int32KeyOutOfRangeTest) { TestMessage message; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "int32_int32_map", &arena); // Look up keys out of int32 range auto result = cel_map->Has( CelValue::CreateInt64(std::numeric_limits::max() + 1L)); EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kOutOfRange, HasSubstr("overflow"))); result = cel_map->Has( CelValue::CreateInt64(std::numeric_limits::lowest() - 1L)); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); } TEST(FieldBackedMapImplTest, Int64KeyTest) { TestMessage message; auto field_map = message.mutable_int64_int32_map(); (*field_map)[0] = 1; (*field_map)[1] = 2; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "int64_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateInt64(3)].has_value(), false); } TEST(FieldBackedMapImplTest, BoolKeyTest) { TestMessage message; auto field_map = message.mutable_bool_int32_map(); (*field_map)[false] = 1; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "bool_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateBool(false)]->Int64OrDie(), 1); EXPECT_TRUE(cel_map->Has(CelValue::CreateBool(false)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)].has_value(), false); (*field_map)[true] = 2; EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)]->Int64OrDie(), 2); } TEST(FieldBackedMapImplTest, Uint32KeyTest) { TestMessage message; auto field_map = message.mutable_uint32_uint32_map(); (*field_map)[0] = 1u; (*field_map)[1] = 2u; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Uint64OrDie(), 1UL); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Uint64OrDie(), 2UL); EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); EXPECT_EQ(cel_map->Has(CelValue::CreateUint64(3)).value_or(true), false); } TEST(FieldBackedMapImplTest, Uint32KeyOutOfRangeTest) { TestMessage message; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); // Look up keys out of uint32 range auto result = cel_map->Has( CelValue::CreateUint64(std::numeric_limits::max() + 1UL)); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); } TEST(FieldBackedMapImplTest, Uint64KeyTest) { TestMessage message; auto field_map = message.mutable_uint64_int32_map(); (*field_map)[0] = 1; (*field_map)[1] = 2; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "uint64_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); } TEST(FieldBackedMapImplTest, StringKeyTest) { TestMessage message; auto field_map = message.mutable_string_int32_map(); (*field_map)["test0"] = 1; (*field_map)["test1"] = 2; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); std::string test0 = "test0"; std::string test1 = "test1"; std::string test_notfound = "test_notfound"; EXPECT_EQ((*cel_map)[CelValue::CreateString(&test0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateString(&test1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateString(&test1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateString(&test_notfound)].has_value(), false); } TEST(FieldBackedMapImplTest, EmptySizeTest) { TestMessage message; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); EXPECT_EQ(cel_map->size(), 0); } TEST(FieldBackedMapImplTest, RepeatedAddTest) { TestMessage message; auto field_map = message.mutable_string_int32_map(); (*field_map)["test0"] = 1; (*field_map)["test1"] = 2; (*field_map)["test0"] = 3; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); EXPECT_EQ(cel_map->size(), 2); } TEST(FieldBackedMapImplTest, KeyListTest) { TestMessage message; auto field_map = message.mutable_string_int32_map(); std::vector keys; std::vector keys1; for (int i = 0; i < 100; i++) { keys.push_back(absl::StrCat("test", i)); (*field_map)[keys.back()] = i; } google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); const CelList* key_list = cel_map->ListKeys().value(); EXPECT_EQ(key_list->size(), 100); for (int i = 0; i < key_list->size(); i++) { keys1.push_back(std::string((*key_list)[i].StringOrDie().value())); } EXPECT_THAT(keys, UnorderedPointwise(Eq(), keys1)); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/containers/internal_field_backed_list_impl.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/containers/internal_field_backed_list_impl.h" #include "eval/public/cel_value.h" #include "eval/public/structs/field_access_impl.h" namespace google::api::expr::runtime::internal { int FieldBackedListImpl::size() const { return reflection_->FieldSize(*message_, descriptor_); } CelValue FieldBackedListImpl::operator[](int index) const { auto result = CreateValueFromRepeatedField(message_, descriptor_, index, factory_, arena_); if (!result.ok()) { CreateErrorValue(arena_, result.status().ToString()); } return *result; } } // namespace google::api::expr::runtime::internal ================================================ FILE: eval/public/containers/internal_field_backed_list_impl.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ #include #include "eval/public/cel_value.h" #include "eval/public/structs/protobuf_value_factory.h" namespace google::api::expr::runtime::internal { // CelList implementation that uses "repeated" message field // as backing storage. // // The internal implementation allows for interface updates without breaking // clients that depend on this class for implementing custom CEL lists class FieldBackedListImpl : public CelList { public: // message contains the "repeated" field // descriptor FieldDescriptor for the field FieldBackedListImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, ProtobufValueFactory factory, google::protobuf::Arena* arena) : message_(message), descriptor_(descriptor), reflection_(message_->GetReflection()), factory_(std::move(factory)), arena_(arena) {} // List size. int size() const override; // List element access operator. CelValue operator[](int index) const override; private: const google::protobuf::Message* message_; const google::protobuf::FieldDescriptor* descriptor_; const google::protobuf::Reflection* reflection_; ProtobufValueFactory factory_; google::protobuf::Arena* arena_; }; } // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_LIST_IMPL_H_ ================================================ FILE: eval/public/containers/internal_field_backed_list_impl_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/containers/internal_field_backed_list_impl.h" #include #include #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "testutil/util.h" namespace google::api::expr::runtime::internal { namespace { using ::google::api::expr::testutil::EqualsProto; using ::testing::DoubleEq; using ::testing::Eq; // Helper method. Creates simple pipeline containing Select step and runs it. std::unique_ptr CreateList(const TestMessage* message, const std::string& field, google::protobuf::Arena* arena) { const google::protobuf::FieldDescriptor* field_desc = message->GetDescriptor()->FindFieldByName(field); return std::make_unique( message, field_desc, &CelProtoWrapper::InternalWrapMessage, arena); } TEST(FieldBackedListImplTest, BoolDatatypeTest) { TestMessage message; message.add_bool_list(true); message.add_bool_list(false); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "bool_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].BoolOrDie(), true); EXPECT_EQ((*cel_list)[1].BoolOrDie(), false); } TEST(FieldBackedListImplTest, TestLength0) { TestMessage message; google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int32_list", &arena); ASSERT_EQ(cel_list->size(), 0); } TEST(FieldBackedListImplTest, TestLength1) { TestMessage message; message.add_int32_list(1); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int32_list", &arena); ASSERT_EQ(cel_list->size(), 1); EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); } TEST(FieldBackedListImplTest, TestLength100000) { TestMessage message; const int kLen = 100000; for (int i = 0; i < kLen; i++) { message.add_int32_list(i); } google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int32_list", &arena); ASSERT_EQ(cel_list->size(), kLen); for (int i = 0; i < kLen; i++) { EXPECT_EQ((*cel_list)[i].Int64OrDie(), i); } } TEST(FieldBackedListImplTest, Int32DatatypeTest) { TestMessage message; message.add_int32_list(1); message.add_int32_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int32_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); EXPECT_EQ((*cel_list)[1].Int64OrDie(), 2); } TEST(FieldBackedListImplTest, Int64DatatypeTest) { TestMessage message; message.add_int64_list(1); message.add_int64_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "int64_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].Int64OrDie(), 1); EXPECT_EQ((*cel_list)[1].Int64OrDie(), 2); } TEST(FieldBackedListImplTest, Uint32DatatypeTest) { TestMessage message; message.add_uint32_list(1); message.add_uint32_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "uint32_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].Uint64OrDie(), 1); EXPECT_EQ((*cel_list)[1].Uint64OrDie(), 2); } TEST(FieldBackedListImplTest, Uint64DatatypeTest) { TestMessage message; message.add_uint64_list(1); message.add_uint64_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "uint64_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].Uint64OrDie(), 1); EXPECT_EQ((*cel_list)[1].Uint64OrDie(), 2); } TEST(FieldBackedListImplTest, FloatDatatypeTest) { TestMessage message; message.add_float_list(1); message.add_float_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "float_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_THAT((*cel_list)[0].DoubleOrDie(), DoubleEq(1)); EXPECT_THAT((*cel_list)[1].DoubleOrDie(), DoubleEq(2)); } TEST(FieldBackedListImplTest, DoubleDatatypeTest) { TestMessage message; message.add_double_list(1); message.add_double_list(2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "double_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_THAT((*cel_list)[0].DoubleOrDie(), DoubleEq(1)); EXPECT_THAT((*cel_list)[1].DoubleOrDie(), DoubleEq(2)); } TEST(FieldBackedListImplTest, StringDatatypeTest) { TestMessage message; message.add_string_list("1"); message.add_string_list("2"); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "string_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].StringOrDie().value(), "1"); EXPECT_EQ((*cel_list)[1].StringOrDie().value(), "2"); } TEST(FieldBackedListImplTest, BytesDatatypeTest) { TestMessage message; message.add_bytes_list("1"); message.add_bytes_list("2"); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "bytes_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_EQ((*cel_list)[0].BytesOrDie().value(), "1"); EXPECT_EQ((*cel_list)[1].BytesOrDie().value(), "2"); } TEST(FieldBackedListImplTest, MessageDatatypeTest) { TestMessage message; TestMessage* msg1 = message.add_message_list(); TestMessage* msg2 = message.add_message_list(); msg1->set_string_value("1"); msg2->set_string_value("2"); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "message_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_THAT(*msg1, EqualsProto(*((*cel_list)[0].MessageOrDie()))); EXPECT_THAT(*msg2, EqualsProto(*((*cel_list)[1].MessageOrDie()))); } TEST(FieldBackedListImplTest, EnumDatatypeTest) { TestMessage message; message.add_enum_list(TestMessage::TEST_ENUM_1); message.add_enum_list(TestMessage::TEST_ENUM_2); google::protobuf::Arena arena; auto cel_list = CreateList(&message, "enum_list", &arena); ASSERT_EQ(cel_list->size(), 2); EXPECT_THAT((*cel_list)[0].Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); EXPECT_THAT((*cel_list)[1].Int64OrDie(), Eq(TestMessage::TEST_ENUM_2)); } } // namespace } // namespace google::api::expr::runtime::internal ================================================ FILE: eval/public/containers/internal_field_backed_map_impl.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/containers/internal_field_backed_map_impl.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/public/cel_value.h" #include "eval/public/structs/field_access_impl.h" #include "eval/public/structs/protobuf_value_factory.h" #include "extensions/protobuf/internal/map_reflection.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime::internal { namespace { using google::protobuf::Descriptor; using google::protobuf::FieldDescriptor; using google::protobuf::MapValueConstRef; using google::protobuf::Message; // Map entries have two field tags // 1 - for key // 2 - for value constexpr int kKeyTag = 1; constexpr int kValueTag = 2; class KeyList : public CelList { public: // message contains the "repeated" field // descriptor FieldDescriptor for the field KeyList(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, const ProtobufValueFactory& factory, google::protobuf::Arena* arena) : message_(message), descriptor_(descriptor), reflection_(message_->GetReflection()), factory_(factory), arena_(arena) {} // List size. int size() const override { return reflection_->FieldSize(*message_, descriptor_); } // List element access operator. CelValue operator[](int index) const override { const Message* entry = &reflection_->GetRepeatedMessage(*message_, descriptor_, index); if (entry == nullptr) { return CelValue::CreateNull(); } const Descriptor* entry_descriptor = entry->GetDescriptor(); // Key Tag == 1 const FieldDescriptor* key_desc = entry_descriptor->FindFieldByNumber(kKeyTag); absl::StatusOr key_value = CreateValueFromSingleField( entry, key_desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, factory_, arena_); if (!key_value.ok()) { return CreateErrorValue(arena_, key_value.status()); } return *key_value; } private: const google::protobuf::Message* message_; const google::protobuf::FieldDescriptor* descriptor_; const google::protobuf::Reflection* reflection_; const ProtobufValueFactory& factory_; google::protobuf::Arena* arena_; }; bool MatchesMapKeyType(const FieldDescriptor* key_desc, const CelValue& key) { switch (key_desc->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: return key.IsBool(); case google::protobuf::FieldDescriptor::CPPTYPE_INT32: // fall through case google::protobuf::FieldDescriptor::CPPTYPE_INT64: return key.IsInt64(); case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: // fall through case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: return key.IsUint64(); case google::protobuf::FieldDescriptor::CPPTYPE_STRING: return key.IsString(); default: return false; } } absl::Status InvalidMapKeyType(absl::string_view key_type) { return absl::InvalidArgumentError( absl::StrCat("Invalid map key type: '", key_type, "'")); } } // namespace FieldBackedMapImpl::FieldBackedMapImpl( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, ProtobufValueFactory factory, google::protobuf::Arena* arena) : message_(message), descriptor_(descriptor), key_desc_(descriptor_->message_type()->FindFieldByNumber(kKeyTag)), value_desc_(descriptor_->message_type()->FindFieldByNumber(kValueTag)), reflection_(message_->GetReflection()), factory_(std::move(factory)), arena_(arena), key_list_( std::make_unique(message, descriptor, factory_, arena)) {} int FieldBackedMapImpl::size() const { return reflection_->FieldSize(*message_, descriptor_); } absl::StatusOr FieldBackedMapImpl::ListKeys() const { return key_list_.get(); } absl::StatusOr FieldBackedMapImpl::Has(const CelValue& key) const { MapValueConstRef value_ref; return LookupMapValue(key, &value_ref); } absl::optional FieldBackedMapImpl::operator[](CelValue key) const { // Fast implementation which uses a friend method to do a hash-based key // lookup. MapValueConstRef value_ref; auto lookup_result = LookupMapValue(key, &value_ref); if (!lookup_result.ok()) { return CreateErrorValue(arena_, lookup_result.status()); } if (!*lookup_result) { return absl::nullopt; } // Get value descriptor treating it as a repeated field. // All values in protobuf map have the same type. // The map is not empty, because LookupMapValue returned true. absl::StatusOr result = CreateValueFromMapValue( message_, value_desc_, &value_ref, factory_, arena_); if (!result.ok()) { return CreateErrorValue(arena_, result.status()); } return *result; } absl::StatusOr FieldBackedMapImpl::LookupMapValue( const CelValue& key, MapValueConstRef* value_ref) const { if (!MatchesMapKeyType(key_desc_, key)) { return InvalidMapKeyType(key_desc_->cpp_type_name()); } std::string map_key_string; google::protobuf::MapKey proto_key; switch (key_desc_->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { bool key_value; key.GetValue(&key_value); proto_key.SetBoolValue(key_value); } break; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { int64_t key_value; key.GetValue(&key_value); if (key_value > std::numeric_limits::max() || key_value < std::numeric_limits::lowest()) { return absl::OutOfRangeError("integer overflow"); } proto_key.SetInt32Value(key_value); } break; case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { int64_t key_value; key.GetValue(&key_value); proto_key.SetInt64Value(key_value); } break; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { CelValue::StringHolder key_value; key.GetValue(&key_value); map_key_string.assign(key_value.value().data(), key_value.value().size()); proto_key.SetStringValue(map_key_string); } break; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { uint64_t key_value; key.GetValue(&key_value); if (key_value > std::numeric_limits::max()) { return absl::OutOfRangeError("unsigned integer overlow"); } proto_key.SetUInt32Value(key_value); } break; case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { uint64_t key_value; key.GetValue(&key_value); proto_key.SetUInt64Value(key_value); } break; default: return InvalidMapKeyType(key_desc_->cpp_type_name()); } // Look the value up return cel::extensions::protobuf_internal::LookupMapValue( *reflection_, *message_, *descriptor_, proto_key, value_ref); } absl::StatusOr FieldBackedMapImpl::LegacyHasMapValue( const CelValue& key) const { auto lookup_result = LegacyLookupMapValue(key); if (!lookup_result.has_value()) { return false; } auto result = *lookup_result; if (result.IsError()) { return *(result.ErrorOrDie()); } return true; } absl::optional FieldBackedMapImpl::LegacyLookupMapValue( const CelValue& key) const { // Ensure that the key matches the key type. if (!MatchesMapKeyType(key_desc_, key)) { return CreateErrorValue(arena_, InvalidMapKeyType(key_desc_->cpp_type_name())); } int map_size = size(); for (int i = 0; i < map_size; i++) { const Message* entry = &reflection_->GetRepeatedMessage(*message_, descriptor_, i); if (entry == nullptr) continue; // Key Tag == 1 absl::StatusOr key_value = CreateValueFromSingleField( entry, key_desc_, ProtoWrapperTypeOptions::kUnsetProtoDefault, factory_, arena_); if (!key_value.ok()) { return CreateErrorValue(arena_, key_value.status()); } bool match = false; switch (key_desc_->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: match = key.BoolOrDie() == key_value->BoolOrDie(); break; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: // fall through case google::protobuf::FieldDescriptor::CPPTYPE_INT64: match = key.Int64OrDie() == key_value->Int64OrDie(); break; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: // fall through case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: match = key.Uint64OrDie() == key_value->Uint64OrDie(); break; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: match = key.StringOrDie() == key_value->StringOrDie(); break; default: // this would normally indicate a bad key type, which should not be // possible based on the earlier test. break; } if (match) { absl::StatusOr value_cel_value = CreateValueFromSingleField( entry, value_desc_, ProtoWrapperTypeOptions::kUnsetProtoDefault, factory_, arena_); if (!value_cel_value.ok()) { return CreateErrorValue(arena_, value_cel_value.status()); } return *value_cel_value; } } return {}; } } // namespace google::api::expr::runtime::internal ================================================ FILE: eval/public/containers/internal_field_backed_map_impl.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ #include "absl/status/statusor.h" #include "eval/public/cel_value.h" #include "eval/public/structs/protobuf_value_factory.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime::internal { // CelMap implementation that uses "map" message field // as backing storage. class FieldBackedMapImpl : public CelMap { public: // message contains the "map" field. Object stores the pointer // to the message, thus it is expected that message outlives the // object. // descriptor FieldDescriptor for the field FieldBackedMapImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, ProtobufValueFactory factory, google::protobuf::Arena* arena); // Map size. int size() const override; // Map element access operator. absl::optional operator[](CelValue key) const override; // Presence test function. absl::StatusOr Has(const CelValue& key) const override; absl::StatusOr ListKeys() const override; // Include base class definitions to avoid GCC warnings about hidden virtual // overloads. using CelMap::ListKeys; protected: // These methods are exposed as protected methods for testing purposes since // whether one or the other is used depends on build time flags, but each // should be tested accordingly. absl::StatusOr LookupMapValue( const CelValue& key, google::protobuf::MapValueConstRef* value_ref) const; absl::StatusOr LegacyHasMapValue(const CelValue& key) const; absl::optional LegacyLookupMapValue(const CelValue& key) const; private: const google::protobuf::Message* message_; const google::protobuf::FieldDescriptor* descriptor_; const google::protobuf::FieldDescriptor* key_desc_; const google::protobuf::FieldDescriptor* value_desc_; const google::protobuf::Reflection* reflection_; ProtobufValueFactory factory_; google::protobuf::Arena* arena_; std::unique_ptr key_list_; }; } // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CONTAINERS_INTERNAL_FIELD_BACKED_MAP_IMPL_H_ ================================================ FILE: eval/public/containers/internal_field_backed_map_impl_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/containers/internal_field_backed_map_impl.h" #include #include #include #include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" namespace google::api::expr::runtime::internal { namespace { using ::absl_testing::StatusIs; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::UnorderedPointwise; class FieldBackedMapTestImpl : public FieldBackedMapImpl { public: FieldBackedMapTestImpl(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, google::protobuf::Arena* arena) : FieldBackedMapImpl(message, descriptor, &CelProtoWrapper::InternalWrapMessage, arena) {} // For code coverage, expose fallback lookups used when not compiled with // support for optimized versions. using FieldBackedMapImpl::LegacyHasMapValue; using FieldBackedMapImpl::LegacyLookupMapValue; }; // Helper method. Creates simple pipeline containing Select step and runs it. std::unique_ptr CreateMap(const TestMessage* message, const std::string& field, google::protobuf::Arena* arena) { const google::protobuf::FieldDescriptor* field_desc = message->GetDescriptor()->FindFieldByName(field); return std::make_unique(message, field_desc, arena); } TEST(FieldBackedMapImplTest, BadKeyTypeTest) { TestMessage message; google::protobuf::Arena arena; constexpr std::array map_types = { "int64_int32_map", "uint64_int32_map", "string_int32_map", "bool_int32_map", "int32_int32_map", "uint32_uint32_map", }; for (auto map_type : map_types) { auto cel_map = CreateMap(&message, std::string(map_type), &arena); // Look up a boolean key. This should result in an error for both the // presence test and the value lookup. auto result = cel_map->Has(CelValue::CreateNull()); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); result = cel_map->LegacyHasMapValue(CelValue::CreateNull()); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kInvalidArgument)); auto lookup = (*cel_map)[CelValue::CreateNull()]; EXPECT_TRUE(lookup.has_value()); EXPECT_TRUE(lookup->IsError()); EXPECT_THAT(lookup->ErrorOrDie()->code(), Eq(absl::StatusCode::kInvalidArgument)); lookup = cel_map->LegacyLookupMapValue(CelValue::CreateNull()); EXPECT_TRUE(lookup.has_value()); EXPECT_TRUE(lookup->IsError()); EXPECT_THAT(lookup->ErrorOrDie()->code(), Eq(absl::StatusCode::kInvalidArgument)); } } TEST(FieldBackedMapImplTest, Int32KeyTest) { TestMessage message; auto field_map = message.mutable_int32_int32_map(); (*field_map)[0] = 1; (*field_map)[1] = 2; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "int32_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); EXPECT_TRUE( cel_map->LegacyHasMapValue(CelValue::CreateInt64(1)).value_or(false)); // Look up nonexistent key EXPECT_FALSE((*cel_map)[CelValue::CreateInt64(3)].has_value()); EXPECT_FALSE(cel_map->Has(CelValue::CreateInt64(3)).value_or(true)); EXPECT_FALSE( cel_map->LegacyHasMapValue(CelValue::CreateInt64(3)).value_or(true)); } TEST(FieldBackedMapImplTest, Int32KeyOutOfRangeTest) { TestMessage message; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "int32_int32_map", &arena); // Look up keys out of int32 range auto result = cel_map->Has( CelValue::CreateInt64(std::numeric_limits::max() + 1L)); EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kOutOfRange, HasSubstr("overflow"))); result = cel_map->Has( CelValue::CreateInt64(std::numeric_limits::lowest() - 1L)); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); } TEST(FieldBackedMapImplTest, Int64KeyTest) { TestMessage message; auto field_map = message.mutable_int64_int32_map(); (*field_map)[0] = 1; (*field_map)[1] = 2; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "int64_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateInt64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateInt64(1)).value_or(false)); EXPECT_EQ( cel_map->LegacyLookupMapValue(CelValue::CreateInt64(1))->Int64OrDie(), 2); EXPECT_TRUE( cel_map->LegacyHasMapValue(CelValue::CreateInt64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateInt64(3)].has_value(), false); } TEST(FieldBackedMapImplTest, BoolKeyTest) { TestMessage message; auto field_map = message.mutable_bool_int32_map(); (*field_map)[false] = 1; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "bool_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateBool(false)]->Int64OrDie(), 1); EXPECT_TRUE(cel_map->Has(CelValue::CreateBool(false)).value_or(false)); EXPECT_TRUE( cel_map->LegacyHasMapValue(CelValue::CreateBool(false)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)].has_value(), false); (*field_map)[true] = 2; EXPECT_EQ((*cel_map)[CelValue::CreateBool(true)]->Int64OrDie(), 2); } TEST(FieldBackedMapImplTest, Uint32KeyTest) { TestMessage message; auto field_map = message.mutable_uint32_uint32_map(); (*field_map)[0] = 1u; (*field_map)[1] = 2u; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Uint64OrDie(), 1UL); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Uint64OrDie(), 2UL); EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); EXPECT_TRUE( cel_map->LegacyHasMapValue(CelValue::CreateUint64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); EXPECT_EQ(cel_map->Has(CelValue::CreateUint64(3)).value_or(true), false); } TEST(FieldBackedMapImplTest, Uint32KeyOutOfRangeTest) { TestMessage message; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "uint32_uint32_map", &arena); // Look up keys out of uint32 range auto result = cel_map->Has( CelValue::CreateUint64(std::numeric_limits::max() + 1UL)); EXPECT_FALSE(result.ok()); EXPECT_THAT(result.status().code(), Eq(absl::StatusCode::kOutOfRange)); } TEST(FieldBackedMapImplTest, Uint64KeyTest) { TestMessage message; auto field_map = message.mutable_uint64_int32_map(); (*field_map)[0] = 1; (*field_map)[1] = 2; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "uint64_int32_map", &arena); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateUint64(1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateUint64(1)).value_or(false)); EXPECT_TRUE( cel_map->LegacyHasMapValue(CelValue::CreateUint64(1)).value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateUint64(3)].has_value(), false); } TEST(FieldBackedMapImplTest, StringKeyTest) { TestMessage message; auto field_map = message.mutable_string_int32_map(); (*field_map)["test0"] = 1; (*field_map)["test1"] = 2; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); std::string test0 = "test0"; std::string test1 = "test1"; std::string test_notfound = "test_notfound"; EXPECT_EQ((*cel_map)[CelValue::CreateString(&test0)]->Int64OrDie(), 1); EXPECT_EQ((*cel_map)[CelValue::CreateString(&test1)]->Int64OrDie(), 2); EXPECT_TRUE(cel_map->Has(CelValue::CreateString(&test1)).value_or(false)); EXPECT_TRUE(cel_map->LegacyHasMapValue(CelValue::CreateString(&test1)) .value_or(false)); // Look up nonexistent key EXPECT_EQ((*cel_map)[CelValue::CreateString(&test_notfound)].has_value(), false); } TEST(FieldBackedMapImplTest, EmptySizeTest) { TestMessage message; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); EXPECT_EQ(cel_map->size(), 0); } TEST(FieldBackedMapImplTest, RepeatedAddTest) { TestMessage message; auto field_map = message.mutable_string_int32_map(); (*field_map)["test0"] = 1; (*field_map)["test1"] = 2; (*field_map)["test0"] = 3; google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); EXPECT_EQ(cel_map->size(), 2); } TEST(FieldBackedMapImplTest, KeyListTest) { TestMessage message; auto field_map = message.mutable_string_int32_map(); std::vector keys; std::vector keys1; for (int i = 0; i < 100; i++) { keys.push_back(absl::StrCat("test", i)); (*field_map)[keys.back()] = i; } google::protobuf::Arena arena; auto cel_map = CreateMap(&message, "string_int32_map", &arena); const CelList* key_list = cel_map->ListKeys().value(); EXPECT_EQ(key_list->size(), 100); for (int i = 0; i < key_list->size(); i++) { keys1.push_back(std::string((*key_list)[i].StringOrDie().value())); } EXPECT_THAT(keys, UnorderedPointwise(Eq(), keys1)); } } // namespace } // namespace google::api::expr::runtime::internal ================================================ FILE: eval/public/equality_function_registrar.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/equality_function_registrar.h" #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/runtime_options.h" #include "runtime/standard/equality_functions.h" namespace google::api::expr::runtime { absl::Status RegisterEqualityFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { cel::RuntimeOptions runtime_options = ConvertToRuntimeOptions(options); return cel::RegisterEqualityFunctions(registry->InternalGetRegistry(), runtime_options); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/equality_function_registrar.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EQUALITY_FUNCTION_REGISTRAR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EQUALITY_FUNCTION_REGISTRAR_H_ #include "absl/status/status.h" #include "eval/internal/cel_value_equal.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" namespace google::api::expr::runtime { // Implementation for general equality between CELValues. Exposed for // consistent behavior in set membership functions. // // Returns nullopt if the comparison is undefined between differently typed // values. using cel::interop_internal::CelValueEqualImpl; // Register built in comparison functions (==, !=). // // Most users should prefer to use RegisterBuiltinFunctions. // // This call is included in RegisterBuiltinFunctions -- calling both // RegisterBuiltinFunctions and RegisterComparisonFunctions directly on the same // registry will result in an error. absl::Status RegisterEqualityFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EQUALITY_FUNCTION_REGISTRAR_H_ ================================================ FILE: eval/public/equality_function_registrar_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/equality_function_registrar.h" #include #include #include #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/any.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "eval/public/activation.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" // IWYU pragma: keep #include "internal/benchmark.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::rpc::context::AttributeContext; using ::testing::_; using ::testing::Combine; using ::testing::HasSubstr; using ::testing::Optional; using ::testing::Values; using ::testing::ValuesIn; MATCHER_P2(DefinesHomogenousOverload, name, argument_type, absl::StrCat(name, " for ", CelValue::TypeName(argument_type))) { const CelFunctionRegistry& registry = arg; return !registry .FindOverloads(name, /*receiver_style=*/false, {argument_type, argument_type}) .empty(); return false; } struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; absl::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; bool IsNumeric(CelValue::Type type) { return type == CelValue::Type::kDouble || type == CelValue::Type::kInt64 || type == CelValue::Type::kUint64; } const CelList& CelListExample1() { static ContainerBackedListImpl* example = new ContainerBackedListImpl({CelValue::CreateInt64(1)}); return *example; } const CelList& CelListExample2() { static ContainerBackedListImpl* example = new ContainerBackedListImpl({CelValue::CreateInt64(2)}); return *example; } const CelMap& CelMapExample1() { static CelMap* example = []() { std::vector> values{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}}; // Implementation copies values into a hash map. auto map = CreateContainerBackedMap(absl::MakeSpan(values)); return map->release(); }(); return *example; } const CelMap& CelMapExample2() { static CelMap* example = []() { std::vector> values{ {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; auto map = CreateContainerBackedMap(absl::MakeSpan(values)); return map->release(); }(); return *example; } const std::vector& ValueExamples1() { static std::vector* examples = []() { google::protobuf::Arena arena; auto result = std::make_unique>(); result->push_back(CelValue::CreateNull()); result->push_back(CelValue::CreateBool(false)); result->push_back(CelValue::CreateInt64(1)); result->push_back(CelValue::CreateUint64(1)); result->push_back(CelValue::CreateDouble(1.0)); result->push_back(CelValue::CreateStringView("string")); result->push_back(CelValue::CreateBytesView("bytes")); // No arena allocs expected in this example. result->push_back(CelProtoWrapper::CreateMessage( std::make_unique().release(), &arena)); result->push_back(CelValue::CreateDuration(absl::Seconds(1))); result->push_back(CelValue::CreateTimestamp(absl::FromUnixSeconds(1))); result->push_back(CelValue::CreateList(&CelListExample1())); result->push_back(CelValue::CreateMap(&CelMapExample1())); result->push_back(CelValue::CreateCelTypeView("type")); return result.release(); }(); return *examples; } const std::vector& ValueExamples2() { static std::vector* examples = []() { google::protobuf::Arena arena; auto result = std::make_unique>(); auto message2 = std::make_unique(); message2->set_int64_value(2); result->push_back(CelValue::CreateNull()); result->push_back(CelValue::CreateBool(true)); result->push_back(CelValue::CreateInt64(2)); result->push_back(CelValue::CreateUint64(2)); result->push_back(CelValue::CreateDouble(2.0)); result->push_back(CelValue::CreateStringView("string2")); result->push_back(CelValue::CreateBytesView("bytes2")); // No arena allocs expected in this example. result->push_back( CelProtoWrapper::CreateMessage(message2.release(), &arena)); result->push_back(CelValue::CreateDuration(absl::Seconds(2))); result->push_back(CelValue::CreateTimestamp(absl::FromUnixSeconds(2))); result->push_back(CelValue::CreateList(&CelListExample2())); result->push_back(CelValue::CreateMap(&CelMapExample2())); result->push_back(CelValue::CreateCelTypeView("type2")); return result.release(); }(); return *examples; } class CelValueEqualImplTypesTest : public testing::TestWithParam> { public: CelValueEqualImplTypesTest() = default; const CelValue& lhs() { return std::get<0>(GetParam()); } const CelValue& rhs() { return std::get<1>(GetParam()); } bool should_be_equal() { return std::get<2>(GetParam()); } }; std::string CelValueEqualTestName( const testing::TestParamInfo>& test_case) { return absl::StrCat(CelValue::TypeName(std::get<0>(test_case.param).type()), CelValue::TypeName(std::get<1>(test_case.param).type()), (std::get<2>(test_case.param)) ? "Equal" : "Inequal"); } TEST_P(CelValueEqualImplTypesTest, Basic) { absl::optional result = CelValueEqualImpl(lhs(), rhs()); if (lhs().IsNull() || rhs().IsNull()) { if (lhs().IsNull() && rhs().IsNull()) { EXPECT_THAT(result, Optional(true)); } else { EXPECT_THAT(result, Optional(false)); } } else if (lhs().type() == rhs().type() || (IsNumeric(lhs().type()) && IsNumeric(rhs().type()))) { EXPECT_THAT(result, Optional(should_be_equal())); } else { EXPECT_THAT(result, Optional(false)); } } INSTANTIATE_TEST_SUITE_P(EqualityBetweenTypes, CelValueEqualImplTypesTest, Combine(ValuesIn(ValueExamples1()), ValuesIn(ValueExamples1()), Values(true)), &CelValueEqualTestName); INSTANTIATE_TEST_SUITE_P(InequalityBetweenTypes, CelValueEqualImplTypesTest, Combine(ValuesIn(ValueExamples1()), ValuesIn(ValueExamples2()), Values(false)), &CelValueEqualTestName); struct NumericInequalityTestCase { std::string name; CelValue a; CelValue b; }; const std::vector& NumericValuesNotEqualExample() { static std::vector* examples = []() { auto result = std::make_unique>(); result->push_back({"NegativeIntAndUint", CelValue::CreateInt64(-1), CelValue::CreateUint64(2)}); result->push_back( {"IntAndLargeUint", CelValue::CreateInt64(1), CelValue::CreateUint64( static_cast(std::numeric_limits::max()) + 1)}); result->push_back( {"IntAndLargeDouble", CelValue::CreateInt64(2), CelValue::CreateDouble( static_cast(std::numeric_limits::max()) + 1025)}); result->push_back( {"IntAndSmallDouble", CelValue::CreateInt64(2), CelValue::CreateDouble( static_cast(std::numeric_limits::lowest()) - 1025)}); result->push_back( {"UintAndLargeDouble", CelValue::CreateUint64(2), CelValue::CreateDouble( static_cast(std::numeric_limits::max()) + 2049)}); result->push_back({"NegativeDoubleAndUint", CelValue::CreateDouble(-2.0), CelValue::CreateUint64(123)}); // NaN tests. result->push_back({"NanAndDouble", CelValue::CreateDouble(NAN), CelValue::CreateDouble(1.0)}); result->push_back({"NanAndNan", CelValue::CreateDouble(NAN), CelValue::CreateDouble(NAN)}); result->push_back({"DoubleAndNan", CelValue::CreateDouble(1.0), CelValue::CreateDouble(NAN)}); result->push_back( {"IntAndNan", CelValue::CreateInt64(1), CelValue::CreateDouble(NAN)}); result->push_back( {"NanAndInt", CelValue::CreateDouble(NAN), CelValue::CreateInt64(1)}); result->push_back( {"UintAndNan", CelValue::CreateUint64(1), CelValue::CreateDouble(NAN)}); result->push_back( {"NanAndUint", CelValue::CreateDouble(NAN), CelValue::CreateUint64(1)}); return result.release(); }(); return *examples; } using NumericInequalityTest = testing::TestWithParam; TEST_P(NumericInequalityTest, NumericValues) { NumericInequalityTestCase test_case = GetParam(); absl::optional result = CelValueEqualImpl(test_case.a, test_case.b); EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, false); } INSTANTIATE_TEST_SUITE_P( InequalityBetweenNumericTypesTest, NumericInequalityTest, ValuesIn(NumericValuesNotEqualExample()), [](const testing::TestParamInfo& info) { return info.param.name; }); TEST(CelValueEqualImplTest, LossyNumericEquality) { absl::optional result = CelValueEqualImpl( CelValue::CreateDouble( static_cast(std::numeric_limits::max()) - 1), CelValue::CreateInt64(std::numeric_limits::max())); EXPECT_TRUE(result.has_value()); EXPECT_TRUE(*result); } TEST(CelValueEqualImplTest, ListMixedTypesInequal) { ContainerBackedListImpl lhs({CelValue::CreateInt64(1)}); ContainerBackedListImpl rhs({CelValue::CreateStringView("abc")}); EXPECT_THAT( CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), Optional(false)); } TEST(CelValueEqualImplTest, NestedList) { ContainerBackedListImpl inner_lhs({CelValue::CreateInt64(1)}); ContainerBackedListImpl lhs({CelValue::CreateList(&inner_lhs)}); ContainerBackedListImpl inner_rhs({CelValue::CreateNull()}); ContainerBackedListImpl rhs({CelValue::CreateList(&inner_rhs)}); EXPECT_THAT( CelValueEqualImpl(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs)), Optional(false)); } TEST(CelValueEqualImplTest, MapMixedValueTypesInequal) { std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; std::vector> rhs_data{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}}; ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, CreateContainerBackedMap(absl::MakeSpan(lhs_data))); ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), CelValue::CreateMap(rhs.get())), Optional(false)); } TEST(CelValueEqualImplTest, MapMixedKeyTypesEqual) { std::vector> lhs_data{ {CelValue::CreateUint64(1), CelValue::CreateStringView("abc")}}; std::vector> rhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, CreateContainerBackedMap(absl::MakeSpan(lhs_data))); ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), CelValue::CreateMap(rhs.get())), Optional(true)); } TEST(CelValueEqualImplTest, MapMixedKeyTypesInequal) { std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateStringView("abc")}}; std::vector> rhs_data{ {CelValue::CreateInt64(2), CelValue::CreateInt64(2)}}; ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, CreateContainerBackedMap(absl::MakeSpan(lhs_data))); ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), CelValue::CreateMap(rhs.get())), Optional(false)); } TEST(CelValueEqualImplTest, NestedMaps) { std::vector> inner_lhs_data{ {CelValue::CreateInt64(2), CelValue::CreateStringView("abc")}}; ASSERT_OK_AND_ASSIGN( std::unique_ptr inner_lhs, CreateContainerBackedMap(absl::MakeSpan(inner_lhs_data))); std::vector> lhs_data{ {CelValue::CreateInt64(1), CelValue::CreateMap(inner_lhs.get())}}; std::vector> inner_rhs_data{ {CelValue::CreateInt64(2), CelValue::CreateNull()}}; ASSERT_OK_AND_ASSIGN( std::unique_ptr inner_rhs, CreateContainerBackedMap(absl::MakeSpan(inner_rhs_data))); std::vector> rhs_data{ {CelValue::CreateInt64(1), CelValue::CreateMap(inner_rhs.get())}}; ASSERT_OK_AND_ASSIGN(std::unique_ptr lhs, CreateContainerBackedMap(absl::MakeSpan(lhs_data))); ASSERT_OK_AND_ASSIGN(std::unique_ptr rhs, CreateContainerBackedMap(absl::MakeSpan(rhs_data))); EXPECT_THAT(CelValueEqualImpl(CelValue::CreateMap(lhs.get()), CelValue::CreateMap(rhs.get())), Optional(false)); } TEST(CelValueEqualImplTest, ProtoEqualityDifferingTypenameInequal) { // If message wrappers report a different typename, treat as inequal without // calling into the provided equal implementation. google::protobuf::Arena arena; TestMessage example; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( int32_value: 1 uint32_value: 2 string_value: "test" )", &example)); CelValue lhs = CelProtoWrapper::CreateMessage(&example, &arena); CelValue rhs = CelValue::CreateMessageWrapper( MessageWrapper(&example, TrivialTypeInfo::GetInstance())); EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); } TEST(CelValueEqualImplTest, ProtoEqualityNoAccessorInequal) { // If message wrappers report no access apis, then treat as inequal. TestMessage example; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( int32_value: 1 uint32_value: 2 string_value: "test" )", &example)); CelValue lhs = CelValue::CreateMessageWrapper( MessageWrapper(&example, TrivialTypeInfo::GetInstance())); CelValue rhs = CelValue::CreateMessageWrapper( MessageWrapper(&example, TrivialTypeInfo::GetInstance())); EXPECT_THAT(CelValueEqualImpl(lhs, rhs), Optional(false)); } TEST(CelValueEqualImplTest, ProtoEqualityAny) { google::protobuf::Arena arena; TestMessage packed_value; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( int32_value: 1 uint32_value: 2 string_value: "test" )", &packed_value)); TestMessage lhs; lhs.mutable_any_value()->PackFrom(packed_value); TestMessage rhs; rhs.mutable_any_value()->PackFrom(packed_value); EXPECT_THAT(CelValueEqualImpl(CelProtoWrapper::CreateMessage(&lhs, &arena), CelProtoWrapper::CreateMessage(&rhs, &arena)), Optional(true)); // Equality falls back to bytewise comparison if type is missing. lhs.mutable_any_value()->clear_type_url(); rhs.mutable_any_value()->clear_type_url(); EXPECT_THAT(CelValueEqualImpl(CelProtoWrapper::CreateMessage(&lhs, &arena), CelProtoWrapper::CreateMessage(&rhs, &arena)), Optional(true)); } // Add transitive dependencies in appropriate order for the dynamic descriptor // pool. // Return false if the dependencies could not be added to the pool. bool AddDepsToPool(const google::protobuf::FileDescriptor* descriptor, google::protobuf::DescriptorPool& pool) { for (int i = 0; i < descriptor->dependency_count(); i++) { if (!AddDepsToPool(descriptor->dependency(i), pool)) { return false; } } google::protobuf::FileDescriptorProto descriptor_proto; descriptor->CopyTo(&descriptor_proto); return pool.BuildFile(descriptor_proto) != nullptr; } // Equivalent descriptors managed by separate descriptor pools are not equal, so // the underlying messages are not considered equal. TEST(CelValueEqualImplTest, DynamicDescriptorAndGeneratedInequal) { // Simulate a dynamically loaded descriptor that happens to match the // compiled version. google::protobuf::DescriptorPool pool; google::protobuf::DynamicMessageFactory factory; google::protobuf::Arena arena; factory.SetDelegateToGeneratedFactory(false); ASSERT_TRUE(AddDepsToPool(TestMessage::descriptor()->file(), pool)); TestMessage example_message; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(R"pb( int64_value: 12345 bool_list: false bool_list: true message_value { float_value: 1.0 } )pb", &example_message)); // Messages from a loaded descriptor and generated versions can't be compared // via MessageDifferencer, so return false. std::unique_ptr example_dynamic_message( factory .GetPrototype(pool.FindMessageTypeByName( TestMessage::descriptor()->full_name())) ->New()); ASSERT_TRUE(example_dynamic_message->ParseFromString( example_message.SerializeAsString())); EXPECT_THAT(CelValueEqualImpl( CelProtoWrapper::CreateMessage(&example_message, &arena), CelProtoWrapper::CreateMessage(example_dynamic_message.get(), &arena)), Optional(false)); } TEST(CelValueEqualImplTest, DynamicMessageAndMessageEqual) { google::protobuf::DynamicMessageFactory factory; google::protobuf::Arena arena; factory.SetDelegateToGeneratedFactory(false); TestMessage example_message; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(R"pb( int64_value: 12345 bool_list: false bool_list: true message_value { float_value: 1.0 } )pb", &example_message)); // Dynamic message and generated Message subclass with the same generated // descriptor are comparable. std::unique_ptr example_dynamic_message( factory.GetPrototype(TestMessage::descriptor())->New()); ASSERT_TRUE(example_dynamic_message->ParseFromString( example_message.SerializeAsString())); EXPECT_THAT(CelValueEqualImpl( CelProtoWrapper::CreateMessage(&example_message, &arena), CelProtoWrapper::CreateMessage(example_dynamic_message.get(), &arena)), Optional(true)); } class EqualityFunctionTest : public testing::TestWithParam> { public: EqualityFunctionTest() { options_.enable_heterogeneous_equality = std::get<1>(GetParam()); options_.enable_empty_wrapper_null_unboxing = true; builder_ = CreateCelExpressionBuilder(options_); } CelFunctionRegistry& registry() { return *builder_->GetRegistry(); } absl::StatusOr Evaluate(absl::string_view expr, const CelValue& lhs, const CelValue& rhs) { CEL_ASSIGN_OR_RETURN(ParsedExpr parsed_expr, parser::Parse(expr)); Activation activation; activation.InsertValue("lhs", lhs); activation.InsertValue("rhs", rhs); CEL_ASSIGN_OR_RETURN(auto expression, builder_->CreateExpression( &parsed_expr.expr(), &parsed_expr.source_info())); return expression->Evaluate(activation, &arena_); } protected: std::unique_ptr builder_; InterpreterOptions options_; google::protobuf::Arena arena_; }; constexpr std::array kEqualableTypes = { CelValue::Type::kInt64, CelValue::Type::kUint64, CelValue::Type::kString, CelValue::Type::kDouble, CelValue::Type::kBytes, CelValue::Type::kDuration, CelValue::Type::kMap, CelValue::Type::kList, CelValue::Type::kBool, CelValue::Type::kTimestamp}; TEST(RegisterEqualityFunctionsTest, EqualDefined) { InterpreterOptions options; options.enable_fast_builtins = false; CelFunctionRegistry registry; ASSERT_THAT(RegisterEqualityFunctions(®istry, options), IsOk()); for (CelValue::Type type : kEqualableTypes) { EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kEqual, type)); } } TEST(RegisterEqualityFunctionsTest, InequalDefined) { InterpreterOptions options; options.enable_fast_builtins = false; CelFunctionRegistry registry; ASSERT_THAT(RegisterEqualityFunctions(®istry, options), IsOk()); for (CelValue::Type type : kEqualableTypes) { EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kInequal, type)); } } TEST_P(EqualityFunctionTest, SmokeTest) { EqualityTestCase test_case = std::get<0>(GetParam()); google::protobuf::LinkMessageReflection(); ASSERT_THAT(RegisterEqualityFunctions(®istry(), options_), IsOk()); ASSERT_OK_AND_ASSIGN(auto result, Evaluate(test_case.expr, test_case.lhs, test_case.rhs)); if (absl::holds_alternative(test_case.result)) { EXPECT_THAT(result, test::IsCelBool(absl::get(test_case.result))); } else { switch (absl::get(test_case.result)) { case EqualityTestCase::ErrorKind::kMissingOverload: EXPECT_THAT(result, test::IsCelError( StatusIs(absl::StatusCode::kUnknown, HasSubstr("No matching overloads")))) << test_case.expr; break; case EqualityTestCase::ErrorKind::kMissingIdentifier: EXPECT_THAT(result, test::IsCelError( StatusIs(absl::StatusCode::kUnknown, HasSubstr("found in Activation")))); break; default: EXPECT_THAT(result, test::IsCelError(_)); break; } } } INSTANTIATE_TEST_SUITE_P( Equality, EqualityFunctionTest, Combine(testing::ValuesIn( {{"null == null", true}, {"true == false", false}, {"1 == 1", true}, {"-2 == -1", false}, {"1.1 == 1.2", false}, {"'a' == 'a'", true}, {"lhs == rhs", false, CelValue::CreateBytesView("a"), CelValue::CreateBytesView("b")}, {"lhs == rhs", false, CelValue::CreateDuration(absl::Seconds(1)), CelValue::CreateDuration(absl::Seconds(2))}, {"lhs == rhs", true, CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), CelValue::CreateTimestamp(absl::FromUnixSeconds(20))}, // This should fail before getting to the equal operator. {"no_such_identifier == 1", EqualityTestCase::ErrorKind::kMissingIdentifier}, {"{1: no_such_identifier} == {1: 1}", EqualityTestCase::ErrorKind::kMissingIdentifier}}), // heterogeneous equality enabled testing::Bool())); INSTANTIATE_TEST_SUITE_P( Inequality, EqualityFunctionTest, Combine(testing::ValuesIn( {{"null != null", false}, {"true != false", true}, {"1 != 1", false}, {"-2 != -1", true}, {"1.1 != 1.2", true}, {"'a' != 'a'", false}, {"lhs != rhs", true, CelValue::CreateBytesView("a"), CelValue::CreateBytesView("b")}, {"lhs != rhs", true, CelValue::CreateDuration(absl::Seconds(1)), CelValue::CreateDuration(absl::Seconds(2))}, {"lhs != rhs", true, CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), CelValue::CreateTimestamp(absl::FromUnixSeconds(30))}, // This should fail before getting to the equal operator. {"no_such_identifier != 1", EqualityTestCase::ErrorKind::kMissingIdentifier}, {"{1: no_such_identifier} != {1: 1}", EqualityTestCase::ErrorKind::kMissingIdentifier}}), // heterogeneous equality enabled testing::Bool())); INSTANTIATE_TEST_SUITE_P(HeterogeneousNumericContainers, EqualityFunctionTest, Combine(testing::ValuesIn({ {"{1: 2} == {1u: 2}", true}, {"{1: 2} == {2u: 2}", false}, {"{1: 2} == {true: 2}", false}, {"{1: 2} != {1u: 2}", false}, {"{1: 2} != {2u: 2}", true}, {"{1: 2} != {true: 2}", true}, {"[1u, 2u, 3.0] != [1, 2.0, 3]", false}, {"[1u, 2u, 3.0] == [1, 2.0, 3]", true}, {"[1u, 2u, 3.0] != [1, 2.1, 3]", true}, {"[1u, 2u, 3.0] == [1, 2.1, 3]", false}, }), // heterogeneous equality enabled testing::Values(true))); INSTANTIATE_TEST_SUITE_P( HomogenousNumericContainers, EqualityFunctionTest, Combine(testing::ValuesIn({ {"{1: 2} == {1u: 2}", false}, {"{1: 2} == {2u: 2}", false}, {"{1: 2} == {true: 2}", false}, {"{1: 2} != {1u: 2}", true}, {"{1: 2} != {2u: 2}", true}, {"{1: 2} != {true: 2}", true}, {"[1u, 2u, 3.0] != [1, 2.0, 3]", EqualityTestCase::ErrorKind::kMissingOverload}, {"[1u, 2u, 3.0] == [1, 2.0, 3]", EqualityTestCase::ErrorKind::kMissingOverload}, {"[1u, 2u, 3.0] != [1, 2.1, 3]", EqualityTestCase::ErrorKind::kMissingOverload}, {"[1u, 2u, 3.0] == [1, 2.1, 3]", EqualityTestCase::ErrorKind::kMissingOverload}, }), // heterogeneous equality enabled testing::Values(false))); INSTANTIATE_TEST_SUITE_P( NullInequalityLegacy, EqualityFunctionTest, Combine(testing::ValuesIn( {{"null != null", false}, {"true != null", EqualityTestCase::ErrorKind::kMissingOverload}, {"1 != null", EqualityTestCase::ErrorKind::kMissingOverload}, {"-2 != null", EqualityTestCase::ErrorKind::kMissingOverload}, {"1.1 != null", EqualityTestCase::ErrorKind::kMissingOverload}, {"'a' != null", EqualityTestCase::ErrorKind::kMissingOverload}, {"lhs != null", EqualityTestCase::ErrorKind::kMissingOverload, CelValue::CreateBytesView("a")}, {"lhs != null", EqualityTestCase::ErrorKind::kMissingOverload, CelValue::CreateDuration(absl::Seconds(1))}, {"lhs != null", EqualityTestCase::ErrorKind::kMissingOverload, CelValue::CreateTimestamp(absl::FromUnixSeconds(20))}}), // heterogeneous equality enabled testing::Values(false))); INSTANTIATE_TEST_SUITE_P( NullEqualityLegacy, EqualityFunctionTest, Combine(testing::ValuesIn( {{"null == null", true}, {"true == null", EqualityTestCase::ErrorKind::kMissingOverload}, {"1 == null", EqualityTestCase::ErrorKind::kMissingOverload}, {"-2 == null", EqualityTestCase::ErrorKind::kMissingOverload}, {"1.1 == null", EqualityTestCase::ErrorKind::kMissingOverload}, {"'a' == null", EqualityTestCase::ErrorKind::kMissingOverload}, {"lhs == null", EqualityTestCase::ErrorKind::kMissingOverload, CelValue::CreateBytesView("a")}, {"lhs == null", EqualityTestCase::ErrorKind::kMissingOverload, CelValue::CreateDuration(absl::Seconds(1))}, {"lhs == null", EqualityTestCase::ErrorKind::kMissingOverload, CelValue::CreateTimestamp(absl::FromUnixSeconds(20))}}), // heterogeneous equality enabled testing::Values(false))); INSTANTIATE_TEST_SUITE_P( NullInequality, EqualityFunctionTest, Combine(testing::ValuesIn( {{"null != null", false}, {"true != null", true}, {"null != false", true}, {"1 != null", true}, {"null != 1", true}, {"-2 != null", true}, {"null != -2", true}, {"1.1 != null", true}, {"null != 1.1", true}, {"'a' != null", true}, {"lhs != null", true, CelValue::CreateBytesView("a")}, {"lhs != null", true, CelValue::CreateDuration(absl::Seconds(1))}, {"google.api.expr.runtime.TestMessage{} != null", true}, {"google.api.expr.runtime.TestMessage{}.string_wrapper_value" " != null", false}, {"google.api.expr.runtime.TestMessage{string_wrapper_value: " "google.protobuf.StringValue{}}.string_wrapper_value != null", true}, {"{} != null", true}, {"[] != null", true}}), // heterogeneous equality enabled testing::Values(true))); INSTANTIATE_TEST_SUITE_P( NullEquality, EqualityFunctionTest, Combine(testing::ValuesIn({ {"null == null", true}, {"true == null", false}, {"null == false", false}, {"1 == null", false}, {"null == 1", false}, {"-2 == null", false}, {"null == -2", false}, {"1.1 == null", false}, {"null == 1.1", false}, {"'a' == null", false}, {"lhs == null", false, CelValue::CreateBytesView("a")}, {"lhs == null", false, CelValue::CreateDuration(absl::Seconds(1))}, {"google.api.expr.runtime.TestMessage{} == null", false}, {"google.api.expr.runtime.TestMessage{}.string_wrapper_value" " == null", true}, {"google.api.expr.runtime.TestMessage{string_wrapper_value: " "google.protobuf.StringValue{}}.string_wrapper_value == null", false}, {"{} == null", false}, {"[] == null", false}, }), // heterogeneous equality enabled testing::Values(true))); INSTANTIATE_TEST_SUITE_P( ProtoEquality, EqualityFunctionTest, Combine(testing::ValuesIn({ {"google.api.expr.runtime.TestMessage{} == null", false}, {"google.api.expr.runtime.TestMessage{string_wrapper_value: " "google.protobuf.StringValue{}}.string_wrapper_value == ''", true}, {"google.api.expr.runtime.TestMessage{" "int64_wrapper_value: " "google.protobuf.Int64Value{value: 1}," "double_value: 1.1} == " "google.api.expr.runtime.TestMessage{" "int64_wrapper_value: " "google.protobuf.Int64Value{value: 1}," "double_value: 1.1}", true}, // ProtoDifferencer::Equals distinguishes set fields vs // defaulted {"google.api.expr.runtime.TestMessage{" "string_wrapper_value: google.protobuf.StringValue{}} == " "google.api.expr.runtime.TestMessage{}", false}, // Differently typed messages inequal. {"google.api.expr.runtime.TestMessage{} == " "google.rpc.context.AttributeContext{}", false}, }), // heterogeneous equality enabled testing::Values(true))); void RunBenchmark(absl::string_view expr, benchmark::State& benchmark) { InterpreterOptions opts; auto builder = CreateCelExpressionBuilder(opts); ASSERT_THAT(RegisterEqualityFunctions(builder->GetRegistry(), opts), IsOk()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(expr)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(auto plan, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : benchmark) { ASSERT_OK_AND_ASSIGN(auto result, plan->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); } } void RunIdentBenchmark(const CelValue& lhs, const CelValue& rhs, benchmark::State& benchmark) { InterpreterOptions opts; auto builder = CreateCelExpressionBuilder(opts); ASSERT_THAT(RegisterEqualityFunctions(builder->GetRegistry(), opts), IsOk()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("lhs == rhs")); google::protobuf::Arena arena; Activation activation; activation.InsertValue("lhs", lhs); activation.InsertValue("rhs", rhs); ASSERT_OK_AND_ASSIGN(auto plan, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : benchmark) { ASSERT_OK_AND_ASSIGN(auto result, plan->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); } } void BM_EqualsInt(benchmark::State& s) { RunBenchmark("42 == 43", s); } BENCHMARK(BM_EqualsInt); void BM_EqualsString(benchmark::State& s) { RunBenchmark("'1234' == '1235'", s); } BENCHMARK(BM_EqualsString); void BM_EqualsCreatedList(benchmark::State& s) { RunBenchmark("[1, 2, 3, 4, 5] == [1, 2, 3, 4, 6]", s); } BENCHMARK(BM_EqualsCreatedList); void BM_EqualsBoundLegacyList(benchmark::State& s) { ContainerBackedListImpl lhs( {CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3), CelValue::CreateInt64(4), CelValue::CreateInt64(5)}); ContainerBackedListImpl rhs( {CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3), CelValue::CreateInt64(4), CelValue::CreateInt64(6)}); RunIdentBenchmark(CelValue::CreateList(&lhs), CelValue::CreateList(&rhs), s); } BENCHMARK(BM_EqualsBoundLegacyList); void BM_EqualsCreatedMap(benchmark::State& s) { RunBenchmark("{1: 2, 2: 3, 3: 6} == {1: 2, 2: 3, 3: 6}", s); } BENCHMARK(BM_EqualsCreatedMap); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/extension_func_registrar.cc ================================================ #include "eval/public/extension_func_registrar.h" #include #include #include #include "google/type/timeofday.pb.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/civil_time.h" #include "absl/time/time.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace google { namespace api { namespace expr { namespace runtime { using google::protobuf::Arena; CelValue BetweenTs(Arena* arena, absl::Time time_stamp, absl::Time start, absl::Time stop) { bool is_between = false; // check if timestamp paremeter is between start and stop parameters is_between = (start <= time_stamp) && (time_stamp < stop); return CelValue::CreateBool(is_between); } CelValue BetweenStr(Arena* arena, absl::Time time_stamp, absl::string_view start, absl::string_view stop) { // convert start and stop into timestamps absl::Time start_ts; absl::Time stop_ts; // check if timestamp parameter is between start and stop -> call BetweenTs if (!absl::ParseTime(absl::RFC3339_full, start, &start_ts, nullptr) || !absl::ParseTime(absl::RFC3339_full, stop, &stop_ts, nullptr)) { return CreateErrorValue(arena, "String to Timestamp conversion failed", absl::StatusCode::kInvalidArgument); } return BetweenTs(arena, time_stamp, start_ts, stop_ts); } CelValue GetDateTz(Arena* arena, absl::Time time_stamp, absl::TimeZone time_zone) { absl::Time ret_date; absl::CivilDay normalized_date; // convert absl time to civil day, which normalizes to midnight time, // convert the result to CivilSecond // convert CivilSecond from previous step back into absl::Time normalized_date = absl::ToCivilDay(time_stamp, time_zone); absl::CivilSecond normalized_date_cs(normalized_date); ret_date = absl::FromCivil(normalized_date_cs, time_zone); return CelValue::CreateTimestamp(ret_date); } CelValue GetDate(Arena* arena, absl::Time time_stamp, absl::string_view time_zone) { absl::TimeZone time_zone_tz; // convert timezone from string to TimeZone if (!absl::LoadTimeZone(time_zone, &time_zone_tz)) { return CreateErrorValue(arena, "String to Timezone conversion failed", absl::StatusCode::kInvalidArgument); } return GetDateTz(arena, time_stamp, time_zone_tz); } CelValue GetDateUTC(Arena* arena, absl::Time time_stamp) { absl::TimeZone time_zone = absl::UTCTimeZone(); return GetDateTz(arena, time_stamp, time_zone); } CelValue GetTimeOfDayTz(Arena* arena, absl::Time time_stamp, absl::TimeZone time_zone) { absl::CivilSecond date_civil_time = absl::ToCivilSecond(time_stamp, time_zone); google::type::TimeOfDay* tod_message = Arena::Create(arena); tod_message->set_seconds(date_civil_time.second()); tod_message->set_minutes(date_civil_time.minute()); tod_message->set_hours(date_civil_time.hour()); // transform into celvalue for return return CelProtoWrapper::CreateMessage(tod_message, arena); } CelValue GetTimeOfDay(Arena* arena, absl::Time time_stamp, absl::string_view time_zone) { absl::TimeZone time_zonetz; if (!absl::LoadTimeZone(time_zone, &time_zonetz)) { return CreateErrorValue(arena, "String to Timezone conversion failed", absl::StatusCode::kInvalidArgument); } return GetTimeOfDayTz(arena, time_stamp, time_zonetz); } CelValue GetTimeOfDayUTC(Arena* arena, absl::Time time_stamp) { absl::TimeZone utc = absl::UTCTimeZone(); // call to helper function GetTimeOfDayTz // return value from helper return GetTimeOfDayTz(arena, time_stamp, utc); } int ToSeconds(const google::type::TimeOfDay* time_of_day) { int seconds = 0; seconds += time_of_day->hours() * 60 * 60; seconds += time_of_day->minutes() * 60; seconds += time_of_day->seconds(); return seconds; } CelValue BetweenToD(Arena* arena, const google::protobuf::Message* time_of_day, const google::protobuf::Message* start, const google::protobuf::Message* stop) { bool is_between; const google::type::TimeOfDay* time_of_day_tod = google::protobuf::DynamicCastMessage(time_of_day); const google::type::TimeOfDay* start_tod = google::protobuf::DynamicCastMessage(start); const google::type::TimeOfDay* stop_tod = google::protobuf::DynamicCastMessage(stop); if ((time_of_day_tod == nullptr) || (start_tod == nullptr) || (stop_tod == nullptr)) { return CreateErrorValue(arena, "Message type downcast failed", absl::StatusCode::kInvalidArgument); } // resolution for TimeOfDay in this function is 1 second int start_time = ToSeconds(start_tod); int stop_time = ToSeconds(stop_tod); int tod_time = ToSeconds(time_of_day_tod); is_between = (tod_time >= start_time) && (tod_time < stop_time); return CelValue::CreateBool(is_between); } CelValue BetweenToDStr(Arena* arena, const google::protobuf::Message* time_of_day, absl::string_view start, absl::string_view stop) { std::string start_date_time = absl::StrCat("1970-01-01T", start, "+00:00"); std::string stop_date_time = absl::StrCat("1970-01-01T", stop, "+00:00"); absl::Time start_ts; absl::Time stop_ts; // format of time of day string: "HH:MM:SS" // Below we prepend a generic date string and append a generic timezone string // this generates a full timestamp string that can be parsed with ParseTime() if (!absl::ParseTime(absl::RFC3339_sec, start_date_time, absl::UTCTimeZone(), &start_ts, nullptr) || !absl::ParseTime(absl::RFC3339_sec, stop_date_time, absl::UTCTimeZone(), &stop_ts, nullptr)) { return CreateErrorValue(arena, "String to Timestamp conversion failed", absl::StatusCode::kInvalidArgument); } const google::protobuf::Message* start_msg = GetTimeOfDayUTC(arena, start_ts).MessageOrDie(); const google::protobuf::Message* stop_msg = GetTimeOfDayUTC(arena, stop_ts).MessageOrDie(); return BetweenToD(arena, time_of_day, start_msg, stop_msg); } absl::Status RegisterExtensionFunctions(CelFunctionRegistry* registry) { auto status = FunctionAdapter:: CreateAndRegister( "between", true, [](Arena* arena, absl::Time ts, absl::Time start, absl::Time stop) -> CelValue { return BetweenTs(arena, ts, start, stop); }, registry); if (!status.ok()) return status; status = FunctionAdapter:: CreateAndRegister( "between", true, [](Arena* arena, absl::Time ts, CelValue::StringHolder start, CelValue::StringHolder stop) -> CelValue { return BetweenStr(arena, ts, start.value(), stop.value()); }, registry); if (!status.ok()) return status; status = FunctionAdapter:: CreateAndRegister( "date", true, [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) -> CelValue { return GetDate(arena, ts, tz.value()); }, registry); if (!status.ok()) return status; status = FunctionAdapter::CreateAndRegister( "date", true, [](Arena* arena, absl::Time ts) -> CelValue { return GetDateUTC(arena, ts); }, registry); if (!status.ok()) return status; status = FunctionAdapter:: CreateAndRegister( "timeOfDay", true, [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) -> CelValue { return GetTimeOfDay(arena, ts, tz.value()); }, registry); if (!status.ok()) return status; status = FunctionAdapter::CreateAndRegister( "timeOfDay", true, [](Arena* arena, absl::Time ts) -> CelValue { return GetTimeOfDayUTC(arena, ts); }, registry); if (!status.ok()) return status; status = FunctionAdapter:: CreateAndRegister( "between", true, [](Arena* arena, const google::protobuf::Message* tod, const google::protobuf::Message* start, const google::protobuf::Message* stop) -> CelValue { return BetweenToD(arena, tod, start, stop); }, registry); if (!status.ok()) return status; status = FunctionAdapter:: CreateAndRegister( "between", true, [](Arena* arena, const google::protobuf::Message* tod, CelValue::StringHolder start, CelValue::StringHolder stop) -> CelValue { return BetweenToDStr(arena, tod, start.value(), stop.value()); }, registry); return status; } } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/extension_func_registrar.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EXTENSION_FUNC_REGISTRAR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EXTENSION_FUNC_REGISTRAR_H_ #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" namespace google { namespace api { namespace expr { namespace runtime { // Register generic/widely used extension functions. absl::Status RegisterExtensionFunctions(CelFunctionRegistry* registry); } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_EXTENSION_FUNC_REGISTRAR_H_ ================================================ FILE: eval/public/extension_func_test.cc ================================================ #include #include #include "google/type/timeofday.pb.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/civil_time.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "eval/public/extension_func_registrar.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/util/time_util.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using google::protobuf::Arena; static const int kNanosPerSecond = 1000000000; class ExtensionTest : public ::testing::Test { protected: ExtensionTest() {} void SetUp() override { ASSERT_OK(RegisterBuiltinFunctions(®istry_)); ASSERT_OK(RegisterExtensionFunctions(®istry_)); } // Helper method to test string startsWith() function void TestStringInclusion(absl::string_view func_name, const std::vector& call_style, const std::string& test_string, const std::string& included, bool result) { std::vector call_styles = {true, false}; for (auto call_style : call_styles) { auto functions = registry_.FindOverloads( func_name, call_style, {CelValue::Type::kString, CelValue::Type::kString}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateString(&test_string), CelValue::CreateString(&included)}; CelValue result_value = CelValue::CreateNull(); google::protobuf::Arena arena; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, &result_value, &arena); ASSERT_OK(status); ASSERT_TRUE(result_value.IsBool()); ASSERT_EQ(result_value.BoolOrDie(), result); } } void TestStringStartsWith(const std::string& test_string, const std::string& prefix, bool result) { TestStringInclusion("startsWith", {true, false}, test_string, prefix, result); } void TestStringEndsWith(const std::string& test_string, const std::string& prefix, bool result) { TestStringInclusion("endsWith", {true, false}, test_string, prefix, result); } // Helper method to test timestamp() function void PerformTimestampConversion(Arena* arena, const std::string& ts_str, CelValue* result) { auto functions = registry_.FindOverloads("timestamp", false, {CelValue::Type::kString}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateString(&ts_str)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformBetweenTest(Arena* arena, absl::Time time_stamp, absl::Time start_ts, absl::Time stop_ts, CelValue* result) { auto functions = registry_.FindOverloads( "between", true, {CelValue::Type::kTimestamp, CelValue::Type::kTimestamp, CelValue::Type::kTimestamp}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateTimestamp(time_stamp), CelValue::CreateTimestamp(start_ts), CelValue::CreateTimestamp(stop_ts)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformBetweenStrTest(Arena* arena, absl::Time time_stamp, std::string* start, std::string* stop, CelValue* result) { auto functions = registry_.FindOverloads( "between", true, {CelValue::Type::kTimestamp, CelValue::Type::kString, CelValue::Type::kString}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateTimestamp(time_stamp), CelValue::CreateString(start), CelValue::CreateString(stop)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformGetDateTest(Arena* arena, absl::Time time_stamp, std::string* time_zone, CelValue* result) { auto functions = registry_.FindOverloads( "date", true, {CelValue::Type::kTimestamp, CelValue::Type::kString}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateTimestamp(time_stamp), CelValue::CreateString(time_zone)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformGetDateUTCTest(Arena* arena, absl::Time time_stamp, CelValue* result) { auto functions = registry_.FindOverloads("date", true, {CelValue::Type::kTimestamp}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateTimestamp(time_stamp)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformGetTimeOfDayTest(Arena* arena, absl::Time time_stamp, std::string* time_zone, CelValue* result) { auto functions = registry_.FindOverloads( "timeOfDay", true, {CelValue::Type::kTimestamp, CelValue::Type::kString}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateTimestamp(time_stamp), CelValue::CreateString(time_zone)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformGetTimeOfDayUTCTest(Arena* arena, absl::Time time_stamp, CelValue* result) { auto functions = registry_.FindOverloads("timeOfDay", true, {CelValue::Type::kTimestamp}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateTimestamp(time_stamp)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformBetweenToDTest(Arena* arena, const google::protobuf::Message* time_of_day, const google::protobuf::Message* start, const google::protobuf::Message* stop, CelValue* result) { auto functions = registry_.FindOverloads( "between", true, {CelValue::Type::kMessage, CelValue::Type::kMessage, CelValue::Type::kMessage}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = { CelProtoWrapper::CreateMessage(time_of_day, arena), CelProtoWrapper::CreateMessage(start, arena), CelProtoWrapper::CreateMessage(stop, arena)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformBetweenToDStrTest(Arena* arena, const google::protobuf::Message* time_of_day, std::string* start, std::string* stop, CelValue* result) { auto functions = registry_.FindOverloads( "between", true, {CelValue::Type::kMessage, CelValue::Type::kString, CelValue::Type::kString}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = { CelProtoWrapper::CreateMessage(time_of_day, arena), CelValue::CreateString(start), CelValue::CreateString(stop)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } // Helper method to test duration() function void PerformDurationConversion(Arena* arena, const std::string& ts_str, CelValue* result) { auto functions = registry_.FindOverloads("duration", false, {CelValue::Type::kString}); ASSERT_EQ(functions.size(), 1); auto func = functions[0]; std::vector args = {CelValue::CreateString(&ts_str)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } // Function registry object CelFunctionRegistry registry_; Arena arena_; }; // Test string startsWith() function. TEST_F(ExtensionTest, TestStartsWithFunction) { // Empty string, non-empty prefix - never matches. EXPECT_NO_FATAL_FAILURE(TestStringStartsWith("", "p", false)); // Prefix of 0 length - always matches. EXPECT_NO_FATAL_FAILURE(TestStringStartsWith("", "", true)); EXPECT_NO_FATAL_FAILURE(TestStringStartsWith("prefixedString", "", true)); // Non-empty matching prefix. EXPECT_NO_FATAL_FAILURE( TestStringStartsWith("prefixedString", "prefix", true)); // Non-empty mismatching prefix. EXPECT_NO_FATAL_FAILURE(TestStringStartsWith("prefixedString", "x", false)); EXPECT_NO_FATAL_FAILURE( TestStringStartsWith("prefixedString", "prefixedString1", false)); } // Test string startsWith() function. TEST_F(ExtensionTest, TestEndsWithFunction) { // Empty string, non-empty postfix - never matches. EXPECT_NO_FATAL_FAILURE(TestStringEndsWith("", "p", false)); // Postfix of 0 length - always matches. EXPECT_NO_FATAL_FAILURE(TestStringEndsWith("", "", true)); EXPECT_NO_FATAL_FAILURE(TestStringEndsWith("postfixedString", "", true)); // Non-empty matching postfix. EXPECT_NO_FATAL_FAILURE( TestStringEndsWith("postfixedString", "String", true)); // Non-empty mismatching post. EXPECT_NO_FATAL_FAILURE(TestStringEndsWith("postfixedString", "x", false)); EXPECT_NO_FATAL_FAILURE( TestStringEndsWith("postfixedString", "1postfixedString", false)); } // Test timestamp conversion function. TEST_F(ExtensionTest, TestTimestampFromString) { CelValue result = CelValue::CreateNull(); Arena arena; // Valid timestamp - no fractions of seconds. EXPECT_NO_FATAL_FAILURE( PerformTimestampConversion(&arena, "2000-01-01T00:00:00Z", &result)); ASSERT_TRUE(result.IsTimestamp()); auto ts = result.TimestampOrDie(); ASSERT_EQ(absl::ToUnixSeconds(ts), 946684800L); ASSERT_EQ(absl::ToUnixNanos(ts), 946684800L * kNanosPerSecond); // Valid timestamp - with nanoseconds. EXPECT_NO_FATAL_FAILURE( PerformTimestampConversion(&arena, "2000-01-01T00:00:00.212Z", &result)); ASSERT_TRUE(result.IsTimestamp()); ts = result.TimestampOrDie(); ASSERT_EQ(absl::ToUnixSeconds(ts), 946684800L); ASSERT_EQ(absl::ToUnixNanos(ts), 946684800L * kNanosPerSecond + 212000000); // Valid timestamp - with timezone. EXPECT_NO_FATAL_FAILURE(PerformTimestampConversion( &arena, "2000-01-01T00:00:00.212-01:00", &result)); ASSERT_TRUE(result.IsTimestamp()); ts = result.TimestampOrDie(); ASSERT_EQ(absl::ToUnixSeconds(ts), 946688400L); ASSERT_EQ(absl::ToUnixNanos(ts), 946688400L * kNanosPerSecond + 212000000); // Invalid timestamp - empty string. EXPECT_NO_FATAL_FAILURE(PerformTimestampConversion(&arena, "", &result)); ASSERT_TRUE(result.IsError()); ASSERT_EQ(result.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); // Invalid timestamp. EXPECT_NO_FATAL_FAILURE( PerformTimestampConversion(&arena, "2000-01-01TT00:00:00Z", &result)); ASSERT_TRUE(result.IsError()); } // Test duration conversion function. TEST_F(ExtensionTest, TestDurationFromString) { CelValue result = CelValue::CreateNull(); Arena arena; // Valid duration - no fractions of seconds. EXPECT_NO_FATAL_FAILURE(PerformDurationConversion(&arena, "1354s", &result)); ASSERT_TRUE(result.IsDuration()); auto d = result.DurationOrDie(); ASSERT_EQ(absl::ToInt64Seconds(d), 1354L); ASSERT_EQ(absl::ToInt64Nanoseconds(d), 1354L * kNanosPerSecond); // Valid duration - with nanoseconds. EXPECT_NO_FATAL_FAILURE(PerformDurationConversion(&arena, "15.11s", &result)); ASSERT_TRUE(result.IsDuration()); d = result.DurationOrDie(); ASSERT_EQ(absl::ToInt64Seconds(d), 15L); ASSERT_EQ(absl::ToInt64Nanoseconds(d), 15L * kNanosPerSecond + 110000000L); // Invalid duration - empty string. EXPECT_NO_FATAL_FAILURE(PerformDurationConversion(&arena, "", &result)); ASSERT_TRUE(result.IsError()); ASSERT_EQ(result.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); // Invalid duration. EXPECT_NO_FATAL_FAILURE(PerformDurationConversion(&arena, "100", &result)); ASSERT_TRUE(result.IsError()); } TEST_F(ExtensionTest, TestBetweenTs) { absl::Time time_1; absl::Time time_2; absl::Time time_3; std::string time_stampstr = "1997-07-16T19:50:30.45+01:00"; std::string time_start = "1997-07-16T19:20:30.45+01:00"; std::string time_stop = "1997-07-16T20:20:30.45+01:00"; Arena arena; CelValue result; absl::ParseTime(absl::RFC3339_full, time_stampstr, &time_2, nullptr); absl::ParseTime(absl::RFC3339_full, time_start, &time_1, nullptr); absl::ParseTime(absl::RFC3339_full, time_stop, &time_3, nullptr); PerformBetweenTest(&arena, time_2, time_1, time_3, &result); ASSERT_EQ(result.BoolOrDie(), true); PerformBetweenTest(&arena, time_1, time_2, time_3, &result); ASSERT_EQ(result.BoolOrDie(), false); PerformBetweenTest(&arena, time_1, time_1, time_3, &result); ASSERT_EQ(result.BoolOrDie(), true); PerformBetweenTest(&arena, time_3, time_1, time_2, &result); ASSERT_EQ(result.BoolOrDie(), false); PerformBetweenTest(&arena, time_3, time_1, time_3, &result); ASSERT_EQ(result.BoolOrDie(), false); } TEST_F(ExtensionTest, TestBetweenStr) { Arena arena; absl::Time time_stamp; CelValue result; std::string time_stampstr = "1997-07-16T19:50:30.45+01:00"; std::string time_start = "1997-07-16T19:20:30.45+01:00"; std::string time_stop = "1997-07-16T20:20:30.45+01:00"; absl::ParseTime(absl::RFC3339_full, time_stampstr, &time_stamp, nullptr); PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); ASSERT_EQ(result.BoolOrDie(), true); absl::ParseTime(absl::RFC3339_full, time_start, &time_stamp, nullptr); PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); ASSERT_EQ(result.BoolOrDie(), true); absl::ParseTime(absl::RFC3339_full, time_stop, &time_stamp, nullptr); PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); ASSERT_EQ(result.BoolOrDie(), false); time_stampstr = "1997-07-16T18:20:30.45+01:00"; absl::ParseTime(absl::RFC3339_full, time_stampstr, &time_stamp, nullptr); PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); ASSERT_EQ(result.BoolOrDie(), false); time_stampstr = "1997-07-16T21:20:30.45+01:00"; absl::ParseTime(absl::RFC3339_full, time_stampstr, &time_stamp, nullptr); PerformBetweenStrTest(&arena, time_stamp, &time_start, &time_stop, &result); ASSERT_EQ(result.BoolOrDie(), false); } TEST_F(ExtensionTest, TestGetDate) { Arena arena; CelValue result; absl::CivilSecond date(2015, 2, 3, 4, 5, 6); absl::CivilSecond normal_date(2015, 2, 3); absl::TimeZone time_zone; std::string time_zonestr = "America/Los_Angeles"; absl::LoadTimeZone(time_zonestr, &time_zone); absl::Time expected_val = absl::FromCivil(normal_date, time_zone); absl::Time input_val = absl::FromCivil(date, time_zone); PerformGetDateTest(&arena, input_val, &time_zonestr, &result); ASSERT_EQ(result.TimestampOrDie(), expected_val); } TEST_F(ExtensionTest, TestGetDateUTC) { Arena arena; CelValue result; absl::CivilSecond date(2015, 2, 3, 4, 5, 6); absl::CivilSecond normal_date(2015, 2, 3); absl::TimeZone time_zone = absl::UTCTimeZone(); absl::Time expected_val = absl::FromCivil(normal_date, time_zone); absl::Time input_val = absl::FromCivil(date, time_zone); PerformGetDateUTCTest(&arena, input_val, &result); ASSERT_EQ(result.TimestampOrDie(), expected_val); } TEST_F(ExtensionTest, TestGetTimeOfDay) { Arena arena; CelValue result; absl::CivilSecond date(2015, 2, 3, 4, 5, 6); absl::TimeZone time_zone; std::string time_zonestr = "America/Los_Angeles"; google::type::TimeOfDay* tod_message = Arena::Create(&arena); absl::LoadTimeZone(time_zonestr, &time_zone); absl::Time input_val = absl::FromCivil(date, time_zone); tod_message->set_seconds(date.second()); tod_message->set_minutes(date.minute()); tod_message->set_hours(date.hour()); PerformGetTimeOfDayTest(&arena, input_val, &time_zonestr, &result); const google::type::TimeOfDay* time_of_day_tod = google::protobuf::DynamicCastMessage( result.MessageOrDie()); ASSERT_EQ(time_of_day_tod->seconds(), tod_message->seconds()); ASSERT_EQ(time_of_day_tod->minutes(), tod_message->minutes()); ASSERT_EQ(time_of_day_tod->hours(), tod_message->hours()); } TEST_F(ExtensionTest, TestGetTimeOfDayUTC) { Arena arena; CelValue result; absl::TimeZone time_zone = absl::UTCTimeZone(); absl::CivilSecond date(2015, 2, 3, 4, 5, 6); absl::Time input_time = absl::FromCivil(date, time_zone); google::type::TimeOfDay* tod_message = Arena::Create(&arena); tod_message->set_seconds(date.second()); tod_message->set_minutes(date.minute()); tod_message->set_hours(date.hour()); PerformGetTimeOfDayUTCTest(&arena, input_time, &result); const google::type::TimeOfDay* time_of_day_tod = google::protobuf::DynamicCastMessage( result.MessageOrDie()); ASSERT_EQ(time_of_day_tod->seconds(), tod_message->seconds()); ASSERT_EQ(time_of_day_tod->minutes(), tod_message->minutes()); ASSERT_EQ(time_of_day_tod->hours(), tod_message->hours()); } TEST_F(ExtensionTest, TestBetweenToD) { Arena arena; CelValue result; google::type::TimeOfDay* time_of_day = Arena::Create(&arena); google::type::TimeOfDay* start = Arena::Create(&arena); google::type::TimeOfDay* stop = Arena::Create(&arena); start->set_hours(20); start->set_minutes(0); start->set_seconds(0); stop->set_hours(21); stop->set_minutes(0); stop->set_seconds(0); time_of_day->set_hours(20); time_of_day->set_minutes(30); time_of_day->set_seconds(0); PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); ASSERT_EQ(result.BoolOrDie(), true); time_of_day->set_minutes(0); PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); ASSERT_EQ(result.BoolOrDie(), true); time_of_day->set_hours(19); PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); ASSERT_EQ(result.BoolOrDie(), false); time_of_day->set_hours(21); PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); ASSERT_EQ(result.BoolOrDie(), false); time_of_day->set_seconds(1); PerformBetweenToDTest(&arena, time_of_day, start, stop, &result); ASSERT_EQ(result.BoolOrDie(), false); } TEST_F(ExtensionTest, TestBetweenTodStr) { Arena arena; CelValue result; std::string start = "18:20:30"; std::string stop = "19:20:30"; google::type::TimeOfDay* time_of_day = Arena::Create(&arena); time_of_day->set_hours(19); time_of_day->set_minutes(0); time_of_day->set_seconds(0); PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); ASSERT_EQ(result.BoolOrDie(), true); time_of_day->set_hours(18); time_of_day->set_minutes(20); time_of_day->set_seconds(30); PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); ASSERT_EQ(result.BoolOrDie(), true); time_of_day->set_seconds(29); PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); ASSERT_EQ(result.BoolOrDie(), false); time_of_day->set_hours(19); time_of_day->set_minutes(20); time_of_day->set_seconds(30); PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); ASSERT_EQ(result.BoolOrDie(), false); time_of_day->set_seconds(29); PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); ASSERT_EQ(result.BoolOrDie(), true); time_of_day->set_seconds(31); PerformBetweenToDStrTest(&arena, time_of_day, &start, &stop, &result); ASSERT_EQ(result.BoolOrDie(), false); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/logical_function_registrar.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/logical_function_registrar.h" #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/standard/logical_functions.h" namespace google::api::expr::runtime { absl::Status RegisterLogicalFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { return cel::RegisterLogicalFunctions(registry->InternalGetRegistry(), ConvertToRuntimeOptions(options)); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/logical_function_registrar.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_LOGICAL_FUNCTION_REGISTRAR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_LOGICAL_FUNCTION_REGISTRAR_H_ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" namespace google::api::expr::runtime { // Register logical operators ! and @not_strictly_false. // // &&, ||, ?: are special cased by the interpreter (not implemented via the // function registry.) // // Most users should use RegisterBuiltinFunctions, which includes these // definitions. absl::Status RegisterLogicalFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_LOGICAL_FUNCTION_REGISTRAR_H_ ================================================ FILE: eval/public/logical_function_registrar_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/logical_function_registrar.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/public/activation.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/portable_cel_function_adapter.h" #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using cel::expr::Expr; using cel::expr::SourceInfo; using ::absl_testing::StatusIs; using ::testing::HasSubstr; struct TestCase { std::string test_name; std::string expr; absl::StatusOr result = CelValue::CreateBool(true); }; const CelError* ExampleError() { static absl::NoDestructor error( absl::InternalError("test example error")); return &*error; } void ExpectResult(const TestCase& test_case) { auto parsed_expr = parser::Parse(test_case.expr); ASSERT_OK(parsed_expr); const Expr& expr_ast = parsed_expr->expr(); const SourceInfo& source_info = parsed_expr->source_info(); InterpreterOptions options; options.short_circuiting = true; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterLogicalFunctions(builder->GetRegistry(), options)); ASSERT_OK(builder->GetRegistry()->Register( PortableUnaryFunctionAdapter::Create( "toBool", false, [](google::protobuf::Arena*, CelValue::StringHolder holder) -> CelValue { if (holder.value() == "true") { return CelValue::CreateBool(true); } else if (holder.value() == "false") { return CelValue::CreateBool(false); } return CelValue::CreateError(ExampleError()); }))); ASSERT_OK_AND_ASSIGN(auto cel_expression, builder->CreateExpression(&expr_ast, &source_info)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(auto value, cel_expression->Evaluate(activation, &arena)); if (!test_case.result.ok()) { EXPECT_TRUE(value.IsError()); EXPECT_THAT(*value.ErrorOrDie(), StatusIs(test_case.result.status().code(), HasSubstr(test_case.result.status().message()))); return; } EXPECT_THAT(value, test::EqualsCelValue(*test_case.result)); } using BuiltinFuncParamsTest = testing::TestWithParam; TEST_P(BuiltinFuncParamsTest, StandardFunctions) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( BuiltinFuncParamsTest, BuiltinFuncParamsTest, testing::ValuesIn({ // Legacy duration and timestamp arithmetic tests. {"LogicalNotOfTrue", "!true", CelValue::CreateBool(false)}, {"LogicalNotOfFalse", "!false", CelValue::CreateBool(true)}, // Not strictly false is an internal function for implementing logical // shortcutting in comprehensions. {"NotStrictlyFalseTrue", "[true, true, true].all(x, x)", CelValue::CreateBool(true)}, // List creation is eager so use an extension function to introduce an // error. {"NotStrictlyFalseErrorShortcircuit", "['true', 'false', 'error'].all(x, toBool(x))", CelValue::CreateBool(false)}, {"NotStrictlyFalseError", "['true', 'true', 'error'].all(x, toBool(x))", CelValue::CreateError(ExampleError())}, {"NotStrictlyFalseFalse", "[false, false, false].all(x, x)", CelValue::CreateBool(false)}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; }); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/message_wrapper.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ #include #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/numeric/bits.h" #include "base/internal/message_wrapper.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" namespace cel::interop_internal { struct MessageWrapperAccess; } // namespace cel::interop_internal namespace google::api::expr::runtime { // Forward declare to resolve cycle. class LegacyTypeInfoApis; // Wrapper type for protobuf messages. This is used to limit internal usages of // proto APIs and to support working with the proto lite runtime. // // Provides operations for checking if down-casting to Message is safe. class ABSL_DEPRECATED("Use google::protobuf::Message directly") MessageWrapper { public: // Simple builder class. // // Wraps a tagged mutable message lite ptr. class ABSL_DEPRECATED("Use google::protobuf::Message directly") Builder { public: explicit Builder(google::protobuf::MessageLite* message) : message_ptr_(reinterpret_cast(message)) { ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= kTagSize); } explicit Builder(google::protobuf::Message* message) : message_ptr_(reinterpret_cast(message) | kMessageTag) { ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= kTagSize); } google::protobuf::MessageLite* message_ptr() const { return reinterpret_cast(message_ptr_ & kPtrMask); } bool HasFullProto() const { return (message_ptr_ & kTagMask) == kMessageTag; } MessageWrapper Build(const LegacyTypeInfoApis* type_info) { return MessageWrapper(message_ptr_, type_info); } private: friend class MessageWrapper; explicit Builder(uintptr_t message_ptr) : message_ptr_(message_ptr) {} uintptr_t message_ptr_; }; static_assert(alignof(google::protobuf::MessageLite) >= 2, "Assume that valid MessageLite ptrs have a free low-order bit"); MessageWrapper() : message_ptr_(0), legacy_type_info_(nullptr) {} MessageWrapper(const google::protobuf::MessageLite* message, const LegacyTypeInfoApis* legacy_type_info) : message_ptr_(reinterpret_cast(message)), legacy_type_info_(legacy_type_info) { ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= kTagSize); } MessageWrapper(const google::protobuf::Message* message, const LegacyTypeInfoApis* legacy_type_info) : message_ptr_(reinterpret_cast(message) | kMessageTag), legacy_type_info_(legacy_type_info) { ABSL_ASSERT(absl::countr_zero(reinterpret_cast(message)) >= kTagSize); } // If true, the message is using the full proto runtime and downcasting to // message should be safe. bool HasFullProto() const { return (message_ptr_ & kTagMask) == kMessageTag; } // Returns the underlying message. // // Clients must check HasFullProto before downcasting to Message. const google::protobuf::MessageLite* message_ptr() const { return reinterpret_cast(message_ptr_ & kPtrMask); } // Type information associated with this message. const LegacyTypeInfoApis* legacy_type_info() const { return legacy_type_info_; } private: friend struct ::cel::interop_internal::MessageWrapperAccess; MessageWrapper(uintptr_t message_ptr, const LegacyTypeInfoApis* legacy_type_info) : message_ptr_(message_ptr), legacy_type_info_(legacy_type_info) {} Builder ToBuilder() { return Builder(message_ptr_); } static constexpr int kTagSize = ::cel::base_internal::kMessageWrapperTagSize; static constexpr uintptr_t kTagMask = ::cel::base_internal::kMessageWrapperTagMask; static constexpr uintptr_t kPtrMask = ::cel::base_internal::kMessageWrapperPtrMask; static constexpr uintptr_t kMessageTag = ::cel::base_internal::kMessageWrapperTagMessageValue; uintptr_t message_ptr_; const LegacyTypeInfoApis* legacy_type_info_; }; static_assert(sizeof(MessageWrapper) <= 2 * sizeof(uintptr_t), "MessageWrapper must not increase CelValue size."); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_MESSAGE_WRAPPER_H_ ================================================ FILE: eval/public/message_wrapper_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/message_wrapper.h" #include #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" #include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" namespace google::api::expr::runtime { namespace { TEST(MessageWrapper, Size) { static_assert(sizeof(MessageWrapper) <= 2 * sizeof(uintptr_t), "MessageWrapper must not increase CelValue size."); } TEST(MessageWrapper, WrapsMessage) { TestMessage test_message; test_message.set_int64_value(20); test_message.set_double_value(12.3); MessageWrapper wrapped_message(&test_message, TrivialTypeInfo::GetInstance()); constexpr bool is_full_proto_runtime = std::is_base_of_v; EXPECT_EQ(wrapped_message.message_ptr(), static_cast(&test_message)); ASSERT_EQ(wrapped_message.HasFullProto(), is_full_proto_runtime); } TEST(MessageWrapperBuilder, Builder) { TestMessage test_message; MessageWrapper::Builder builder(&test_message); constexpr bool is_full_proto_runtime = std::is_base_of_v; ASSERT_EQ(builder.HasFullProto(), is_full_proto_runtime); ASSERT_EQ(builder.message_ptr(), static_cast(&test_message)); auto mutable_message = cel::internal::down_cast(builder.message_ptr()); mutable_message->set_int64_value(20); mutable_message->set_double_value(12.3); MessageWrapper wrapped_message = builder.Build(TrivialTypeInfo::GetInstance()); ASSERT_EQ(wrapped_message.message_ptr(), static_cast(&test_message)); ASSERT_EQ(wrapped_message.HasFullProto(), is_full_proto_runtime); EXPECT_EQ(wrapped_message.message_ptr(), static_cast(&test_message)); EXPECT_EQ(test_message.int64_value(), 20); EXPECT_EQ(test_message.double_value(), 12.3); } TEST(MessageWrapper, DefaultNull) { MessageWrapper wrapper; EXPECT_EQ(wrapper.message_ptr(), nullptr); EXPECT_EQ(wrapper.legacy_type_info(), nullptr); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/portable_cel_function_adapter.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ #include "eval/public/cel_function_adapter.h" namespace google::api::expr::runtime { // Portable version of the FunctionAdapter template utility. // // The PortableFunctionAdapter variation provides the same interface, // but doesn't support unwrapping google::protobuf::Message values. See documentation on // Function adapter for example usage. // // Most users should prefer using the standard FunctionAdapter. template using PortableFunctionAdapter = FunctionAdapter; // PortableUnaryFunctionAdapter provides a factory for adapting 1 argument // functions to CEL extension functions. // // Static Methods: // // Create(absl::string_view function_name, bool receiver_style, // FunctionType func) -> std::unique_ptr // // Usage example: // // auto func = [](::google::protobuf::Arena* arena, int64_t i) -> int64_t { // return -i; // }; // // auto cel_func = // PortableUnaryFunctionAdapter::Create("negate", true, // func); template using PortableUnaryFunctionAdapter = UnaryFunctionAdapter; // PortableBinaryFunctionAdapter provides a factory for adapting 2 argument // functions to CEL extension functions. // // Create(absl::string_view function_name, bool receiver_style, // FunctionType func) -> std::unique_ptr // // Usage example: // // auto func = [](::google::protobuf::Arena* arena, int64_t i, int64_t j) -> bool { // return i < j; // }; // // auto cel_func = // PortableBinaryFunctionAdapter::Create("<", // false, func); template using PortableBinaryFunctionAdapter = BinaryFunctionAdapter; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_PORTABLE_CEL_FUNCTION_ADAPTER_H_ ================================================ FILE: eval/public/set_util.cc ================================================ #include "eval/public/set_util.h" #include #include namespace google::api::expr::runtime { namespace { // Default implementation is operator<. // Note: for UnknownSet, Error and Message, this is ptr less than. template int ComparisonImpl(T lhs, T rhs) { if (lhs < rhs) { return -1; } else if (lhs > rhs) { return 1; } else { return 0; } } template <> int ComparisonImpl(const CelError* lhs, const CelError* rhs) { if (*lhs == *rhs) { return 0; } return lhs < rhs ? -1 : 1; } // Message wrapper specialization template <> int ComparisonImpl(CelValue::MessageWrapper lhs_wrapper, CelValue::MessageWrapper rhs_wrapper) { auto* lhs = lhs_wrapper.message_ptr(); auto* rhs = rhs_wrapper.message_ptr(); if (lhs < rhs) { return -1; } else if (lhs > rhs) { return 1; } else { return 0; } } // List specialization -- compare size then elementwise compare. template <> int ComparisonImpl(const CelList* lhs, const CelList* rhs) { int size_comparison = ComparisonImpl(lhs->size(), rhs->size()); if (size_comparison != 0) { return size_comparison; } google::protobuf::Arena arena; for (int i = 0; i < lhs->size(); i++) { CelValue lhs_i = lhs->Get(&arena, i); CelValue rhs_i = rhs->Get(&arena, i); int value_comparison = CelValueCompare(lhs_i, rhs_i); if (value_comparison != 0) { return value_comparison; } } // equal return 0; } // Map specialization -- size then sorted elementwise compare (i.e. // < // // This is expensive, but hopefully maps will be rarely used in sets. template <> int ComparisonImpl(const CelMap* lhs, const CelMap* rhs) { int size_comparison = ComparisonImpl(lhs->size(), rhs->size()); if (size_comparison != 0) { return size_comparison; } google::protobuf::Arena arena; std::vector lhs_keys; std::vector rhs_keys; lhs_keys.reserve(lhs->size()); rhs_keys.reserve(lhs->size()); const CelList* lhs_key_view = lhs->ListKeys(&arena).value(); const CelList* rhs_key_view = rhs->ListKeys(&arena).value(); for (int i = 0; i < lhs->size(); i++) { lhs_keys.push_back(lhs_key_view->Get(&arena, i)); rhs_keys.push_back(rhs_key_view->Get(&arena, i)); } std::sort(lhs_keys.begin(), lhs_keys.end(), &CelValueLessThan); std::sort(rhs_keys.begin(), rhs_keys.end(), &CelValueLessThan); for (size_t i = 0; i < lhs_keys.size(); i++) { auto lhs_key_i = lhs_keys[i]; auto rhs_key_i = rhs_keys[i]; int key_comparison = CelValueCompare(lhs_key_i, rhs_key_i); if (key_comparison != 0) { return key_comparison; } // keys equal, compare values. auto lhs_value_i = lhs->Get(&arena, lhs_key_i).value(); auto rhs_value_i = rhs->Get(&arena, rhs_key_i).value(); int value_comparison = CelValueCompare(lhs_value_i, rhs_value_i); if (value_comparison != 0) { return value_comparison; } } // maps equal return 0; } struct ComparisonVisitor { explicit ComparisonVisitor(CelValue rhs) : rhs(rhs) {} template int operator()(T lhs_value) { T rhs_value; if (!rhs.GetValue(&rhs_value)) { return ComparisonImpl(CelValue::Type(CelValue::IndexOf::value), rhs.type()); } return ComparisonImpl(lhs_value, rhs_value); } CelValue rhs; }; } // namespace int CelValueCompare(CelValue lhs, CelValue rhs) { return lhs.InternalVisit(ComparisonVisitor(rhs)); } bool CelValueLessThan(CelValue lhs, CelValue rhs) { return lhs.InternalVisit(ComparisonVisitor(rhs)) < 0; } bool CelValueEqual(CelValue lhs, CelValue rhs) { return lhs.InternalVisit(ComparisonVisitor(rhs)) == 0; } bool CelValueGreaterThan(CelValue lhs, CelValue rhs) { return lhs.InternalVisit(ComparisonVisitor(rhs)) > 0; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/set_util.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SET_UTIL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SET_UTIL_H_ #include "eval/public/cel_value.h" namespace google { namespace api { namespace expr { namespace runtime { // Less than operator sufficient as a comparator in a set. This provides // a stable and consistent but not necessarily meaningful ordering. This should // not be used directly in the cel runtime (e.g. as an overload for _<_) as // it conflicts with some of the expected behaviors. // // Type is compared using the the enum ordering for CelValue::Type then // underlying values are compared: // // For lists, compares length first, then in-order elementwise compare. // // For maps, compares size first, then sorted key order elementwise compare // (i.e. ((k1, v1) < (k2, v2))). // // For other types, it defaults to the wrapped value's operator<. // Note that for For messages, errors, and unknown sets, this is a ptr // comparison. bool CelValueLessThan(CelValue lhs, CelValue rhs); bool CelValueEqual(CelValue lhs, CelValue rhs); bool CelValueGreaterThan(CelValue lhs, CelValue rhs); int CelValueCompare(CelValue lhs, CelValue rhs); // Convenience alias for using the CelValueLessThan function in sets providing // the stl interface. using CelValueLessThanComparator = decltype(&CelValueLessThan); using CelValueEqualComparator = decltype(&CelValueEqual); using CelValueGreaterThanComparator = decltype(&CelValueGreaterThan); } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SET_UTIL_H_ ================================================ FILE: eval/public/set_util_test.cc ================================================ #include "eval/public/set_util.h" #include #include #include #include #include #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/unknown_set.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using google::protobuf::Arena; using protobuf::Empty; using protobuf::ListValue; using protobuf::Struct; constexpr char kExampleText[] = "abc"; constexpr char kExampleText2[] = "abd"; std::string* ExampleStr() { static std::string* example = new std::string(kExampleText); return example; } std::string* ExampleStr2() { static std::string* example = new std::string(kExampleText2); return example; } // Returns a vector that has an example for each type, ordered by the type // ordering in |CelValueLessThan|. Length 13 std::vector TypeExamples(Arena* arena) { Empty* empty = Arena::Create(arena); Struct* proto_map = Arena::Create(arena); ListValue* proto_list = Arena::Create(arena); UnknownSet* unknown_set = Arena::Create(arena); return {CelValue::CreateBool(false), CelValue::CreateInt64(0), CelValue::CreateUint64(0), CelValue::CreateDouble(0.0), CelValue::CreateStringView(kExampleText), CelValue::CreateBytes(ExampleStr()), CelProtoWrapper::CreateMessage(empty, arena), CelValue::CreateDuration(absl::ZeroDuration()), CelValue::CreateTimestamp(absl::Now()), CelProtoWrapper::CreateMessage(proto_list, arena), CelProtoWrapper::CreateMessage(proto_map, arena), CelValue::CreateUnknownSet(unknown_set), CreateErrorValue(arena, "test", absl::StatusCode::kInternal)}; } // Parameterized test for confirming type orderings are correct. Compares all // pairs of type examples to confirm the expected type priority. class TypeOrderingTest : public testing::TestWithParam> { public: TypeOrderingTest() { i_ = std::get<0>(GetParam()); j_ = std::get<1>(GetParam()); } protected: int i_; int j_; Arena arena_; }; TEST_P(TypeOrderingTest, TypeLessThan) { auto examples = TypeExamples(&arena_); CelValue lhs = examples[i_]; CelValue rhs = examples[j_]; // Strict less than. EXPECT_EQ(CelValueLessThan(lhs, rhs), i_ < j_); // Equality. EXPECT_EQ(CelValueEqual(lhs, rhs), i_ == j_); } std::string TypeOrderingTestName( testing::TestParamInfo> param) { int i = std::get<0>(param.param); int j = std::get<1>(param.param); return absl::StrCat(CelValue::TypeName(CelValue::Type(i)), "_", CelValue::TypeName(CelValue::Type(j))); } INSTANTIATE_TEST_SUITE_P(TypePairs, TypeOrderingTest, testing::Combine(testing::Range(0, 13), testing::Range(0, 13)), &TypeOrderingTestName); TEST(CelValueLessThanComparator, StdSetSupport) { Arena arena; auto examples = TypeExamples(&arena); std::set value_set(&CelValueLessThan); for (CelValue value : examples) { auto insert = value_set.insert(value); bool was_inserted = insert.second; EXPECT_TRUE(was_inserted) << absl::StrCat("Insertion failed ", CelValue::TypeName(value.type())); } for (CelValue value : examples) { auto insert = value_set.insert(value); bool was_inserted = insert.second; EXPECT_FALSE(was_inserted) << absl::StrCat( "Re-insertion succeeded ", CelValue::TypeName(value.type())); } } enum class ExpectedCmp { kEq, kLt, kGt }; struct PrimitiveCmpTestCase { CelValue lhs; CelValue rhs; ExpectedCmp expected; }; // Test for primitive types that just use operator< for the underlying value. class PrimitiveCmpTest : public testing::TestWithParam { public: PrimitiveCmpTest() { lhs_ = GetParam().lhs; rhs_ = GetParam().rhs; expected_ = GetParam().expected; } protected: CelValue lhs_; CelValue rhs_; ExpectedCmp expected_; }; TEST_P(PrimitiveCmpTest, Basic) { switch (expected_) { case ExpectedCmp::kLt: EXPECT_TRUE(CelValueLessThan(lhs_, rhs_)); break; case ExpectedCmp::kGt: EXPECT_TRUE(CelValueGreaterThan(lhs_, rhs_)); break; case ExpectedCmp::kEq: EXPECT_TRUE(CelValueEqual(lhs_, rhs_)); break; } } std::string PrimitiveCmpTestName( testing::TestParamInfo info) { absl::string_view cmp_name; switch (info.param.expected) { case ExpectedCmp::kEq: cmp_name = "Eq"; break; case ExpectedCmp::kLt: cmp_name = "Lt"; break; case ExpectedCmp::kGt: cmp_name = "Gt"; break; } return absl::StrCat(CelValue::TypeName(info.param.lhs.type()), "_", cmp_name); } INSTANTIATE_TEST_SUITE_P( Pairs, PrimitiveCmpTest, testing::ValuesIn(std::vector{ {CelValue::CreateStringView(kExampleText), CelValue::CreateStringView(kExampleText), ExpectedCmp::kEq}, {CelValue::CreateStringView(kExampleText), CelValue::CreateStringView(kExampleText2), ExpectedCmp::kLt}, {CelValue::CreateStringView(kExampleText2), CelValue::CreateStringView(kExampleText), ExpectedCmp::kGt}, {CelValue::CreateBytes(ExampleStr()), CelValue::CreateBytes(ExampleStr()), ExpectedCmp::kEq}, {CelValue::CreateBytes(ExampleStr()), CelValue::CreateBytes(ExampleStr2()), ExpectedCmp::kLt}, {CelValue::CreateBytes(ExampleStr2()), CelValue::CreateBytes(ExampleStr()), ExpectedCmp::kGt}, {CelValue::CreateBool(false), CelValue::CreateBool(false), ExpectedCmp::kEq}, {CelValue::CreateBool(false), CelValue::CreateBool(true), ExpectedCmp::kLt}, {CelValue::CreateBool(true), CelValue::CreateBool(false), ExpectedCmp::kGt}, {CelValue::CreateInt64(1), CelValue::CreateInt64(1), ExpectedCmp::kEq}, {CelValue::CreateInt64(1), CelValue::CreateInt64(2), ExpectedCmp::kLt}, {CelValue::CreateInt64(2), CelValue::CreateInt64(1), ExpectedCmp::kGt}, {CelValue::CreateUint64(1), CelValue::CreateUint64(1), ExpectedCmp::kEq}, {CelValue::CreateUint64(1), CelValue::CreateUint64(2), ExpectedCmp::kLt}, {CelValue::CreateUint64(2), CelValue::CreateUint64(1), ExpectedCmp::kGt}, {CelValue::CreateDuration(absl::Minutes(1)), CelValue::CreateDuration(absl::Minutes(1)), ExpectedCmp::kEq}, {CelValue::CreateDuration(absl::Minutes(1)), CelValue::CreateDuration(absl::Minutes(2)), ExpectedCmp::kLt}, {CelValue::CreateDuration(absl::Minutes(2)), CelValue::CreateDuration(absl::Minutes(1)), ExpectedCmp::kGt}, {CelValue::CreateTimestamp(absl::FromUnixSeconds(1)), CelValue::CreateTimestamp(absl::FromUnixSeconds(1)), ExpectedCmp::kEq}, {CelValue::CreateTimestamp(absl::FromUnixSeconds(1)), CelValue::CreateTimestamp(absl::FromUnixSeconds(2)), ExpectedCmp::kLt}, {CelValue::CreateTimestamp(absl::FromUnixSeconds(2)), CelValue::CreateTimestamp(absl::FromUnixSeconds(1)), ExpectedCmp::kGt}}), &PrimitiveCmpTestName); TEST(CelValueLessThan, PtrCmpMessage) { Arena arena; CelValue lhs = CelProtoWrapper::CreateMessage(Arena::Create(&arena), &arena); CelValue rhs = CelProtoWrapper::CreateMessage(Arena::Create(&arena), &arena); if (lhs.MessageOrDie() > rhs.MessageOrDie()) { std::swap(lhs, rhs); } EXPECT_TRUE(CelValueLessThan(lhs, rhs)); EXPECT_FALSE(CelValueLessThan(rhs, lhs)); EXPECT_FALSE(CelValueLessThan(lhs, lhs)); } TEST(CelValueLessThan, PtrCmpUnknownSet) { Arena arena; CelValue lhs = CelValue::CreateUnknownSet(Arena::Create(&arena)); CelValue rhs = CelValue::CreateUnknownSet(Arena::Create(&arena)); if (lhs.UnknownSetOrDie() > rhs.UnknownSetOrDie()) { std::swap(lhs, rhs); } EXPECT_TRUE(CelValueLessThan(lhs, rhs)); EXPECT_FALSE(CelValueLessThan(rhs, lhs)); EXPECT_FALSE(CelValueLessThan(lhs, lhs)); } TEST(CelValueLessThan, PtrCmpError) { Arena arena; CelValue lhs = CreateErrorValue(&arena, "test1", absl::StatusCode::kInternal); CelValue rhs = CreateErrorValue(&arena, "test2", absl::StatusCode::kInternal); if (lhs.ErrorOrDie() > rhs.ErrorOrDie()) { std::swap(lhs, rhs); } EXPECT_TRUE(CelValueLessThan(lhs, rhs)); EXPECT_FALSE(CelValueLessThan(rhs, lhs)); EXPECT_FALSE(CelValueLessThan(lhs, lhs)); } TEST(CelValueLessThan, CelListSameSize) { ContainerBackedListImpl cel_list_1(std::vector{ CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); ContainerBackedListImpl cel_list_2(std::vector{ CelValue::CreateInt64(1), CelValue::CreateInt64(3)}); EXPECT_TRUE(CelValueLessThan(CelValue::CreateList(&cel_list_1), CelValue::CreateList(&cel_list_2))); } TEST(CelValueLessThan, CelListDifferentSizes) { ContainerBackedListImpl cel_list_1( std::vector{CelValue::CreateInt64(2)}); ContainerBackedListImpl cel_list_2(std::vector{ CelValue::CreateInt64(1), CelValue::CreateInt64(3)}); EXPECT_TRUE(CelValueLessThan(CelValue::CreateList(&cel_list_1), CelValue::CreateList(&cel_list_2))); } TEST(CelValueLessThan, CelListEqual) { ContainerBackedListImpl cel_list_1(std::vector{ CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); ContainerBackedListImpl cel_list_2(std::vector{ CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); EXPECT_FALSE(CelValueLessThan(CelValue::CreateList(&cel_list_1), CelValue::CreateList(&cel_list_2))); EXPECT_TRUE(CelValueEqual(CelValue::CreateList(&cel_list_2), CelValue::CreateList(&cel_list_1))); } TEST(CelValueLessThan, CelListSupportProtoListCompatible) { Arena arena; ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("abc"); CelValue proto_list = CelProtoWrapper::CreateMessage(&list_value, &arena); ASSERT_TRUE(proto_list.IsList()); std::vector list_values{CelValue::CreateBool(true), CelValue::CreateDouble(1.0), CelValue::CreateStringView("abd")}; ContainerBackedListImpl list_backing(list_values); CelValue cel_list = CelValue::CreateList(&list_backing); EXPECT_TRUE(CelValueLessThan(proto_list, cel_list)); } TEST(CelValueLessThan, CelMapSameSize) { std::vector> values{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; auto cel_map_backing_1 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); std::vector> values2{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(4), CelValue::CreateInt64(6)}}; auto cel_map_backing_2 = CreateContainerBackedMap(absl::MakeSpan(values2)).value(); std::vector> values3{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(8)}}; auto cel_map_backing_3 = CreateContainerBackedMap(absl::MakeSpan(values3)).value(); CelValue map1 = CelValue::CreateMap(cel_map_backing_1.get()); CelValue map2 = CelValue::CreateMap(cel_map_backing_2.get()); CelValue map3 = CelValue::CreateMap(cel_map_backing_3.get()); EXPECT_TRUE(CelValueLessThan(map1, map2)); EXPECT_TRUE(CelValueLessThan(map1, map3)); EXPECT_TRUE(CelValueLessThan(map3, map2)); } TEST(CelValueLessThan, CelMapDifferentSizes) { std::vector> values{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); std::vector> values2{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)).value(); EXPECT_TRUE(CelValueLessThan(CelValue::CreateMap(cel_map_1.get()), CelValue::CreateMap(cel_map_2.get()))); } TEST(CelValueLessThan, CelMapEqual) { std::vector> values{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)).value(); std::vector> values2{ {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)).value(); EXPECT_FALSE(CelValueLessThan(CelValue::CreateMap(cel_map_1.get()), CelValue::CreateMap(cel_map_2.get()))); EXPECT_TRUE(CelValueEqual(CelValue::CreateMap(cel_map_2.get()), CelValue::CreateMap(cel_map_1.get()))); } TEST(CelValueLessThan, CelMapSupportProtoMapCompatible) { Arena arena; const std::vector kFields = {"field1", "field2", "field3"}; Struct value_struct; auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; value1.set_bool_value(true); auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; value2.set_number_value(1.0); auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; value3.set_string_value("test"); CelValue proto_struct = CelProtoWrapper::CreateMessage(&value_struct, &arena); ASSERT_TRUE(proto_struct.IsMap()); std::vector> values{ {CelValue::CreateStringView(kFields[2]), CelValue::CreateStringView("test")}, {CelValue::CreateStringView(kFields[1]), CelValue::CreateDouble(1.0)}, {CelValue::CreateStringView(kFields[0]), CelValue::CreateBool(true)}}; auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)).value(); CelValue cel_map = CelValue::CreateMap(backing_map.get()); EXPECT_TRUE(!CelValueLessThan(cel_map, proto_struct) && !CelValueGreaterThan(cel_map, proto_struct)); } TEST(CelValueLessThan, NestedMap) { Arena arena; ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); std::vector list_values{CelValue::CreateBool(true), CelValue::CreateDouble(1.0), CelValue::CreateStringView("test")}; ContainerBackedListImpl list_backing(list_values); CelValue cel_list = CelValue::CreateList(&list_backing); Struct value_struct; *(value_struct.mutable_fields()->operator[]("field").mutable_list_value()) = list_value; std::vector> values{ {CelValue::CreateStringView("field"), cel_list}}; auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)).value(); CelValue cel_map = CelValue::CreateMap(backing_map.get()); CelValue proto_map = CelProtoWrapper::CreateMessage(&value_struct, &arena); EXPECT_TRUE(!CelValueLessThan(cel_map, proto_map) && !CelValueLessThan(proto_map, cel_map)); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/source_position.cc ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/source_position.h" #include namespace google { namespace api { namespace expr { namespace runtime { using cel::expr::SourceInfo; namespace { std::pair GetLineAndLineOffset(const SourceInfo* source_info, int32_t position) { int line = 0; int32_t line_offset = 0; if (source_info != nullptr) { for (const auto& curr_line_offset : source_info->line_offsets()) { if (curr_line_offset > position) { break; } line_offset = curr_line_offset; line++; } } if (line == 0) { line++; } return std::pair(line, line_offset); } } // namespace int32_t SourcePosition::line() const { return GetLineAndLineOffset(source_info_, character_offset()).first; } int32_t SourcePosition::column() const { int32_t position = character_offset(); std::pair line_and_offset = GetLineAndLineOffset(source_info_, position); return 1 + (position - line_and_offset.second); } int32_t SourcePosition::character_offset() const { if (source_info_ == nullptr) { return 0; } auto position_it = source_info_->positions().find(expr_id_); return position_it != source_info_->positions().end() ? position_it->second : 0; } } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/source_position.h ================================================ /* * Copyright 2018 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ #include "cel/expr/syntax.pb.h" namespace google { namespace api { namespace expr { namespace runtime { // Class representing the source position as well as line and column data for // a given expression id. class SourcePosition { public: // Constructor for a SourcePosition value. The source_info may be nullptr, // in which case line, column, and character_offset will return 0. SourcePosition(const int64_t expr_id, const cel::expr::SourceInfo* source_info) : expr_id_(expr_id), source_info_(source_info) {} // Non-copyable SourcePosition(const SourcePosition& other) = delete; SourcePosition& operator=(const SourcePosition& other) = delete; virtual ~SourcePosition() {} // Return the 1-based source line number for the expression. int32_t line() const; // Return the 1-based column offset within the source line for the // expression. int32_t column() const; // Return the 0-based character offset of the expression within source. int32_t character_offset() const; private: // The expression identifier. const int64_t expr_id_; // The source information reference generated during expression parsing. const cel::expr::SourceInfo* source_info_; }; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_SOURCE_POSITION_H_ ================================================ FILE: eval/public/source_position_test.cc ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/source_position.h" #include "cel/expr/syntax.pb.h" #include "internal/testing.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::testing::Eq; using cel::expr::SourceInfo; class SourcePositionTest : public testing::Test { protected: void SetUp() override { // Simulate the expression positions : '\n\na\n&& b\n\n|| c' // // Within the ExprChecker, the line offset is the first character of the // line rather than the newline character. // // The tests outputs are affected by leading newlines, but not trailing // newlines, and the ExprChecker will actually always generate a trailing // newline entry for EOF; however, this offset is not included in the test // since there may be other parsers which generate newline information // slightly differently. source_info_.add_line_offsets(0); source_info_.add_line_offsets(1); source_info_.add_line_offsets(2); (*source_info_.mutable_positions())[1] = 2; source_info_.add_line_offsets(4); (*source_info_.mutable_positions())[2] = 4; (*source_info_.mutable_positions())[3] = 7; source_info_.add_line_offsets(9); source_info_.add_line_offsets(10); (*source_info_.mutable_positions())[4] = 10; (*source_info_.mutable_positions())[5] = 13; } SourceInfo source_info_; }; TEST_F(SourcePositionTest, TestNullSourceInfo) { SourcePosition position(3, nullptr); EXPECT_THAT(position.character_offset(), Eq(0)); EXPECT_THAT(position.line(), Eq(1)); EXPECT_THAT(position.column(), Eq(1)); } TEST_F(SourcePositionTest, TestNoNewlines) { source_info_.clear_line_offsets(); SourcePosition position(3, &source_info_); EXPECT_THAT(position.character_offset(), Eq(7)); EXPECT_THAT(position.line(), Eq(1)); EXPECT_THAT(position.column(), Eq(8)); } TEST_F(SourcePositionTest, TestPosition) { SourcePosition position(3, &source_info_); EXPECT_THAT(position.character_offset(), Eq(7)); } TEST_F(SourcePositionTest, TestLine) { SourcePosition position1(1, &source_info_); EXPECT_THAT(position1.line(), Eq(3)); SourcePosition position2(2, &source_info_); EXPECT_THAT(position2.line(), Eq(4)); SourcePosition position3(3, &source_info_); EXPECT_THAT(position3.line(), Eq(4)); SourcePosition position4(5, &source_info_); EXPECT_THAT(position4.line(), Eq(6)); } TEST_F(SourcePositionTest, TestColumn) { SourcePosition position1(1, &source_info_); EXPECT_THAT(position1.column(), Eq(1)); SourcePosition position2(2, &source_info_); EXPECT_THAT(position2.column(), Eq(1)); SourcePosition position3(3, &source_info_); EXPECT_THAT(position3.column(), Eq(4)); SourcePosition position4(5, &source_info_); EXPECT_THAT(position4.column(), Eq(4)); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/string_extension_func_registrar.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/string_extension_func_registrar.h" #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "extensions/strings.h" namespace google::api::expr::runtime { absl::Status RegisterStringExtensionFunctions( CelFunctionRegistry* registry, const InterpreterOptions& options) { return cel::extensions::RegisterStringsFunctions(registry, options); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/string_extension_func_registrar.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRING_EXTENSION_FUNC_REGISTRAR_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRING_EXTENSION_FUNC_REGISTRAR_H_ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" namespace google::api::expr::runtime { // Register string related widely used extension functions. absl::Status RegisterStringExtensionFunctions( CelFunctionRegistry* registry, const InterpreterOptions& options = InterpreterOptions()); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRING_EXTENSION_FUNC_REGISTRAR_H_ ================================================ FILE: eval/public/string_extension_func_registrar_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/string_extension_func_registrar.h" #include #include #include #include "cel/expr/checked.pb.h" #include "absl/types/span.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using google::protobuf::Arena; class StringExtensionTest : public ::testing::Test { protected: StringExtensionTest() = default; void SetUp() override { ASSERT_OK(RegisterBuiltinFunctions(®istry_)); ASSERT_OK(RegisterStringExtensionFunctions(®istry_)); } void PerformSplitStringTest(Arena* arena, std::string* value, std::string* delimiter, CelValue* result) { auto function = registry_.FindOverloads( "split", true, {CelValue::Type::kString, CelValue::Type::kString}); ASSERT_EQ(function.size(), 1); auto func = function[0]; std::vector args = {CelValue::CreateString(value), CelValue::CreateString(delimiter)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformSplitStringWithLimitTest(Arena* arena, std::string* value, std::string* delimiter, int64_t limit, CelValue* result) { auto function = registry_.FindOverloads( "split", true, {CelValue::Type::kString, CelValue::Type::kString, CelValue::Type::kInt64}); ASSERT_EQ(function.size(), 1); auto func = function[0]; std::vector args = {CelValue::CreateString(value), CelValue::CreateString(delimiter), CelValue::CreateInt64(limit)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformJoinStringTest(Arena* arena, std::vector& values, CelValue* result) { auto function = registry_.FindOverloads("join", true, {CelValue::Type::kList}); ASSERT_EQ(function.size(), 1); auto func = function[0]; std::vector cel_list; cel_list.reserve(values.size()); for (const std::string& value : values) { cel_list.push_back( CelValue::CreateString(Arena::Create(arena, value))); } std::vector args = {CelValue::CreateList( Arena::Create(arena, cel_list))}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformJoinStringWithSeparatorTest(Arena* arena, std::vector& values, std::string* separator, CelValue* result) { auto function = registry_.FindOverloads( "join", true, {CelValue::Type::kList, CelValue::Type::kString}); ASSERT_EQ(function.size(), 1); auto func = function[0]; std::vector cel_list; cel_list.reserve(values.size()); for (const std::string& value : values) { cel_list.push_back( CelValue::CreateString(Arena::Create(arena, value))); } std::vector args = { CelValue::CreateList( Arena::Create(arena, cel_list)), CelValue::CreateString(separator)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } void PerformLowerAsciiTest(Arena* arena, std::string* value, CelValue* result) { auto function = registry_.FindOverloads("lowerAscii", true, {CelValue::Type::kString}); ASSERT_EQ(function.size(), 1); auto func = function[0]; std::vector args = {CelValue::CreateString(value)}; absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); ASSERT_OK(status); } // Function registry CelFunctionRegistry registry_; Arena arena_; }; TEST_F(StringExtensionTest, TestStringSplit) { Arena arena; CelValue result; std::string value = "This!!Is!!Test"; std::string delimiter = "!!"; std::vector expected = {"This", "Is", "Test"}; ASSERT_NO_FATAL_FAILURE( PerformSplitStringTest(&arena, &value, &delimiter, &result)); ASSERT_EQ(result.type(), CelValue::Type::kList); EXPECT_EQ(result.ListOrDie()->size(), 3); for (int i = 0; i < expected.size(); ++i) { EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), expected[i]); } } TEST_F(StringExtensionTest, TestStringSplitEmptyDelimiter) { Arena arena; CelValue result; std::string value = "TEST"; std::string delimiter = ""; std::vector expected = {"T", "E", "S", "T"}; ASSERT_NO_FATAL_FAILURE( PerformSplitStringTest(&arena, &value, &delimiter, &result)); ASSERT_EQ(result.type(), CelValue::Type::kList); EXPECT_EQ(result.ListOrDie()->size(), 4); for (int i = 0; i < expected.size(); ++i) { EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), expected[i]); } } TEST_F(StringExtensionTest, TestStringSplitWithLimitTwo) { Arena arena; CelValue result; int64_t limit = 2; std::string value = "This!!Is!!Test"; std::string delimiter = "!!"; std::vector expected = {"This", "Is!!Test"}; ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( &arena, &value, &delimiter, limit, &result)); ASSERT_EQ(result.type(), CelValue::Type::kList); EXPECT_EQ(result.ListOrDie()->size(), 2); for (int i = 0; i < expected.size(); ++i) { EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), expected[i]); } } TEST_F(StringExtensionTest, TestStringSplitWithLimitOne) { Arena arena; CelValue result; int64_t limit = 1; std::string value = "This!!Is!!Test"; std::string delimiter = "!!"; ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( &arena, &value, &delimiter, limit, &result)); ASSERT_EQ(result.type(), CelValue::Type::kList); EXPECT_EQ(result.ListOrDie()->size(), 1); EXPECT_EQ(result.ListOrDie()->Get(&arena, 0).StringOrDie().value(), value); } TEST_F(StringExtensionTest, TestStringSplitWithLimitZero) { Arena arena; CelValue result; int64_t limit = 0; std::string value = "This!!Is!!Test"; std::string delimiter = "!!"; ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( &arena, &value, &delimiter, limit, &result)); ASSERT_EQ(result.type(), CelValue::Type::kList); EXPECT_EQ(result.ListOrDie()->size(), 0); } TEST_F(StringExtensionTest, TestStringSplitWithLimitNegative) { Arena arena; CelValue result; int64_t limit = -1; std::string value = "This!!Is!!Test"; std::string delimiter = "!!"; std::vector expected = {"This", "Is", "Test"}; ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( &arena, &value, &delimiter, limit, &result)); ASSERT_EQ(result.type(), CelValue::Type::kList); EXPECT_EQ(result.ListOrDie()->size(), 3); for (int i = 0; i < expected.size(); ++i) { EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), expected[i]); } } TEST_F(StringExtensionTest, TestStringSplitWithLimitAsMaxPossibleSplits) { Arena arena; CelValue result; int64_t limit = 3; std::string value = "This!!Is!!Test"; std::string delimiter = "!!"; std::vector expected = {"This", "Is", "Test"}; ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( &arena, &value, &delimiter, limit, &result)); ASSERT_EQ(result.type(), CelValue::Type::kList); EXPECT_EQ(result.ListOrDie()->size(), 3); for (int i = 0; i < expected.size(); ++i) { EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), expected[i]); } } TEST_F(StringExtensionTest, TestStringSplitWithLimitGreaterThanMaxPossibleSplits) { Arena arena; CelValue result; int64_t limit = 4; std::string value = "This!!Is!!Test"; std::string delimiter = "!!"; std::vector expected = {"This", "Is", "Test"}; ASSERT_NO_FATAL_FAILURE(PerformSplitStringWithLimitTest( &arena, &value, &delimiter, limit, &result)); ASSERT_EQ(result.type(), CelValue::Type::kList); EXPECT_EQ(result.ListOrDie()->size(), 3); for (int i = 0; i < expected.size(); ++i) { EXPECT_EQ(result.ListOrDie()->Get(&arena, i).StringOrDie().value(), expected[i]); } } TEST_F(StringExtensionTest, TestStringJoin) { Arena arena; CelValue result; std::vector value = {"This", "Is", "Test"}; std::string expected = "ThisIsTest"; ASSERT_NO_FATAL_FAILURE(PerformJoinStringTest(&arena, value, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } TEST_F(StringExtensionTest, TestStringJoinEmptyInput) { Arena arena; CelValue result; std::vector value = {}; std::string expected = ""; ASSERT_NO_FATAL_FAILURE(PerformJoinStringTest(&arena, value, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } TEST_F(StringExtensionTest, TestStringJoinWithSeparator) { Arena arena; CelValue result; std::vector value = {"This", "Is", "Test"}; std::string separator = "-"; std::string expected = "This-Is-Test"; ASSERT_NO_FATAL_FAILURE( PerformJoinStringWithSeparatorTest(&arena, value, &separator, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } TEST_F(StringExtensionTest, TestStringJoinWithMultiCharSeparator) { Arena arena; CelValue result; std::vector value = {"This", "Is", "Test"}; std::string separator = "--"; std::string expected = "This--Is--Test"; ASSERT_NO_FATAL_FAILURE( PerformJoinStringWithSeparatorTest(&arena, value, &separator, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } TEST_F(StringExtensionTest, TestStringJoinWithEmptySeparator) { Arena arena; CelValue result; std::vector value = {"This", "Is", "Test"}; std::string separator = ""; std::string expected = "ThisIsTest"; ASSERT_NO_FATAL_FAILURE( PerformJoinStringWithSeparatorTest(&arena, value, &separator, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } TEST_F(StringExtensionTest, TestStringJoinWithSeparatorEmptyInput) { Arena arena; CelValue result; std::vector value = {}; std::string separator = "-"; std::string expected = ""; ASSERT_NO_FATAL_FAILURE( PerformJoinStringWithSeparatorTest(&arena, value, &separator, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } TEST_F(StringExtensionTest, TestLowerAscii) { Arena arena; CelValue result; std::string value = "ThisIs@Test!-5"; std::string expected = "thisis@test!-5"; ASSERT_NO_FATAL_FAILURE(PerformLowerAsciiTest(&arena, &value, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } TEST_F(StringExtensionTest, TestLowerAsciiWithEmptyInput) { Arena arena; CelValue result; std::string value = ""; std::string expected = ""; ASSERT_NO_FATAL_FAILURE(PerformLowerAsciiTest(&arena, &value, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } TEST_F(StringExtensionTest, TestLowerAsciiWithNonAsciiCharacter) { Arena arena; CelValue result; std::string value = "TacoCÆt"; std::string expected = "tacocÆt"; ASSERT_NO_FATAL_FAILURE(PerformLowerAsciiTest(&arena, &value, &result)); ASSERT_EQ(result.type(), CelValue::Type::kString); EXPECT_EQ(result.StringOrDie().value(), expected); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "cel_proto_wrapper", srcs = [ "cel_proto_wrapper.cc", ], hdrs = [ "cel_proto_wrapper.h", ], deps = [ ":cel_proto_wrap_util", ":proto_message_type_adapter", "//eval/public:cel_value", "//eval/public:message_wrapper", "//internal:proto_time_encoding", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_library( name = "protobuf_value_factory", hdrs = [ "protobuf_value_factory.h", ], deps = [ "//eval/public:cel_value", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_proto_wrap_util", srcs = [ "cel_proto_wrap_util.cc", ], hdrs = [ "cel_proto_wrap_util.h", ], deps = [ ":protobuf_value_factory", "//eval/public:cel_value", "//internal:overflow", "//internal:proto_time_encoding", "//internal:status_macros", "//internal:time", "//internal:well_known_types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_test( name = "cel_proto_wrap_util_test", size = "small", srcs = [ "cel_proto_wrap_util_test.cc", ], deps = [ ":cel_proto_wrap_util", ":protobuf_value_factory", ":trivial_legacy_type_info", "//eval/public:cel_value", "//eval/public:message_wrapper", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", "//internal:proto_time_encoding", "//internal:status_macros", "//internal:testing", "//testutil:util", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( name = "field_access_impl", srcs = [ "field_access_impl.cc", ], hdrs = [ "field_access_impl.h", ], deps = [ ":cel_proto_wrap_util", ":protobuf_value_factory", "//eval/public:cel_options", "//eval/public:cel_value", "//internal:casts", "//internal:overflow", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_test( name = "field_access_impl_test", srcs = ["field_access_impl_test.cc"], deps = [ ":cel_proto_wrapper", ":field_access_impl", "//eval/public:cel_value", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//internal:testing", "//internal:time", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_proto_descriptor_pool_builder", srcs = ["cel_proto_descriptor_pool_builder.cc"], hdrs = ["cel_proto_descriptor_pool_builder.h"], deps = [ "//internal:proto_util", "//internal:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:field_mask_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_test( name = "cel_proto_descriptor_pool_builder_test", srcs = ["cel_proto_descriptor_pool_builder_test.cc"], deps = [ ":cel_proto_descriptor_pool_builder", "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/container:flat_hash_map", "@com_google_protobuf//:any_cc_proto", ], ) cc_test( name = "cel_proto_wrapper_test", size = "small", srcs = [ "cel_proto_wrapper_test.cc", ], deps = [ ":cel_proto_wrapper", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", "//internal:proto_time_encoding", "//internal:status_macros", "//internal:testing", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( name = "legacy_type_provider", srcs = ["legacy_type_provider.cc"], hdrs = ["legacy_type_provider.h"], deps = [ ":legacy_type_adapter", ":legacy_type_info_apis", "//common:legacy_value", "//common:memory", "//common:type", "//common:value", "//eval/public:message_wrapper", "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "legacy_type_adapter", hdrs = ["legacy_type_adapter.h"], deps = [ "//base:attributes", "//common:memory", "//eval/public:cel_options", "//eval/public:cel_value", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) cc_test( name = "legacy_type_adapter_test", srcs = ["legacy_type_adapter_test.cc"], deps = [ ":legacy_type_adapter", ":trivial_legacy_type_info", "//eval/public:cel_value", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "proto_message_type_adapter", srcs = ["proto_message_type_adapter.cc"], hdrs = ["proto_message_type_adapter.h"], deps = [ ":cel_proto_wrap_util", ":field_access_impl", ":legacy_type_adapter", ":legacy_type_info_apis", "//base:attributes", "//common:memory", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:message_wrapper", "//eval/public/containers:internal_field_backed_list_impl", "//eval/public/containers:internal_field_backed_map_impl", "//extensions/protobuf:memory_manager", "//extensions/protobuf/internal:qualify", "//internal:casts", "//internal:status_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "proto_message_type_adapter_test", srcs = ["proto_message_type_adapter_test.cc"], deps = [ ":legacy_type_adapter", ":legacy_type_info_apis", ":proto_message_type_adapter", "//base:attributes", "//common:value", "//eval/public:cel_value", "//eval/public:message_wrapper", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", "//internal:proto_matchers", "//internal:testing", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( name = "protobuf_descriptor_type_provider", srcs = ["protobuf_descriptor_type_provider.cc"], hdrs = ["protobuf_descriptor_type_provider.h"], deps = [ ":legacy_type_adapter", ":legacy_type_info_apis", ":legacy_type_provider", ":proto_message_type_adapter", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "protobuf_descriptor_type_provider_test", srcs = ["protobuf_descriptor_type_provider_test.cc"], deps = [ ":legacy_type_info_apis", ":protobuf_descriptor_type_provider", "//eval/public:cel_value", "//eval/public/testing:matchers", "//extensions/protobuf:memory_manager", "//internal:testing", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( name = "legacy_type_info_apis", hdrs = ["legacy_type_info_apis.h"], deps = [ "//eval/public:message_wrapper", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "trivial_legacy_type_info", testonly = True, hdrs = ["trivial_legacy_type_info.h"], deps = [ ":legacy_type_info_apis", "//eval/public:message_wrapper", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "trivial_legacy_type_info_test", srcs = ["trivial_legacy_type_info_test.cc"], deps = [ ":trivial_legacy_type_info", "//eval/public:message_wrapper", "//internal:testing", ], ) cc_test( name = "legacy_type_provider_test", srcs = ["legacy_type_provider_test.cc"], deps = [ ":legacy_type_info_apis", ":legacy_type_provider", "//internal:testing", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "dynamic_descriptor_pool_end_to_end_test", srcs = ["dynamic_descriptor_pool_end_to_end_test.cc"], deps = [ ":cel_proto_descriptor_pool_builder", ":cel_proto_wrapper", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public/testing:matchers", "//internal:testing", "//parser", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: eval/public/structs/cel_proto_descriptor_pool_builder.cc ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/container/flat_hash_map.h" #include "internal/proto_util.h" #include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { template absl::Status AddOrValidateMessageType(google::protobuf::DescriptorPool& descriptor_pool) { const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); if (descriptor_pool.FindMessageTypeByName(descriptor->full_name()) != nullptr) { return internal::ValidateStandardMessageType(descriptor_pool); } google::protobuf::FileDescriptorProto file_descriptor_proto; descriptor->file()->CopyTo(&file_descriptor_proto); if (descriptor_pool.BuildFile(file_descriptor_proto) == nullptr) { return absl::InternalError( absl::StrFormat("Failed to add descriptor '%s' to descriptor pool", descriptor->full_name())); } return absl::OkStatus(); } template void AddStandardMessageTypeToMap( absl::flat_hash_map& fdmap) { const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); if (fdmap.contains(descriptor->file()->name())) return; descriptor->file()->CopyTo(&fdmap[descriptor->file()->name()]); } } // namespace absl::Status AddStandardMessageTypesToDescriptorPool( google::protobuf::DescriptorPool& descriptor_pool) { // The types below do not depend on each other, hence we can add them in any // order. Should that change with new messages add them in the proper order, // i.e., dependencies first. CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); CEL_RETURN_IF_ERROR( AddOrValidateMessageType(descriptor_pool)); return absl::OkStatus(); } google::protobuf::FileDescriptorSet GetStandardMessageTypesFileDescriptorSet() { // The types below do not depend on each other, hence we can add them to // an unordered map. Should that change with new messages being added here // adapt this to a sorted data structure and add in the proper order. absl::flat_hash_map files; AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); AddStandardMessageTypeToMap(files); google::protobuf::FileDescriptorSet fdset; for (const auto& [name, fdproto] : files) { *fdset.add_file() = fdproto; } return fdset; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/cel_proto_descriptor_pool_builder.h ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ #include "google/protobuf/descriptor.pb.h" #include "absl/status/status.h" #include "google/protobuf/descriptor.h" namespace google::api::expr::runtime { // Add standard message types required by CEL to given descriptor pool. // This includes standard wrappers, timestamp, duration, any, etc. // This does not work for descriptor pools that have a fallback database. // Use GetStandardMessageTypesFileDescriptorSet() below instead to populate. absl::Status AddStandardMessageTypesToDescriptorPool( google::protobuf::DescriptorPool& descriptor_pool); // Get the standard message types required by CEL. // This includes standard wrappers, timestamp, duration, any, etc. These can be // used to, e.g., add them to a DescriptorDatabase backing a DescriptorPool. google::protobuf::FileDescriptorSet GetStandardMessageTypesFileDescriptorSet(); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_DESCRIPTOR_POOL_BUILDER_H_ ================================================ FILE: eval/public/structs/cel_proto_descriptor_pool_builder_test.cc ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include #include #include "google/protobuf/any.pb.h" #include "absl/container/flat_hash_map.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::testing::HasSubstr; using ::testing::UnorderedElementsAre; TEST(DescriptorPoolUtilsTest, PopulatesEmptyDescriptorPool) { google::protobuf::DescriptorPool descriptor_pool; ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Any"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.BoolValue"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.BytesValue"), nullptr); ASSERT_EQ( descriptor_pool.FindMessageTypeByName("google.protobuf.DoubleValue"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Duration"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.FloatValue"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Int32Value"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Int64Value"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.ListValue"), nullptr); ASSERT_EQ( descriptor_pool.FindMessageTypeByName("google.protobuf.StringValue"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Struct"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Timestamp"), nullptr); ASSERT_EQ( descriptor_pool.FindMessageTypeByName("google.protobuf.UInt32Value"), nullptr); ASSERT_EQ( descriptor_pool.FindMessageTypeByName("google.protobuf.UInt64Value"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.Value"), nullptr); ASSERT_EQ(descriptor_pool.FindMessageTypeByName("google.protobuf.FieldMask"), nullptr); ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Any"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.BoolValue"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.BytesValue"), nullptr); EXPECT_NE( descriptor_pool.FindMessageTypeByName("google.protobuf.DoubleValue"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Duration"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.FloatValue"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Int32Value"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Int64Value"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.ListValue"), nullptr); EXPECT_NE( descriptor_pool.FindMessageTypeByName("google.protobuf.StringValue"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Struct"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Timestamp"), nullptr); EXPECT_NE( descriptor_pool.FindMessageTypeByName("google.protobuf.UInt32Value"), nullptr); EXPECT_NE( descriptor_pool.FindMessageTypeByName("google.protobuf.UInt64Value"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Value"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.FieldMask"), nullptr); EXPECT_NE(descriptor_pool.FindMessageTypeByName("google.protobuf.Empty"), nullptr); } TEST(DescriptorPoolUtilsTest, AcceptsPreAddedStandardTypes) { google::protobuf::DescriptorPool descriptor_pool; for (auto proto_name : std::vector{ "google.protobuf.Any", "google.protobuf.BoolValue", "google.protobuf.BytesValue", "google.protobuf.DoubleValue", "google.protobuf.Duration", "google.protobuf.FloatValue", "google.protobuf.Int32Value", "google.protobuf.Int64Value", "google.protobuf.ListValue", "google.protobuf.StringValue", "google.protobuf.Struct", "google.protobuf.Timestamp", "google.protobuf.UInt32Value", "google.protobuf.UInt64Value", "google.protobuf.Value", "google.protobuf.FieldMask", "google.protobuf.Empty"}) { const google::protobuf::Descriptor* descriptor = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( proto_name); ASSERT_NE(descriptor, nullptr); google::protobuf::FileDescriptorProto file_descriptor_proto; descriptor->file()->CopyTo(&file_descriptor_proto); ASSERT_NE(descriptor_pool.BuildFile(file_descriptor_proto), nullptr); } EXPECT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); } TEST(DescriptorPoolUtilsTest, RejectsModifiedStandardType) { google::protobuf::DescriptorPool descriptor_pool; const google::protobuf::Descriptor* descriptor = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.protobuf.Duration"); ASSERT_NE(descriptor, nullptr); google::protobuf::FileDescriptorProto file_descriptor_proto; descriptor->file()->CopyTo(&file_descriptor_proto); // We emulate a modification by external code that replaced the nanos by a // millis field. google::protobuf::FieldDescriptorProto seconds_desc_proto; google::protobuf::FieldDescriptorProto nanos_desc_proto; descriptor->FindFieldByName("seconds")->CopyTo(&seconds_desc_proto); descriptor->FindFieldByName("nanos")->CopyTo(&nanos_desc_proto); nanos_desc_proto.set_name("millis"); file_descriptor_proto.mutable_message_type(0)->clear_field(); *file_descriptor_proto.mutable_message_type(0)->add_field() = seconds_desc_proto; *file_descriptor_proto.mutable_message_type(0)->add_field() = nanos_desc_proto; descriptor_pool.BuildFile(file_descriptor_proto); EXPECT_THAT( AddStandardMessageTypesToDescriptorPool(descriptor_pool), StatusIs(absl::StatusCode::kFailedPrecondition, HasSubstr("differs"))); } TEST(DescriptorPoolUtilsTest, GetStandardMessageTypesFileDescriptorSet) { google::protobuf::FileDescriptorSet fdset = GetStandardMessageTypesFileDescriptorSet(); std::vector file_names; for (int i = 0; i < fdset.file_size(); ++i) { file_names.push_back(fdset.file(i).name()); } EXPECT_THAT( file_names, UnorderedElementsAre( "google/protobuf/any.proto", "google/protobuf/struct.proto", "google/protobuf/wrappers.proto", "google/protobuf/timestamp.proto", "google/protobuf/duration.proto", "google/protobuf/field_mask.proto", "google/protobuf/empty.proto")); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/cel_proto_wrap_util.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/cel_proto_wrap_util.h" #include #include #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "eval/public/cel_value.h" #include "eval/public/structs/protobuf_value_factory.h" #include "internal/overflow.h" #include "internal/proto_time_encoding.h" #include "internal/status_macros.h" #include "internal/time.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" namespace google::api::expr::runtime::internal { namespace { using cel::internal::DecodeDuration; using cel::internal::DecodeTime; using google::protobuf::Any; using google::protobuf::BoolValue; using google::protobuf::BytesValue; using google::protobuf::DoubleValue; using google::protobuf::Duration; using google::protobuf::FloatValue; using google::protobuf::Int32Value; using google::protobuf::Int64Value; using google::protobuf::ListValue; using google::protobuf::StringValue; using google::protobuf::Struct; using google::protobuf::Timestamp; using google::protobuf::UInt32Value; using google::protobuf::UInt64Value; using google::protobuf::Value; using google::protobuf::Arena; using google::protobuf::Descriptor; using google::protobuf::DescriptorPool; using google::protobuf::Message; using google::protobuf::MessageFactory; // kMaxIntJSON is defined as the Number.MAX_SAFE_INTEGER value per EcmaScript 6. constexpr int64_t kMaxIntJSON = (1ll << 53) - 1; // kMinIntJSON is defined as the Number.MIN_SAFE_INTEGER value per EcmaScript 6. constexpr int64_t kMinIntJSON = -kMaxIntJSON; // IsJSONSafe indicates whether the int is safely representable as a floating // point value in JSON. static bool IsJSONSafe(int64_t i) { return i >= kMinIntJSON && i <= kMaxIntJSON; } // IsJSONSafe indicates whether the uint is safely representable as a floating // point value in JSON. static bool IsJSONSafe(uint64_t i) { return i <= static_cast(kMaxIntJSON); } // Map implementation wrapping google.protobuf.ListValue class DynamicList : public CelList { public: DynamicList(const ListValue* values, ProtobufValueFactory factory, Arena* arena) : arena_(arena), factory_(std::move(factory)), values_(values) {} CelValue operator[](int index) const override; // List size int size() const override { return values_->values_size(); } private: Arena* arena_; ProtobufValueFactory factory_; const ListValue* values_; }; // Map implementation wrapping google.protobuf.Struct. class DynamicMap : public CelMap { public: DynamicMap(const Struct* values, ProtobufValueFactory factory, Arena* arena) : arena_(arena), factory_(std::move(factory)), values_(values), key_list_(values) {} absl::StatusOr Has(const CelValue& key) const override { CelValue::StringHolder str_key; if (!key.GetValue(&str_key)) { // Not a string key. return absl::InvalidArgumentError(absl::StrCat( "Invalid map key type: '", CelValue::TypeName(key.type()), "'")); } return values_->fields().contains(std::string(str_key.value())); } absl::optional operator[](CelValue key) const override; int size() const override { return values_->fields_size(); } absl::StatusOr ListKeys() const override { return &key_list_; } private: // List of keys in Struct.fields map. // It utilizes lazy initialization, to avoid performance penalties. class DynamicMapKeyList : public CelList { public: explicit DynamicMapKeyList(const Struct* values) : values_(values), keys_(), initialized_(false) {} // Index access CelValue operator[](int index) const override { CheckInit(); return keys_[index]; } // List size int size() const override { CheckInit(); return values_->fields_size(); } private: void CheckInit() const { absl::MutexLock lock(mutex_); if (!initialized_) { for (const auto& it : values_->fields()) { keys_.push_back(CelValue::CreateString(&it.first)); } initialized_ = true; } } const Struct* values_; mutable absl::Mutex mutex_; mutable std::vector keys_; mutable bool initialized_; }; Arena* arena_; ProtobufValueFactory factory_; const Struct* values_; const DynamicMapKeyList key_list_; }; // Adapter for usage with CEL_RETURN_IF_ERROR and CEL_ASSIGN_OR_RETURN. class ReturnCelValueError { public: explicit ReturnCelValueError(google::protobuf::Arena* absl_nonnull arena) : arena_(arena) {} CelValue operator()(const absl::Status& status) const { ABSL_DCHECK(!status.ok()); return CelValue::CreateError( google::protobuf::Arena::Create(arena_, status)); } private: google::protobuf::Arena* absl_nonnull arena_; }; struct IgnoreErrorAndReturnNullptr { std::nullptr_t operator()(const absl::Status& status) const { status.IgnoreError(); return nullptr; } }; // ValueManager provides ValueFromMessage(....) function family. // Functions of this family create CelValue object from specific subtypes of // protobuf message. class ValueManager { public: ValueManager(const ProtobufValueFactory& value_factory, const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena, google::protobuf::MessageFactory* message_factory) : value_factory_(value_factory), descriptor_pool_(descriptor_pool), arena_(arena), message_factory_(message_factory) {} // Note: this overload should only be used in the context of accessing struct // value members, which have already been adapted to the generated message // types. ValueManager(const ProtobufValueFactory& value_factory, google::protobuf::Arena* arena) : value_factory_(value_factory), descriptor_pool_(DescriptorPool::generated_pool()), arena_(arena), message_factory_(MessageFactory::generated_factory()) {} static CelValue ValueFromDuration(absl::Duration duration) { return CelValue::CreateDuration(duration); } CelValue ValueFromDuration(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetDurationReflection(message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromDuration(reflection.UnsafeToAbslDuration(*message)); } CelValue ValueFromMessage(const Duration* duration) { return ValueFromDuration(DecodeDuration(*duration)); } CelValue ValueFromTimestamp(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetTimestampReflection(message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromTimestamp(reflection.UnsafeToAbslTime(*message)); } static CelValue ValueFromTimestamp(absl::Time timestamp) { return CelValue::CreateTimestamp(timestamp); } CelValue ValueFromMessage(const Timestamp* timestamp) { return ValueFromTimestamp(DecodeTime(*timestamp)); } CelValue ValueFromMessage(const ListValue* list_values) { return CelValue::CreateList(Arena::Create( arena_, list_values, value_factory_, arena_)); } CelValue ValueFromMessage(const Struct* struct_value) { return CelValue::CreateMap(Arena::Create( arena_, struct_value, value_factory_, arena_)); } CelValue ValueFromAny(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetAnyReflection(message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); std::string type_url_scratch; std::string value_scratch; return ValueFromAny(reflection.GetTypeUrl(*message, type_url_scratch), reflection.GetValue(*message, value_scratch), descriptor_pool_, message_factory_); } CelValue ValueFromAny(const cel::well_known_types::StringValue& type_url, const cel::well_known_types::BytesValue& payload, const DescriptorPool* descriptor_pool, MessageFactory* message_factory) { std::string type_url_string_scratch; absl::string_view type_url_string = absl::visit( absl::Overload([](absl::string_view string) -> absl::string_view { return string; }, [&type_url_string_scratch]( const absl::Cord& cord) -> absl::string_view { if (auto flat = cord.TryFlat(); flat) { return *flat; } absl::CopyCordToString(cord, &type_url_string_scratch); return absl::string_view(type_url_string_scratch); }), cel::well_known_types::AsVariant(type_url)); auto pos = type_url_string.find_last_of('/'); if (pos == type_url_string.npos) { // TODO(issues/25) What error code? // Malformed type_url return CreateErrorValue(arena_, "Malformed type_url string"); } absl::string_view full_name = type_url_string.substr(pos + 1); const Descriptor* nested_descriptor = descriptor_pool->FindMessageTypeByName(full_name); if (nested_descriptor == nullptr) { // Descriptor not found for the type // TODO(issues/25) What error code? return CreateErrorValue(arena_, "Descriptor not found"); } const Message* prototype = message_factory->GetPrototype(nested_descriptor); if (prototype == nullptr) { // Failed to obtain prototype for the descriptor // TODO(issues/25) What error code? return CreateErrorValue(arena_, "Prototype not found"); } Message* nested_message = prototype->New(arena_); bool ok = absl::visit(absl::Overload( [nested_message](absl::string_view string) -> bool { return nested_message->ParsePartialFromString(string); }, [nested_message](const absl::Cord& cord) -> bool { return nested_message->ParsePartialFromString(cord); }), cel::well_known_types::AsVariant(payload)); if (!ok) { // Failed to unpack. // TODO(issues/25) What error code? return CreateErrorValue(arena_, "Failed to unpack Any into message"); } return UnwrapMessageToValue(nested_message, value_factory_, arena_); } CelValue ValueFromMessage(const Any* any_value, const DescriptorPool* descriptor_pool, MessageFactory* message_factory) { return ValueFromAny(any_value->type_url(), absl::Cord(any_value->value()), descriptor_pool, message_factory); } CelValue ValueFromMessage(const Any* any_value) { return ValueFromMessage(any_value, descriptor_pool_, message_factory_); } CelValue ValueFromBool(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetBoolValueReflection(message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromBool(reflection.GetValue(*message)); } static CelValue ValueFromBool(bool value) { return CelValue::CreateBool(value); } CelValue ValueFromMessage(const BoolValue* wrapper) { return ValueFromBool(wrapper->value()); } CelValue ValueFromInt32(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN(auto reflection, cel::well_known_types::GetInt32ValueReflection( message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromInt32(reflection.GetValue(*message)); } static CelValue ValueFromInt32(int32_t value) { return CelValue::CreateInt64(value); } CelValue ValueFromMessage(const Int32Value* wrapper) { return ValueFromInt32(wrapper->value()); } CelValue ValueFromUInt32(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN(auto reflection, cel::well_known_types::GetUInt32ValueReflection( message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromUInt32(reflection.GetValue(*message)); } static CelValue ValueFromUInt32(uint32_t value) { return CelValue::CreateUint64(value); } CelValue ValueFromMessage(const UInt32Value* wrapper) { return ValueFromUInt32(wrapper->value()); } CelValue ValueFromInt64(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN(auto reflection, cel::well_known_types::GetInt64ValueReflection( message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromInt64(reflection.GetValue(*message)); } static CelValue ValueFromInt64(int64_t value) { return CelValue::CreateInt64(value); } CelValue ValueFromMessage(const Int64Value* wrapper) { return ValueFromInt64(wrapper->value()); } CelValue ValueFromUInt64(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN(auto reflection, cel::well_known_types::GetUInt64ValueReflection( message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromUInt64(reflection.GetValue(*message)); } static CelValue ValueFromUInt64(uint64_t value) { return CelValue::CreateUint64(value); } CelValue ValueFromMessage(const UInt64Value* wrapper) { return ValueFromUInt64(wrapper->value()); } CelValue ValueFromFloat(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN(auto reflection, cel::well_known_types::GetFloatValueReflection( message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromFloat(reflection.GetValue(*message)); } static CelValue ValueFromFloat(float value) { return CelValue::CreateDouble(value); } CelValue ValueFromMessage(const FloatValue* wrapper) { return ValueFromFloat(wrapper->value()); } CelValue ValueFromDouble(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN(auto reflection, cel::well_known_types::GetDoubleValueReflection( message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); return ValueFromDouble(reflection.GetValue(*message)); } static CelValue ValueFromDouble(double value) { return CelValue::CreateDouble(value); } CelValue ValueFromMessage(const DoubleValue* wrapper) { return ValueFromDouble(wrapper->value()); } CelValue ValueFromString(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN(auto reflection, cel::well_known_types::GetStringValueReflection( message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); std::string scratch; return absl::visit( absl::Overload( [&](absl::string_view string) -> CelValue { if (string.data() == scratch.data() && string.size() == scratch.size()) { return CelValue::CreateString( google::protobuf::Arena::Create(arena_, std::move(scratch))); } return CelValue::CreateString(google::protobuf::Arena::Create( arena_, std::string(string))); }, [&](absl::Cord&& cord) -> CelValue { auto* string = google::protobuf::Arena::Create(arena_); absl::CopyCordToString(cord, string); return CelValue::CreateString(string); }), cel::well_known_types::AsVariant( reflection.GetValue(*message, scratch))); } CelValue ValueFromString(const absl::Cord& value) { return CelValue::CreateString( Arena::Create(arena_, static_cast(value))); } static CelValue ValueFromString(const std::string* value) { return CelValue::CreateString(value); } CelValue ValueFromMessage(const StringValue* wrapper) { return ValueFromString(&wrapper->value()); } CelValue ValueFromBytes(const google::protobuf::Message* message) { CEL_ASSIGN_OR_RETURN(auto reflection, cel::well_known_types::GetBytesValueReflection( message->GetDescriptor()), _.With(ReturnCelValueError(arena_))); std::string scratch; return absl::visit( absl::Overload( [&](absl::string_view string) -> CelValue { if (string.data() == scratch.data() && string.size() == scratch.size()) { return CelValue::CreateBytes(google::protobuf::Arena::Create( arena_, std::move(scratch))); } return CelValue::CreateBytes(google::protobuf::Arena::Create( arena_, std::string(string))); }, [&](absl::Cord&& cord) -> CelValue { auto* string = google::protobuf::Arena::Create(arena_); absl::CopyCordToString(cord, string); return CelValue::CreateBytes(string); }), cel::well_known_types::AsVariant( reflection.GetValue(*message, scratch))); } CelValue ValueFromBytes(const absl::Cord& value) { return CelValue::CreateBytes( Arena::Create(arena_, static_cast(value))); } static CelValue ValueFromBytes(google::protobuf::Arena* arena, std::string value) { return CelValue::CreateBytes( Arena::Create(arena, std::move(value))); } CelValue ValueFromMessage(const BytesValue* wrapper) { // BytesValue stores value as Cord return CelValue::CreateBytes( Arena::Create(arena_, std::string(wrapper->value()))); } CelValue ValueFromMessage(const Value* value) { switch (value->kind_case()) { case Value::KindCase::kNullValue: return CelValue::CreateNull(); case Value::KindCase::kNumberValue: return CelValue::CreateDouble(value->number_value()); case Value::KindCase::kStringValue: return CelValue::CreateString(&value->string_value()); case Value::KindCase::kBoolValue: return CelValue::CreateBool(value->bool_value()); case Value::KindCase::kStructValue: return ValueFromMessage(&value->struct_value()); case Value::KindCase::kListValue: return ValueFromMessage(&value->list_value()); default: return CelValue::CreateNull(); } } template CelValue ValueFromGeneratedMessageLite(const google::protobuf::Message* message) { const auto* downcast_message = google::protobuf::DynamicCastToGenerated(message); if (downcast_message != nullptr) { return ValueFromMessage(downcast_message); } auto* value = google::protobuf::Arena::Create(arena_); absl::Cord serialized; if (!message->SerializeToString(&serialized)) { return CreateErrorValue( arena_, absl::UnknownError( absl::StrCat("failed to serialize dynamic message: ", message->GetTypeName()))); } if (!value->ParseFromCord(serialized)) { return CreateErrorValue(arena_, absl::UnknownError(absl::StrCat( "failed to parse generated message: ", value->GetTypeName()))); } return ValueFromMessage(value); } template CelValue ValueFromMessage(const google::protobuf::Message* message) { if constexpr (std::is_same_v) { return ValueFromAny(message); } else if constexpr (std::is_same_v) { return ValueFromBool(message); } else if constexpr (std::is_same_v) { return ValueFromBytes(message); } else if constexpr (std::is_same_v) { return ValueFromDouble(message); } else if constexpr (std::is_same_v) { return ValueFromDuration(message); } else if constexpr (std::is_same_v) { return ValueFromFloat(message); } else if constexpr (std::is_same_v) { return ValueFromInt32(message); } else if constexpr (std::is_same_v) { return ValueFromInt64(message); } else if constexpr (std::is_same_v) { return ValueFromGeneratedMessageLite(message); } else if constexpr (std::is_same_v) { return ValueFromString(message); } else if constexpr (std::is_same_v) { return ValueFromGeneratedMessageLite(message); } else if constexpr (std::is_same_v) { return ValueFromTimestamp(message); } else if constexpr (std::is_same_v) { return ValueFromUInt32(message); } else if constexpr (std::is_same_v) { return ValueFromUInt64(message); } else if constexpr (std::is_same_v) { return ValueFromGeneratedMessageLite(message); } else { ABSL_UNREACHABLE(); } } private: const ProtobufValueFactory& value_factory_; const google::protobuf::DescriptorPool* descriptor_pool_; google::protobuf::Arena* arena_; MessageFactory* message_factory_; }; // Class makes CelValue from generic protobuf Message. // It holds a registry of CelValue factories for specific subtypes of Message. // If message does not match any of types stored in registry, generic // message-containing CelValue is created. class ValueFromMessageMaker { public: template static CelValue CreateWellknownTypeValue(const google::protobuf::Message* msg, const ProtobufValueFactory& factory, Arena* arena) { // Copy the original descriptor pool and message factory for unpacking 'Any' // values. google::protobuf::MessageFactory* message_factory = msg->GetReflection()->GetMessageFactory(); const google::protobuf::DescriptorPool* pool = msg->GetDescriptor()->file()->pool(); return ValueManager(factory, pool, arena, message_factory) .ValueFromMessage(msg); } static absl::optional CreateValue( const google::protobuf::Message* message, const ProtobufValueFactory& factory, Arena* arena) { switch (message->GetDescriptor()->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: return CreateWellknownTypeValue(message, factory, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: return CreateWellknownTypeValue(message, factory, arena); // WELLKNOWNTYPE_FIELDMASK has no special CelValue type default: return absl::nullopt; } } // Non-copyable, non-assignable ValueFromMessageMaker(const ValueFromMessageMaker&) = delete; ValueFromMessageMaker& operator=(const ValueFromMessageMaker&) = delete; }; CelValue DynamicList::operator[](int index) const { return ValueManager(factory_, arena_) .ValueFromMessage(&values_->values(index)); } absl::optional DynamicMap::operator[](CelValue key) const { CelValue::StringHolder str_key; if (!key.GetValue(&str_key)) { // Not a string key. return CreateErrorValue(arena_, absl::InvalidArgumentError(absl::StrCat( "Invalid map key type: '", CelValue::TypeName(key.type()), "'"))); } auto it = values_->fields().find(std::string(str_key.value())); if (it == values_->fields().end()) { return absl::nullopt; } return ValueManager(factory_, arena_).ValueFromMessage(&it->second); } google::protobuf::Message* DurationFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { absl::Duration val; if (!value.GetValue(&val)) { return nullptr; } if (!cel::internal::ValidateDuration(val).ok()) { return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetDurationReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.UnsafeSetFromAbslDuration(message, val); return message; } google::protobuf::Message* BoolFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { bool val; if (!value.GetValue(&val)) { return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetBoolValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, val); return message; } google::protobuf::Message* BytesFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { CelValue::BytesHolder view_val; if (!value.GetValue(&view_val)) { return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetBytesValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, view_val.value()); return message; } google::protobuf::Message* DoubleFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { double val; if (!value.GetValue(&val)) { return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetDoubleValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, val); return message; } google::protobuf::Message* FloatFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { double val; if (!value.GetValue(&val)) { return nullptr; } float fval = val; // Abort the conversion if the value is outside the float range. if (val > std::numeric_limits::max()) { fval = std::numeric_limits::infinity(); } else if (val < std::numeric_limits::lowest()) { fval = -std::numeric_limits::infinity(); } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetFloatValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, static_cast(fval)); return message; } google::protobuf::Message* Int32FromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { int64_t val; if (!value.GetValue(&val)) { return nullptr; } if (!cel::internal::CheckedInt64ToInt32(val).ok()) { return nullptr; } int32_t ival = static_cast(val); auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetInt32ValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, ival); return message; } google::protobuf::Message* Int64FromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { int64_t val; if (!value.GetValue(&val)) { return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetInt64ValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, val); return message; } google::protobuf::Message* StringFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { CelValue::StringHolder view_val; if (!value.GetValue(&view_val)) { return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetStringValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, view_val.value()); return message; } google::protobuf::Message* TimestampFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { absl::Time val; if (!value.GetValue(&val)) { return nullptr; } if (!cel::internal::ValidateTimestamp(val).ok()) { return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetTimestampReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.UnsafeSetFromAbslTime(message, val); return message; } google::protobuf::Message* UInt32FromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { uint64_t val; if (!value.GetValue(&val)) { return nullptr; } if (!cel::internal::CheckedUint64ToUint32(val).ok()) { return nullptr; } uint32_t ival = static_cast(val); auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetUInt32ValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, ival); return message; } google::protobuf::Message* UInt64FromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { uint64_t val; if (!value.GetValue(&val)) { return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetUInt64ValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetValue(message, val); return message; } google::protobuf::Message* ValueFromValue(google::protobuf::Message* message, const CelValue& value, google::protobuf::Arena* arena); google::protobuf::Message* ValueFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { return ValueFromValue(prototype->New(arena), value, arena); } google::protobuf::Message* ListFromValue(google::protobuf::Message* message, const CelValue& value, google::protobuf::Arena* arena) { if (!value.IsList()) { return nullptr; } const CelList& list = *value.ListOrDie(); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetListValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); for (int i = 0; i < list.size(); i++) { auto e = list.Get(arena, i); auto* elem = reflection.AddValues(message); if (ValueFromValue(elem, e, arena) == nullptr) { return nullptr; } } return message; } google::protobuf::Message* ListFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { if (!value.IsList()) { return nullptr; } return ListFromValue(prototype->New(arena), value, arena); } google::protobuf::Message* StructFromValue(google::protobuf::Message* message, const CelValue& value, google::protobuf::Arena* arena) { if (!value.IsMap()) { return nullptr; } const CelMap& map = *value.MapOrDie(); absl::StatusOr keys_or = map.ListKeys(arena); if (!keys_or.ok()) { // If map doesn't support listing keys, it can't pack into a Struct value. // This will surface as a CEL error when the object creation expression // fails. return nullptr; } const CelList& keys = **keys_or; CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetStructReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); for (int i = 0; i < keys.size(); i++) { auto k = keys.Get(arena, i); // If the key is not a string type, abort the conversion. if (!k.IsString()) { return nullptr; } absl::string_view key = k.StringOrDie().value(); auto v = map.Get(arena, k); if (!v.has_value()) { return nullptr; } auto* field = reflection.InsertField(message, key); if (ValueFromValue(field, *v, arena) == nullptr) { return nullptr; } } return message; } google::protobuf::Message* StructFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { if (!value.IsMap()) { return nullptr; } return StructFromValue(prototype->New(arena), value, arena); } google::protobuf::Message* ValueFromValue(google::protobuf::Message* message, const CelValue& value, google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetValueReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); switch (value.type()) { case CelValue::Type::kBool: { bool val; if (value.GetValue(&val)) { reflection.SetBoolValue(message, val); return message; } } break; case CelValue::Type::kBytes: { // Base64 encode byte strings to ensure they can safely be transported // in a JSON string. CelValue::BytesHolder val; if (value.GetValue(&val)) { reflection.SetStringValueFromBytes(message, val.value()); return message; } } break; case CelValue::Type::kDouble: { double val; if (value.GetValue(&val)) { reflection.SetNumberValue(message, val); return message; } } break; case CelValue::Type::kDuration: { // Convert duration values to a protobuf JSON format. absl::Duration val; if (value.GetValue(&val)) { CEL_RETURN_IF_ERROR(cel::internal::ValidateDuration(val)) .With(IgnoreErrorAndReturnNullptr()); reflection.SetStringValueFromDuration(message, val); return message; } } break; case CelValue::Type::kInt64: { int64_t val; // Convert int64_t values within the int53 range to doubles, otherwise // serialize the value to a string. if (value.GetValue(&val)) { reflection.SetNumberValue(message, val); return message; } } break; case CelValue::Type::kString: { CelValue::StringHolder val; if (value.GetValue(&val)) { reflection.SetStringValue(message, val.value()); return message; } } break; case CelValue::Type::kTimestamp: { // Convert timestamp values to a protobuf JSON format. absl::Time val; if (value.GetValue(&val)) { CEL_RETURN_IF_ERROR(cel::internal::ValidateTimestamp(val)) .With(IgnoreErrorAndReturnNullptr()); reflection.SetStringValueFromTimestamp(message, val); return message; } } break; case CelValue::Type::kUint64: { uint64_t val; // Convert uint64_t values within the int53 range to doubles, otherwise // serialize the value to a string. if (value.GetValue(&val)) { reflection.SetNumberValue(message, val); return message; } } break; case CelValue::Type::kList: { if (ListFromValue(reflection.MutableListValue(message), value, arena) != nullptr) { return message; } } break; case CelValue::Type::kMap: { if (StructFromValue(reflection.MutableStructValue(message), value, arena) != nullptr) { return message; } } break; case CelValue::Type::kNullType: reflection.SetNullValue(message); return message; break; default: return nullptr; } return nullptr; } bool ValueFromValue(Value* json, const CelValue& value, google::protobuf::Arena* arena); bool ListFromValue(ListValue* json_list, const CelValue& value, google::protobuf::Arena* arena) { if (!value.IsList()) { return false; } const CelList& list = *value.ListOrDie(); for (int i = 0; i < list.size(); i++) { auto e = list.Get(arena, i); Value* elem = json_list->add_values(); if (!ValueFromValue(elem, e, arena)) { return false; } } return true; } bool StructFromValue(Struct* json_struct, const CelValue& value, google::protobuf::Arena* arena) { if (!value.IsMap()) { return false; } const CelMap& map = *value.MapOrDie(); absl::StatusOr keys_or = map.ListKeys(arena); if (!keys_or.ok()) { // If map doesn't support listing keys, it can't pack into a Struct value. // This will surface as a CEL error when the object creation expression // fails. return false; } const CelList& keys = **keys_or; auto fields = json_struct->mutable_fields(); for (int i = 0; i < keys.size(); i++) { auto k = keys.Get(arena, i); // If the key is not a string type, abort the conversion. if (!k.IsString()) { return false; } absl::string_view key = k.StringOrDie().value(); auto v = map.Get(arena, k); if (!v.has_value()) { return false; } Value field_value; if (!ValueFromValue(&field_value, *v, arena)) { return false; } (*fields)[std::string(key)] = field_value; } return true; } bool ValueFromValue(Value* json, const CelValue& value, google::protobuf::Arena* arena) { switch (value.type()) { case CelValue::Type::kBool: { bool val; if (value.GetValue(&val)) { json->set_bool_value(val); return true; } } break; case CelValue::Type::kBytes: { // Base64 encode byte strings to ensure they can safely be transported // in a JSON string. CelValue::BytesHolder val; if (value.GetValue(&val)) { json->set_string_value(absl::Base64Escape(val.value())); return true; } } break; case CelValue::Type::kDouble: { double val; if (value.GetValue(&val)) { json->set_number_value(val); return true; } } break; case CelValue::Type::kDuration: { // Convert duration values to a protobuf JSON format. absl::Duration val; if (value.GetValue(&val)) { auto encode = cel::internal::EncodeDurationToString(val); if (!encode.ok()) { return false; } json->set_string_value(*encode); return true; } } break; case CelValue::Type::kInt64: { int64_t val; // Convert int64_t values within the int53 range to doubles, otherwise // serialize the value to a string. if (value.GetValue(&val)) { if (IsJSONSafe(val)) { json->set_number_value(val); } else { json->set_string_value(absl::StrCat(val)); } return true; } } break; case CelValue::Type::kString: { CelValue::StringHolder val; if (value.GetValue(&val)) { json->set_string_value(val.value()); return true; } } break; case CelValue::Type::kTimestamp: { // Convert timestamp values to a protobuf JSON format. absl::Time val; if (value.GetValue(&val)) { auto encode = cel::internal::EncodeTimeToString(val); if (!encode.ok()) { return false; } json->set_string_value(*encode); return true; } } break; case CelValue::Type::kUint64: { uint64_t val; // Convert uint64_t values within the int53 range to doubles, otherwise // serialize the value to a string. if (value.GetValue(&val)) { if (IsJSONSafe(val)) { json->set_number_value(val); } else { json->set_string_value(absl::StrCat(val)); } return true; } } break; case CelValue::Type::kList: return ListFromValue(json->mutable_list_value(), value, arena); case CelValue::Type::kMap: return StructFromValue(json->mutable_struct_value(), value, arena); case CelValue::Type::kNullType: json->set_null_value(protobuf::NULL_VALUE); return true; default: return false; } return false; } google::protobuf::Message* AnyFromValue(const google::protobuf::Message* prototype, const CelValue& value, google::protobuf::Arena* arena) { std::string type_name; absl::Cord payload; // In open source, any->PackFrom() returns void rather than boolean. switch (value.type()) { case CelValue::Type::kBool: { BoolValue v; type_name = v.GetTypeName(); v.set_value(value.BoolOrDie()); payload = v.SerializeAsCord(); } break; case CelValue::Type::kBytes: { BytesValue v; type_name = v.GetTypeName(); v.set_value(std::string(value.BytesOrDie().value())); payload = v.SerializeAsCord(); } break; case CelValue::Type::kDouble: { DoubleValue v; type_name = v.GetTypeName(); v.set_value(value.DoubleOrDie()); payload = v.SerializeAsCord(); } break; case CelValue::Type::kDuration: { Duration v; if (!cel::internal::EncodeDuration(value.DurationOrDie(), &v).ok()) { return nullptr; } type_name = v.GetTypeName(); payload = v.SerializeAsCord(); } break; case CelValue::Type::kInt64: { Int64Value v; type_name = v.GetTypeName(); v.set_value(value.Int64OrDie()); payload = v.SerializeAsCord(); } break; case CelValue::Type::kString: { StringValue v; type_name = v.GetTypeName(); v.set_value(std::string(value.StringOrDie().value())); payload = v.SerializeAsCord(); } break; case CelValue::Type::kTimestamp: { Timestamp v; if (!cel::internal::EncodeTime(value.TimestampOrDie(), &v).ok()) { return nullptr; } type_name = v.GetTypeName(); payload = v.SerializeAsCord(); } break; case CelValue::Type::kUint64: { UInt64Value v; type_name = v.GetTypeName(); v.set_value(value.Uint64OrDie()); payload = v.SerializeAsCord(); } break; case CelValue::Type::kList: { ListValue v; if (!ListFromValue(&v, value, arena)) { return nullptr; } type_name = v.GetTypeName(); payload = v.SerializeAsCord(); } break; case CelValue::Type::kMap: { Struct v; if (!StructFromValue(&v, value, arena)) { return nullptr; } type_name = v.GetTypeName(); payload = v.SerializeAsCord(); } break; case CelValue::Type::kNullType: { Value v; type_name = v.GetTypeName(); v.set_null_value(google::protobuf::NULL_VALUE); payload = v.SerializeAsCord(); } break; case CelValue::Type::kMessage: { type_name = value.MessageWrapperOrDie().message_ptr()->GetTypeName(); payload = value.MessageWrapperOrDie().message_ptr()->SerializeAsCord(); } break; default: return nullptr; } auto* message = prototype->New(arena); CEL_ASSIGN_OR_RETURN( auto reflection, cel::well_known_types::GetAnyReflection(message->GetDescriptor()), _.With(IgnoreErrorAndReturnNullptr())); reflection.SetTypeUrl(message, absl::StrCat("type.googleapis.com/", type_name)); reflection.SetValue(message, payload); return message; } bool IsAlreadyWrapped(google::protobuf::Descriptor::WellKnownType wkt, const CelValue& value) { if (value.IsMessage()) { const auto* msg = value.MessageOrDie(); if (wkt == msg->GetDescriptor()->well_known_type()) { return true; } } return false; } // MessageFromValueMaker makes a specific protobuf Message instance based on // the desired protobuf type name and an input CelValue. // // It holds a registry of CelValue factories for specific subtypes of Message. // If message does not match any of types stored in registry, an the factory // returns an absent value. class MessageFromValueMaker { public: // Non-copyable, non-assignable MessageFromValueMaker(const MessageFromValueMaker&) = delete; MessageFromValueMaker& operator=(const MessageFromValueMaker&) = delete; static google::protobuf::Message* MaybeWrapMessage(const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* factory, const CelValue& value, Arena* arena) { switch (descriptor->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return DoubleFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return FloatFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return Int64FromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return UInt64FromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return Int32FromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return UInt32FromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return StringFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return BytesFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return BoolFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return AnyFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return DurationFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return TimestampFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return ValueFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return ListFromValue(factory->GetPrototype(descriptor), value, arena); case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: if (IsAlreadyWrapped(descriptor->well_known_type(), value)) { return nullptr; } return StructFromValue(factory->GetPrototype(descriptor), value, arena); // WELLKNOWNTYPE_FIELDMASK has no special CelValue type default: return nullptr; } } }; } // namespace CelValue UnwrapMessageToValue(const google::protobuf::Message* value, const ProtobufValueFactory& factory, Arena* arena) { // Messages are Nullable types if (value == nullptr) { return CelValue::CreateNull(); } absl::optional special_value = ValueFromMessageMaker::CreateValue(value, factory, arena); if (special_value.has_value()) { return *special_value; } return factory(value); } const google::protobuf::Message* MaybeWrapValueToMessage( const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* factory, const CelValue& value, Arena* arena) { google::protobuf::Message* msg = MessageFromValueMaker::MaybeWrapMessage( descriptor, factory, value, arena); return msg; } } // namespace google::api::expr::runtime::internal ================================================ FILE: eval/public/structs/cel_proto_wrap_util.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ #include "eval/public/cel_value.h" #include "eval/public/structs/protobuf_value_factory.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime::internal { // UnwrapValue creates CelValue from google::protobuf::Message. // As some of CEL basic types are subclassing google::protobuf::Message, // this method contains type checking and downcasts. CelValue UnwrapMessageToValue(const google::protobuf::Message* value, const ProtobufValueFactory& factory, google::protobuf::Arena* arena); // MaybeWrapValue attempts to wrap the input value in a proto message with // the given type_name. If the value can be wrapped, it is returned as a // protobuf message. Otherwise, the result will be nullptr. // // This method is the complement to MaybeUnwrapValue which may unwrap a protobuf // message to native CelValue representation during a protobuf field read. // Just as CreateMessage should only be used when reading protobuf values, // MaybeWrapValue should only be used when assigning protobuf fields. const google::protobuf::Message* MaybeWrapValueToMessage( const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* factory, const CelValue& value, google::protobuf::Arena* arena); } // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAP_UTIL_H_ ================================================ FILE: eval/public/structs/cel_proto_wrap_util_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/cel_proto_wrap_util.h" #include #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/protobuf_value_factory.h" #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" #include "internal/proto_time_encoding.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "testutil/util.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime::internal { namespace { using ::testing::Eq; using ::testing::UnorderedPointwise; using google::protobuf::Duration; using google::protobuf::ListValue; using google::protobuf::Struct; using google::protobuf::Timestamp; using google::protobuf::Value; using google::protobuf::Any; using google::protobuf::BoolValue; using google::protobuf::BytesValue; using google::protobuf::DoubleValue; using google::protobuf::FloatValue; using google::protobuf::Int32Value; using google::protobuf::Int64Value; using google::protobuf::StringValue; using google::protobuf::UInt32Value; using google::protobuf::UInt64Value; using google::protobuf::Arena; CelValue ProtobufValueFactoryImpl(const google::protobuf::Message* m) { return CelValue::CreateMessageWrapper( CelValue::MessageWrapper(m, TrivialTypeInfo::GetInstance())); } class CelProtoWrapperTest : public ::testing::Test { protected: CelProtoWrapperTest() {} void ExpectWrappedMessage(const CelValue& value, const google::protobuf::Message& message) { // Test the input value wraps to the destination message type. auto* result = MaybeWrapValueToMessage( message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), value, arena()); EXPECT_TRUE(result != nullptr); EXPECT_THAT(result, testutil::EqualsProto(message)); // Ensure that double wrapping results in the object being wrapped once. auto* identity = MaybeWrapValueToMessage( message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), ProtobufValueFactoryImpl(result), arena()); EXPECT_TRUE(identity == nullptr); // Check to make sure that even dynamic messages can be used as input to // the wrapping call. result = MaybeWrapValueToMessage( ReflectedCopy(message)->GetDescriptor(), ReflectedCopy(message)->GetReflection()->GetMessageFactory(), value, arena()); EXPECT_TRUE(result != nullptr); EXPECT_THAT(result, testutil::EqualsProto(message)); } void ExpectNotWrapped(const CelValue& value, const google::protobuf::Message& message) { // Test the input value does not wrap by asserting value == result. auto result = MaybeWrapValueToMessage( message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), value, arena()); EXPECT_TRUE(result == nullptr); } template void ExpectUnwrappedPrimitive(const google::protobuf::Message& message, T result) { CelValue cel_value = UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()); T value; EXPECT_TRUE(cel_value.GetValue(&value)); EXPECT_THAT(value, Eq(result)); T dyn_value; CelValue cel_dyn_value = UnwrapMessageToValue( ReflectedCopy(message).get(), &ProtobufValueFactoryImpl, arena()); EXPECT_THAT(cel_dyn_value.type(), Eq(cel_value.type())); EXPECT_TRUE(cel_dyn_value.GetValue(&dyn_value)); EXPECT_THAT(value, Eq(dyn_value)); } void ExpectUnwrappedMessage(const google::protobuf::Message& message, google::protobuf::Message* result) { CelValue cel_value = UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()); if (result == nullptr) { EXPECT_TRUE(cel_value.IsNull()); return; } EXPECT_TRUE(cel_value.IsMessage()); EXPECT_THAT(cel_value.MessageOrDie(), testutil::EqualsProto(*result)); } std::unique_ptr ReflectedCopy( const google::protobuf::Message& message) { std::unique_ptr dynamic_value( factory_.GetPrototype(message.GetDescriptor())->New()); dynamic_value->CopyFrom(message); return dynamic_value; } Arena* arena() { return &arena_; } private: Arena arena_; google::protobuf::DynamicMessageFactory factory_; }; TEST_F(CelProtoWrapperTest, TestType) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); CelValue value_duration2 = UnwrapMessageToValue(&msg_duration, &ProtobufValueFactoryImpl, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); Timestamp msg_timestamp; msg_timestamp.set_seconds(2); msg_timestamp.set_nanos(3); CelValue value_timestamp2 = UnwrapMessageToValue(&msg_timestamp, &ProtobufValueFactoryImpl, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); } // This test verifies CelValue support of Duration type. TEST_F(CelProtoWrapperTest, TestDuration) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); CelValue value = UnwrapMessageToValue(&msg_duration, &ProtobufValueFactoryImpl, arena()); EXPECT_THAT(value.type(), Eq(CelValue::Type::kDuration)); Duration out; auto status = cel::internal::EncodeDuration(value.DurationOrDie(), &out); EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); } // This test verifies CelValue support of Timestamp type. TEST_F(CelProtoWrapperTest, TestTimestamp) { Timestamp msg_timestamp; msg_timestamp.set_seconds(2); msg_timestamp.set_nanos(3); CelValue value = UnwrapMessageToValue(&msg_timestamp, &ProtobufValueFactoryImpl, arena()); EXPECT_TRUE(value.IsTimestamp()); Timestamp out; auto status = cel::internal::EncodeTime(value.TimestampOrDie(), &out); EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); } // Dynamic Values test // TEST_F(CelProtoWrapperTest, UnwrapMessageToValueNull) { Value json; json.set_null_value(google::protobuf::NullValue::NULL_VALUE); ExpectUnwrappedMessage(json, nullptr); } // Test support for unwrapping a google::protobuf::Value to a CEL value. TEST_F(CelProtoWrapperTest, UnwrapDynamicValueNull) { Value value_msg; value_msg.set_null_value(protobuf::NULL_VALUE); CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), &ProtobufValueFactoryImpl, arena()); EXPECT_TRUE(value.IsNull()); } TEST_F(CelProtoWrapperTest, UnwrapMessageToValueBool) { bool value = true; Value json; json.set_bool_value(true); ExpectUnwrappedPrimitive(json, value); } TEST_F(CelProtoWrapperTest, UnwrapMessageToValueNumber) { double value = 1.0; Value json; json.set_number_value(value); ExpectUnwrappedPrimitive(json, value); } TEST_F(CelProtoWrapperTest, UnwrapMessageToValueString) { const std::string test = "test"; auto value = CelValue::StringHolder(&test); Value json; json.set_string_value(test); ExpectUnwrappedPrimitive(json, value); } TEST_F(CelProtoWrapperTest, UnwrapMessageToValueStruct) { const std::vector kFields = {"field1", "field2", "field3"}; Struct value_struct; auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; value1.set_bool_value(true); auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; value2.set_number_value(1.0); auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; value3.set_string_value("test"); CelValue value = UnwrapMessageToValue(&value_struct, &ProtobufValueFactoryImpl, arena()); ASSERT_TRUE(value.IsMap()); const CelMap* cel_map = value.MapOrDie(); CelValue field1 = CelValue::CreateString(&kFields[0]); auto field1_presence = cel_map->Has(field1); ASSERT_OK(field1_presence); EXPECT_TRUE(*field1_presence); auto lookup1 = (*cel_map)[field1]; ASSERT_TRUE(lookup1.has_value()); ASSERT_TRUE(lookup1->IsBool()); EXPECT_EQ(lookup1->BoolOrDie(), true); CelValue field2 = CelValue::CreateString(&kFields[1]); auto field2_presence = cel_map->Has(field2); ASSERT_OK(field2_presence); EXPECT_TRUE(*field2_presence); auto lookup2 = (*cel_map)[field2]; ASSERT_TRUE(lookup2.has_value()); ASSERT_TRUE(lookup2->IsDouble()); EXPECT_DOUBLE_EQ(lookup2->DoubleOrDie(), 1.0); CelValue field3 = CelValue::CreateString(&kFields[2]); auto field3_presence = cel_map->Has(field3); ASSERT_OK(field3_presence); EXPECT_TRUE(*field3_presence); auto lookup3 = (*cel_map)[field3]; ASSERT_TRUE(lookup3.has_value()); ASSERT_TRUE(lookup3->IsString()); EXPECT_EQ(lookup3->StringOrDie().value(), "test"); std::string missing = "missing_field"; CelValue missing_field = CelValue::CreateString(&missing); auto missing_field_presence = cel_map->Has(missing_field); ASSERT_OK(missing_field_presence); EXPECT_FALSE(*missing_field_presence); const CelList* key_list = cel_map->ListKeys().value(); ASSERT_EQ(key_list->size(), kFields.size()); std::vector result_keys; for (int i = 0; i < key_list->size(); i++) { CelValue key = (*key_list)[i]; ASSERT_TRUE(key.IsString()); result_keys.push_back(std::string(key.StringOrDie().value())); } EXPECT_THAT(result_keys, UnorderedPointwise(Eq(), kFields)); } // Test support for google::protobuf::Struct when it is created as dynamic // message TEST_F(CelProtoWrapperTest, UnwrapDynamicStruct) { Struct struct_msg; const std::string kFieldInt = "field_int"; const std::string kFieldBool = "field_bool"; (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); CelValue value = UnwrapMessageToValue(ReflectedCopy(struct_msg).get(), &ProtobufValueFactoryImpl, arena()); EXPECT_TRUE(value.IsMap()); const CelMap* cel_map = value.MapOrDie(); ASSERT_TRUE(cel_map != nullptr); { auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; ASSERT_TRUE(lookup.has_value()); auto v = lookup.value(); ASSERT_TRUE(v.IsDouble()); EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); } { auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; ASSERT_TRUE(lookup.has_value()); auto v = lookup.value(); ASSERT_TRUE(v.IsBool()); EXPECT_EQ(v.BoolOrDie(), true); } { auto presence = cel_map->Has(CelValue::CreateBool(true)); ASSERT_FALSE(presence.ok()); EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); auto lookup = (*cel_map)[CelValue::CreateBool(true)]; ASSERT_TRUE(lookup.has_value()); auto v = lookup.value(); ASSERT_TRUE(v.IsError()); } } TEST_F(CelProtoWrapperTest, UnwrapDynamicValueStruct) { const std::string kField1 = "field1"; const std::string kField2 = "field2"; Value value_msg; (*value_msg.mutable_struct_value()->mutable_fields())[kField1] .set_number_value(1); (*value_msg.mutable_struct_value()->mutable_fields())[kField2] .set_number_value(2); CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), &ProtobufValueFactoryImpl, arena()); EXPECT_TRUE(value.IsMap()); EXPECT_TRUE( (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); EXPECT_TRUE( (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); } TEST_F(CelProtoWrapperTest, UnwrapMessageToValueList) { const std::vector kFields = {"field1", "field2", "field3"}; ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); CelValue value = UnwrapMessageToValue(&list_value, &ProtobufValueFactoryImpl, arena()); ASSERT_TRUE(value.IsList()); const CelList* cel_list = value.ListOrDie(); ASSERT_EQ(cel_list->size(), 3); CelValue value1 = (*cel_list)[0]; ASSERT_TRUE(value1.IsBool()); EXPECT_EQ(value1.BoolOrDie(), true); auto value2 = (*cel_list)[1]; ASSERT_TRUE(value2.IsDouble()); EXPECT_DOUBLE_EQ(value2.DoubleOrDie(), 1.0); auto value3 = (*cel_list)[2]; ASSERT_TRUE(value3.IsString()); EXPECT_EQ(value3.StringOrDie().value(), "test"); } TEST_F(CelProtoWrapperTest, UnwrapDynamicValueListValue) { Value value_msg; value_msg.mutable_list_value()->add_values()->set_number_value(1.); value_msg.mutable_list_value()->add_values()->set_number_value(2.); CelValue value = UnwrapMessageToValue(ReflectedCopy(value_msg).get(), &ProtobufValueFactoryImpl, arena()); EXPECT_TRUE(value.IsList()); EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); } // Test support of google.protobuf.Any in CelValue. TEST_F(CelProtoWrapperTest, UnwrapAnyValue) { TestMessage test_message; test_message.set_string_value("test"); Any any; any.PackFrom(test_message); ExpectUnwrappedMessage(any, &test_message); } TEST_F(CelProtoWrapperTest, UnwrapInvalidAny) { Any any; CelValue value = UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()); ASSERT_TRUE(value.IsError()); any.set_type_url("/"); ASSERT_TRUE( UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()).IsError()); any.set_type_url("/invalid.proto.name"); ASSERT_TRUE( UnwrapMessageToValue(&any, &ProtobufValueFactoryImpl, arena()).IsError()); } // Test support of google.protobuf.Value wrappers in CelValue. TEST_F(CelProtoWrapperTest, UnwrapBoolWrapper) { bool value = true; BoolValue wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapInt32Wrapper) { int64_t value = 12; Int32Value wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapUInt32Wrapper) { uint64_t value = 12; UInt32Value wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapInt64Wrapper) { int64_t value = 12; Int64Value wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapUInt64Wrapper) { uint64_t value = 12; UInt64Value wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapFloatWrapper) { double value = 42.5; FloatValue wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapDoubleWrapper) { double value = 42.5; DoubleValue wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapStringWrapper) { std::string text = "42"; auto value = CelValue::StringHolder(&text); StringValue wrapper; wrapper.set_value(text); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapBytesWrapper) { std::string text = "42"; auto value = CelValue::BytesHolder(&text); BytesValue wrapper; wrapper.set_value("42"); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, WrapNull) { auto cel_value = CelValue::CreateNull(); Value json; json.set_null_value(protobuf::NULL_VALUE); ExpectWrappedMessage(cel_value, json); Any any; any.PackFrom(json); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapBool) { auto cel_value = CelValue::CreateBool(true); Value json; json.set_bool_value(true); ExpectWrappedMessage(cel_value, json); BoolValue wrapper; wrapper.set_value(true); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapBytes) { std::string str = "hello world"; auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); BytesValue wrapper; wrapper.set_value(str); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapBytesToValue) { std::string str = "hello world"; auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); Value json; json.set_string_value("aGVsbG8gd29ybGQ="); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapDuration) { auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); Duration d; d.set_seconds(300); ExpectWrappedMessage(cel_value, d); Any any; any.PackFrom(d); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapDurationToValue) { auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); Value json; json.set_string_value("300s"); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapDouble) { double num = 1.5; auto cel_value = CelValue::CreateDouble(num); Value json; json.set_number_value(num); ExpectWrappedMessage(cel_value, json); DoubleValue wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapDoubleToFloatValue) { double num = 1.5; auto cel_value = CelValue::CreateDouble(num); FloatValue wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); // Imprecise double -> float representation results in truncation. double small_num = -9.9e-100; wrapper.set_value(small_num); cel_value = CelValue::CreateDouble(small_num); ExpectWrappedMessage(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapDoubleOverflow) { double lowest_double = std::numeric_limits::lowest(); auto cel_value = CelValue::CreateDouble(lowest_double); // Double exceeds float precision, overflow to -infinity. FloatValue wrapper; wrapper.set_value(-std::numeric_limits::infinity()); ExpectWrappedMessage(cel_value, wrapper); double max_double = std::numeric_limits::max(); cel_value = CelValue::CreateDouble(max_double); wrapper.set_value(std::numeric_limits::infinity()); ExpectWrappedMessage(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapInt64) { int32_t num = std::numeric_limits::lowest(); auto cel_value = CelValue::CreateInt64(num); Value json; json.set_number_value(static_cast(num)); ExpectWrappedMessage(cel_value, json); Int64Value wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapInt64ToInt32Value) { int32_t num = std::numeric_limits::lowest(); auto cel_value = CelValue::CreateInt64(num); Int32Value wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapFailureInt64ToInt32Value) { int64_t num = std::numeric_limits::lowest(); auto cel_value = CelValue::CreateInt64(num); Int32Value wrapper; ExpectNotWrapped(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapInt64ToValue) { int64_t max = std::numeric_limits::max(); auto cel_value = CelValue::CreateInt64(max); Value json; json.set_string_value(absl::StrCat(max)); ExpectWrappedMessage(cel_value, json); int64_t min = std::numeric_limits::min(); cel_value = CelValue::CreateInt64(min); json.set_string_value(absl::StrCat(min)); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapUint64) { uint32_t num = std::numeric_limits::max(); auto cel_value = CelValue::CreateUint64(num); Value json; json.set_number_value(static_cast(num)); ExpectWrappedMessage(cel_value, json); UInt64Value wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapUint64ToUint32Value) { uint32_t num = std::numeric_limits::max(); auto cel_value = CelValue::CreateUint64(num); UInt32Value wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapUint64ToValue) { uint64_t num = std::numeric_limits::max(); auto cel_value = CelValue::CreateUint64(num); Value json; json.set_string_value(absl::StrCat(num)); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapFailureUint64ToUint32Value) { uint64_t num = std::numeric_limits::max(); auto cel_value = CelValue::CreateUint64(num); UInt32Value wrapper; ExpectNotWrapped(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapString) { std::string str = "test"; auto cel_value = CelValue::CreateString(CelValue::StringHolder(&str)); Value json; json.set_string_value(str); ExpectWrappedMessage(cel_value, json); StringValue wrapper; wrapper.set_value(str); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapTimestamp) { absl::Time ts = absl::FromUnixSeconds(1615852799); auto cel_value = CelValue::CreateTimestamp(ts); Timestamp t; t.set_seconds(1615852799); ExpectWrappedMessage(cel_value, t); Any any; any.PackFrom(t); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapTimestampToValue) { absl::Time ts = absl::FromUnixSeconds(1615852799); auto cel_value = CelValue::CreateTimestamp(ts); Value json; json.set_string_value("2021-03-15T23:59:59Z"); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapList) { std::vector list_elems = { CelValue::CreateDouble(1.5), CelValue::CreateInt64(-2L), }; ContainerBackedListImpl list(std::move(list_elems)); auto cel_value = CelValue::CreateList(&list); Value json; json.mutable_list_value()->add_values()->set_number_value(1.5); json.mutable_list_value()->add_values()->set_number_value(-2.); ExpectWrappedMessage(cel_value, json); ExpectWrappedMessage(cel_value, json.list_value()); Any any; any.PackFrom(json.list_value()); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapFailureListValueBadJSON) { TestMessage message; std::vector list_elems = { CelValue::CreateDouble(1.5), UnwrapMessageToValue(&message, &ProtobufValueFactoryImpl, arena()), }; ContainerBackedListImpl list(std::move(list_elems)); auto cel_value = CelValue::CreateList(&list); Value json; ExpectNotWrapped(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapStruct) { const std::string kField1 = "field1"; std::vector> args = { {CelValue::CreateString(CelValue::StringHolder(&kField1)), CelValue::CreateBool(true)}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); auto cel_value = CelValue::CreateMap(cel_map.get()); Value json; (*json.mutable_struct_value()->mutable_fields())[kField1].set_bool_value( true); ExpectWrappedMessage(cel_value, json); ExpectWrappedMessage(cel_value, json.struct_value()); Any any; any.PackFrom(json.struct_value()); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapFailureStructBadKeyType) { std::vector> args = { {CelValue::CreateInt64(1L), CelValue::CreateBool(true)}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); auto cel_value = CelValue::CreateMap(cel_map.get()); Value json; ExpectNotWrapped(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapFailureStructBadValueType) { const std::string kField1 = "field1"; TestMessage bad_value; std::vector> args = { {CelValue::CreateString(CelValue::StringHolder(&kField1)), UnwrapMessageToValue(&bad_value, &ProtobufValueFactoryImpl, arena())}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); auto cel_value = CelValue::CreateMap(cel_map.get()); Value json; ExpectNotWrapped(cel_value, json); } class TestMap : public CelMapBuilder { public: absl::StatusOr ListKeys() const override { return absl::UnimplementedError("test"); } }; TEST_F(CelProtoWrapperTest, WrapFailureStructListKeysUnimplemented) { const std::string kField1 = "field1"; TestMap map; ASSERT_OK(map.Add(CelValue::CreateString(CelValue::StringHolder(&kField1)), CelValue::CreateString(CelValue::StringHolder(&kField1)))); auto cel_value = CelValue::CreateMap(&map); Value json; ExpectNotWrapped(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapFailureWrongType) { auto cel_value = CelValue::CreateNull(); std::vector wrong_types = { &BoolValue::default_instance(), &BytesValue::default_instance(), &DoubleValue::default_instance(), &Duration::default_instance(), &FloatValue::default_instance(), &Int32Value::default_instance(), &Int64Value::default_instance(), &ListValue::default_instance(), &StringValue::default_instance(), &Struct::default_instance(), &Timestamp::default_instance(), &UInt32Value::default_instance(), &UInt64Value::default_instance(), }; for (const auto* wrong_type : wrong_types) { ExpectNotWrapped(cel_value, *wrong_type); } } TEST_F(CelProtoWrapperTest, WrapFailureErrorToAny) { auto cel_value = CreateNoSuchFieldError(arena(), "error_field"); ExpectNotWrapped(cel_value, Any::default_instance()); } TEST_F(CelProtoWrapperTest, DebugString) { google::protobuf::Empty e; // Note: the value factory is trivial so the debug string for a message-typed // value is uninteresting. EXPECT_EQ(UnwrapMessageToValue(&e, &ProtobufValueFactoryImpl, arena()) .DebugString(), "Message: opaque"); ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); CelValue value = UnwrapMessageToValue(&list_value, &ProtobufValueFactoryImpl, arena()); EXPECT_EQ(value.DebugString(), "CelList: [bool: 1, double: 1.000000, string: test]"); Struct value_struct; auto& value1 = (*value_struct.mutable_fields())["a"]; value1.set_bool_value(true); auto& value2 = (*value_struct.mutable_fields())["b"]; value2.set_number_value(1.0); auto& value3 = (*value_struct.mutable_fields())["c"]; value3.set_string_value("test"); value = UnwrapMessageToValue(&value_struct, &ProtobufValueFactoryImpl, arena()); EXPECT_THAT( value.DebugString(), testing::AllOf(testing::StartsWith("CelMap: {"), testing::HasSubstr(": "), testing::HasSubstr(": : "))); } } // namespace } // namespace google::api::expr::runtime::internal ================================================ FILE: eval/public/structs/cel_proto_wrapper.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/cel_proto_wrapper.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrap_util.h" #include "eval/public/structs/proto_message_type_adapter.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::google::protobuf::Arena; using ::google::protobuf::Descriptor; using ::google::protobuf::Message; } // namespace CelValue CelProtoWrapper::InternalWrapMessage(const Message* message) { return CelValue::CreateMessageWrapper( MessageWrapper(message, &GetGenericProtoTypeInfoInstance())); } // CreateMessage creates CelValue from google::protobuf::Message. // As some of CEL basic types are subclassing google::protobuf::Message, // this method contains type checking and downcasts. CelValue CelProtoWrapper::CreateMessage(const Message* value, Arena* arena) { return internal::UnwrapMessageToValue(value, &InternalWrapMessage, arena); } absl::optional CelProtoWrapper::MaybeWrapValue( const Descriptor* descriptor, google::protobuf::MessageFactory* factory, const CelValue& value, Arena* arena) { const Message* msg = internal::MaybeWrapValueToMessage(descriptor, factory, value, arena); if (msg != nullptr) { return InternalWrapMessage(msg); } else { return absl::nullopt; } } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/cel_proto_wrapper.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAPPER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAPPER_H_ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" #include "internal/proto_time_encoding.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { class CelProtoWrapper { public: // CreateMessage creates CelValue from google::protobuf::Message. // As some of CEL basic types are subclassing google::protobuf::Message, // this method contains type checking and downcasts. static CelValue CreateMessage(const google::protobuf::Message* value, google::protobuf::Arena* arena); // Internal utility for creating a CelValue wrapping a user defined type. // Assumes that the message has been properly unpacked. static CelValue InternalWrapMessage(const google::protobuf::Message* message); // CreateDuration creates CelValue from a non-null protobuf duration value. static CelValue CreateDuration(const google::protobuf::Duration* value) { return CelValue(cel::internal::DecodeDuration(*value)); } // CreateTimestamp creates CelValue from a non-null protobuf timestamp value. static CelValue CreateTimestamp(const google::protobuf::Timestamp* value) { return CelValue(cel::internal::DecodeTime(*value)); } // MaybeWrapValue attempts to wrap the input value in a proto message with // the given type_name. If the value can be wrapped, it is returned as a // CelValue pointing to the protobuf message. Otherwise, the result will be // empty. // // This method is the complement to CreateMessage which may unwrap a protobuf // message to native CelValue representation during a protobuf field read. // Just as CreateMessage should only be used when reading protobuf values, // MaybeWrapValue should only be used when assigning protobuf fields. static absl::optional MaybeWrapValue( const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* factory, const CelValue& value, google::protobuf::Arena* arena); }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_CEL_PROTO_WRAPPER_H_ ================================================ FILE: eval/public/structs/cel_proto_wrapper_test.cc ================================================ #include "eval/public/structs/cel_proto_wrapper.h" #include #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/testutil/test_message.pb.h" #include "internal/proto_time_encoding.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "testutil/util.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::testing::Eq; using ::testing::UnorderedPointwise; using google::protobuf::Duration; using google::protobuf::ListValue; using google::protobuf::Struct; using google::protobuf::Timestamp; using google::protobuf::Value; using google::protobuf::Any; using google::protobuf::BoolValue; using google::protobuf::BytesValue; using google::protobuf::DoubleValue; using google::protobuf::FloatValue; using google::protobuf::Int32Value; using google::protobuf::Int64Value; using google::protobuf::StringValue; using google::protobuf::UInt32Value; using google::protobuf::UInt64Value; using google::protobuf::Arena; class CelProtoWrapperTest : public ::testing::Test { protected: CelProtoWrapperTest() {} void ExpectWrappedMessage(const CelValue& value, const google::protobuf::Message& message) { // Test the input value wraps to the destination message type. auto result = CelProtoWrapper::MaybeWrapValue( message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), value, arena()); EXPECT_TRUE(result.has_value()); EXPECT_TRUE((*result).IsMessage()); EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); // Ensure that double wrapping results in the object being wrapped once. auto identity = CelProtoWrapper::MaybeWrapValue( message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), *result, arena()); EXPECT_FALSE(identity.has_value()); // Check to make sure that even dynamic messages can be used as input to // the wrapping call. result = CelProtoWrapper::MaybeWrapValue( ReflectedCopy(message)->GetDescriptor(), ReflectedCopy(message)->GetReflection()->GetMessageFactory(), value, arena()); EXPECT_TRUE(result.has_value()); EXPECT_TRUE((*result).IsMessage()); EXPECT_THAT((*result).MessageOrDie(), testutil::EqualsProto(message)); } void ExpectNotWrapped(const CelValue& value, const google::protobuf::Message& message) { // Test the input value does not wrap by asserting value == result. auto result = CelProtoWrapper::MaybeWrapValue( message.GetDescriptor(), message.GetReflection()->GetMessageFactory(), value, arena()); EXPECT_FALSE(result.has_value()); } template void ExpectUnwrappedPrimitive(const google::protobuf::Message& message, T result) { CelValue cel_value = CelProtoWrapper::CreateMessage(&message, arena()); T value; EXPECT_TRUE(cel_value.GetValue(&value)); EXPECT_THAT(value, Eq(result)); T dyn_value; CelValue cel_dyn_value = CelProtoWrapper::CreateMessage(ReflectedCopy(message).get(), arena()); EXPECT_THAT(cel_dyn_value.type(), Eq(cel_value.type())); EXPECT_TRUE(cel_dyn_value.GetValue(&dyn_value)); EXPECT_THAT(value, Eq(dyn_value)); } void ExpectUnwrappedMessage(const google::protobuf::Message& message, google::protobuf::Message* result) { CelValue cel_value = CelProtoWrapper::CreateMessage(&message, arena()); if (result == nullptr) { EXPECT_TRUE(cel_value.IsNull()); return; } EXPECT_TRUE(cel_value.IsMessage()); EXPECT_THAT(cel_value.MessageOrDie(), testutil::EqualsProto(*result)); } std::unique_ptr ReflectedCopy( const google::protobuf::Message& message) { std::unique_ptr dynamic_value( factory_.GetPrototype(message.GetDescriptor())->New()); dynamic_value->CopyFrom(message); return dynamic_value; } Arena* arena() { return &arena_; } private: Arena arena_; google::protobuf::DynamicMessageFactory factory_; }; TEST_F(CelProtoWrapperTest, TestType) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); CelValue value_duration1 = CelProtoWrapper::CreateDuration(&msg_duration); EXPECT_THAT(value_duration1.type(), Eq(CelValue::Type::kDuration)); CelValue value_duration2 = CelProtoWrapper::CreateMessage(&msg_duration, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); Timestamp msg_timestamp; msg_timestamp.set_seconds(2); msg_timestamp.set_nanos(3); CelValue value_timestamp1 = CelProtoWrapper::CreateTimestamp(&msg_timestamp); EXPECT_THAT(value_timestamp1.type(), Eq(CelValue::Type::kTimestamp)); CelValue value_timestamp2 = CelProtoWrapper::CreateMessage(&msg_timestamp, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); } // This test verifies CelValue support of Duration type. TEST_F(CelProtoWrapperTest, TestDuration) { Duration msg_duration; msg_duration.set_seconds(2); msg_duration.set_nanos(3); CelValue value_duration1 = CelProtoWrapper::CreateDuration(&msg_duration); EXPECT_THAT(value_duration1.type(), Eq(CelValue::Type::kDuration)); CelValue value_duration2 = CelProtoWrapper::CreateMessage(&msg_duration, arena()); EXPECT_THAT(value_duration2.type(), Eq(CelValue::Type::kDuration)); CelValue value = CelProtoWrapper::CreateDuration(&msg_duration); EXPECT_TRUE(value.IsDuration()); Duration out; auto status = cel::internal::EncodeDuration(value.DurationOrDie(), &out); EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_duration)); } // This test verifies CelValue support of Timestamp type. TEST_F(CelProtoWrapperTest, TestTimestamp) { Timestamp msg_timestamp; msg_timestamp.set_seconds(2); msg_timestamp.set_nanos(3); CelValue value_timestamp1 = CelProtoWrapper::CreateTimestamp(&msg_timestamp); EXPECT_THAT(value_timestamp1.type(), Eq(CelValue::Type::kTimestamp)); CelValue value_timestamp2 = CelProtoWrapper::CreateMessage(&msg_timestamp, arena()); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); CelValue value = CelProtoWrapper::CreateTimestamp(&msg_timestamp); // CelValue value = CelValue::CreateString("test"); EXPECT_TRUE(value.IsTimestamp()); Timestamp out; auto status = cel::internal::EncodeTime(value.TimestampOrDie(), &out); EXPECT_TRUE(status.ok()); EXPECT_THAT(out, testutil::EqualsProto(msg_timestamp)); } // Dynamic Values test // TEST_F(CelProtoWrapperTest, UnwrapValueNull) { Value json; json.set_null_value(google::protobuf::NullValue::NULL_VALUE); ExpectUnwrappedMessage(json, nullptr); } // Test support for unwrapping a google::protobuf::Value to a CEL value. TEST_F(CelProtoWrapperTest, UnwrapDynamicValueNull) { Value value_msg; value_msg.set_null_value(protobuf::NULL_VALUE); CelValue value = CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); EXPECT_TRUE(value.IsNull()); } TEST_F(CelProtoWrapperTest, UnwrapValueBool) { bool value = true; Value json; json.set_bool_value(true); ExpectUnwrappedPrimitive(json, value); } TEST_F(CelProtoWrapperTest, UnwrapValueNumber) { double value = 1.0; Value json; json.set_number_value(value); ExpectUnwrappedPrimitive(json, value); } TEST_F(CelProtoWrapperTest, UnwrapValueString) { const std::string test = "test"; auto value = CelValue::StringHolder(&test); Value json; json.set_string_value(test); ExpectUnwrappedPrimitive(json, value); } TEST_F(CelProtoWrapperTest, UnwrapValueStruct) { const std::vector kFields = {"field1", "field2", "field3"}; Struct value_struct; auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; value1.set_bool_value(true); auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; value2.set_number_value(1.0); auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; value3.set_string_value("test"); CelValue value = CelProtoWrapper::CreateMessage(&value_struct, arena()); ASSERT_TRUE(value.IsMap()); const CelMap* cel_map = value.MapOrDie(); CelValue field1 = CelValue::CreateString(&kFields[0]); auto field1_presence = cel_map->Has(field1); ASSERT_OK(field1_presence); EXPECT_TRUE(*field1_presence); auto lookup1 = (*cel_map)[field1]; ASSERT_TRUE(lookup1.has_value()); ASSERT_TRUE(lookup1->IsBool()); EXPECT_EQ(lookup1->BoolOrDie(), true); CelValue field2 = CelValue::CreateString(&kFields[1]); auto field2_presence = cel_map->Has(field2); ASSERT_OK(field2_presence); EXPECT_TRUE(*field2_presence); auto lookup2 = (*cel_map)[field2]; ASSERT_TRUE(lookup2.has_value()); ASSERT_TRUE(lookup2->IsDouble()); EXPECT_DOUBLE_EQ(lookup2->DoubleOrDie(), 1.0); CelValue field3 = CelValue::CreateString(&kFields[2]); auto field3_presence = cel_map->Has(field3); ASSERT_OK(field3_presence); EXPECT_TRUE(*field3_presence); auto lookup3 = (*cel_map)[field3]; ASSERT_TRUE(lookup3.has_value()); ASSERT_TRUE(lookup3->IsString()); EXPECT_EQ(lookup3->StringOrDie().value(), "test"); std::string missing = "missing_field"; CelValue missing_field = CelValue::CreateString(&missing); auto missing_field_presence = cel_map->Has(missing_field); ASSERT_OK(missing_field_presence); EXPECT_FALSE(*missing_field_presence); const CelList* key_list = cel_map->ListKeys().value(); ASSERT_EQ(key_list->size(), kFields.size()); std::vector result_keys; for (int i = 0; i < key_list->size(); i++) { CelValue key = (*key_list)[i]; ASSERT_TRUE(key.IsString()); result_keys.push_back(std::string(key.StringOrDie().value())); } EXPECT_THAT(result_keys, UnorderedPointwise(Eq(), kFields)); } // Test support for google::protobuf::Struct when it is created as dynamic // message TEST_F(CelProtoWrapperTest, UnwrapDynamicStruct) { Struct struct_msg; const std::string kFieldInt = "field_int"; const std::string kFieldBool = "field_bool"; (*struct_msg.mutable_fields())[kFieldInt].set_number_value(1.); (*struct_msg.mutable_fields())[kFieldBool].set_bool_value(true); CelValue value = CelProtoWrapper::CreateMessage(ReflectedCopy(struct_msg).get(), arena()); EXPECT_TRUE(value.IsMap()); const CelMap* cel_map = value.MapOrDie(); ASSERT_TRUE(cel_map != nullptr); { auto lookup = (*cel_map)[CelValue::CreateString(&kFieldInt)]; ASSERT_TRUE(lookup.has_value()); auto v = lookup.value(); ASSERT_TRUE(v.IsDouble()); EXPECT_THAT(v.DoubleOrDie(), testing::DoubleEq(1.)); } { auto lookup = (*cel_map)[CelValue::CreateString(&kFieldBool)]; ASSERT_TRUE(lookup.has_value()); auto v = lookup.value(); ASSERT_TRUE(v.IsBool()); EXPECT_EQ(v.BoolOrDie(), true); } { auto presence = cel_map->Has(CelValue::CreateBool(true)); ASSERT_FALSE(presence.ok()); EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); auto lookup = (*cel_map)[CelValue::CreateBool(true)]; ASSERT_TRUE(lookup.has_value()); auto v = lookup.value(); ASSERT_TRUE(v.IsError()); } } TEST_F(CelProtoWrapperTest, UnwrapDynamicValueStruct) { const std::string kField1 = "field1"; const std::string kField2 = "field2"; Value value_msg; (*value_msg.mutable_struct_value()->mutable_fields())[kField1] .set_number_value(1); (*value_msg.mutable_struct_value()->mutable_fields())[kField2] .set_number_value(2); CelValue value = CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); EXPECT_TRUE(value.IsMap()); EXPECT_TRUE( (*value.MapOrDie())[CelValue::CreateString(&kField1)].has_value()); EXPECT_TRUE( (*value.MapOrDie())[CelValue::CreateString(&kField2)].has_value()); } TEST_F(CelProtoWrapperTest, UnwrapValueList) { const std::vector kFields = {"field1", "field2", "field3"}; ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); CelValue value = CelProtoWrapper::CreateMessage(&list_value, arena()); ASSERT_TRUE(value.IsList()); const CelList* cel_list = value.ListOrDie(); ASSERT_EQ(cel_list->size(), 3); CelValue value1 = (*cel_list)[0]; ASSERT_TRUE(value1.IsBool()); EXPECT_EQ(value1.BoolOrDie(), true); auto value2 = (*cel_list)[1]; ASSERT_TRUE(value2.IsDouble()); EXPECT_DOUBLE_EQ(value2.DoubleOrDie(), 1.0); auto value3 = (*cel_list)[2]; ASSERT_TRUE(value3.IsString()); EXPECT_EQ(value3.StringOrDie().value(), "test"); } TEST_F(CelProtoWrapperTest, UnwrapDynamicValueListValue) { Value value_msg; value_msg.mutable_list_value()->add_values()->set_number_value(1.); value_msg.mutable_list_value()->add_values()->set_number_value(2.); CelValue value = CelProtoWrapper::CreateMessage(ReflectedCopy(value_msg).get(), arena()); EXPECT_TRUE(value.IsList()); EXPECT_THAT((*value.ListOrDie())[0].DoubleOrDie(), testing::DoubleEq(1)); EXPECT_THAT((*value.ListOrDie())[1].DoubleOrDie(), testing::DoubleEq(2)); } // Test support of google.protobuf.Any in CelValue. TEST_F(CelProtoWrapperTest, UnwrapAnyValue) { TestMessage test_message; test_message.set_string_value("test"); Any any; any.PackFrom(test_message); ExpectUnwrappedMessage(any, &test_message); } TEST_F(CelProtoWrapperTest, UnwrapInvalidAny) { Any any; CelValue value = CelProtoWrapper::CreateMessage(&any, arena()); ASSERT_TRUE(value.IsError()); any.set_type_url("/"); ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, arena()).IsError()); any.set_type_url("/invalid.proto.name"); ASSERT_TRUE(CelProtoWrapper::CreateMessage(&any, arena()).IsError()); } // Test support of google.protobuf.Value wrappers in CelValue. TEST_F(CelProtoWrapperTest, UnwrapBoolWrapper) { bool value = true; BoolValue wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapInt32Wrapper) { int64_t value = 12; Int32Value wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapUInt32Wrapper) { uint64_t value = 12; UInt32Value wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapInt64Wrapper) { int64_t value = 12; Int64Value wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapUInt64Wrapper) { uint64_t value = 12; UInt64Value wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapFloatWrapper) { double value = 42.5; FloatValue wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapDoubleWrapper) { double value = 42.5; DoubleValue wrapper; wrapper.set_value(value); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapStringWrapper) { std::string text = "42"; auto value = CelValue::StringHolder(&text); StringValue wrapper; wrapper.set_value(text); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, UnwrapBytesWrapper) { std::string text = "42"; auto value = CelValue::BytesHolder(&text); BytesValue wrapper; wrapper.set_value("42"); ExpectUnwrappedPrimitive(wrapper, value); } TEST_F(CelProtoWrapperTest, WrapNull) { auto cel_value = CelValue::CreateNull(); Value json; json.set_null_value(protobuf::NULL_VALUE); ExpectWrappedMessage(cel_value, json); Any any; any.PackFrom(json); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapBool) { auto cel_value = CelValue::CreateBool(true); Value json; json.set_bool_value(true); ExpectWrappedMessage(cel_value, json); BoolValue wrapper; wrapper.set_value(true); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapBytes) { std::string str = "hello world"; auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); BytesValue wrapper; wrapper.set_value(str); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapBytesToValue) { std::string str = "hello world"; auto cel_value = CelValue::CreateBytes(CelValue::BytesHolder(&str)); Value json; json.set_string_value("aGVsbG8gd29ybGQ="); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapDuration) { auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); Duration d; d.set_seconds(300); ExpectWrappedMessage(cel_value, d); Any any; any.PackFrom(d); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapDurationToValue) { auto cel_value = CelValue::CreateDuration(absl::Seconds(300)); Value json; json.set_string_value("300s"); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapDouble) { double num = 1.5; auto cel_value = CelValue::CreateDouble(num); Value json; json.set_number_value(num); ExpectWrappedMessage(cel_value, json); DoubleValue wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapDoubleToFloatValue) { double num = 1.5; auto cel_value = CelValue::CreateDouble(num); FloatValue wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); // Imprecise double -> float representation results in truncation. double small_num = -9.9e-100; wrapper.set_value(small_num); cel_value = CelValue::CreateDouble(small_num); ExpectWrappedMessage(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapDoubleOverflow) { double lowest_double = std::numeric_limits::lowest(); auto cel_value = CelValue::CreateDouble(lowest_double); // Double exceeds float precision, overflow to -infinity. FloatValue wrapper; wrapper.set_value(-std::numeric_limits::infinity()); ExpectWrappedMessage(cel_value, wrapper); double max_double = std::numeric_limits::max(); cel_value = CelValue::CreateDouble(max_double); wrapper.set_value(std::numeric_limits::infinity()); ExpectWrappedMessage(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapInt64) { int32_t num = std::numeric_limits::lowest(); auto cel_value = CelValue::CreateInt64(num); Value json; json.set_number_value(static_cast(num)); ExpectWrappedMessage(cel_value, json); Int64Value wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapInt64ToInt32Value) { int32_t num = std::numeric_limits::lowest(); auto cel_value = CelValue::CreateInt64(num); Int32Value wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapFailureInt64ToInt32Value) { int64_t num = std::numeric_limits::lowest(); auto cel_value = CelValue::CreateInt64(num); Int32Value wrapper; ExpectNotWrapped(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapInt64ToValue) { int64_t max = std::numeric_limits::max(); auto cel_value = CelValue::CreateInt64(max); Value json; json.set_string_value(absl::StrCat(max)); ExpectWrappedMessage(cel_value, json); int64_t min = std::numeric_limits::min(); cel_value = CelValue::CreateInt64(min); json.set_string_value(absl::StrCat(min)); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapUint64) { uint32_t num = std::numeric_limits::max(); auto cel_value = CelValue::CreateUint64(num); Value json; json.set_number_value(static_cast(num)); ExpectWrappedMessage(cel_value, json); UInt64Value wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapUint64ToUint32Value) { uint32_t num = std::numeric_limits::max(); auto cel_value = CelValue::CreateUint64(num); UInt32Value wrapper; wrapper.set_value(num); ExpectWrappedMessage(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapUint64ToValue) { uint64_t num = std::numeric_limits::max(); auto cel_value = CelValue::CreateUint64(num); Value json; json.set_string_value(absl::StrCat(num)); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapFailureUint64ToUint32Value) { uint64_t num = std::numeric_limits::max(); auto cel_value = CelValue::CreateUint64(num); UInt32Value wrapper; ExpectNotWrapped(cel_value, wrapper); } TEST_F(CelProtoWrapperTest, WrapString) { std::string str = "test"; auto cel_value = CelValue::CreateString(CelValue::StringHolder(&str)); Value json; json.set_string_value(str); ExpectWrappedMessage(cel_value, json); StringValue wrapper; wrapper.set_value(str); ExpectWrappedMessage(cel_value, wrapper); Any any; any.PackFrom(wrapper); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapTimestamp) { absl::Time ts = absl::FromUnixSeconds(1615852799); auto cel_value = CelValue::CreateTimestamp(ts); Timestamp t; t.set_seconds(1615852799); ExpectWrappedMessage(cel_value, t); Any any; any.PackFrom(t); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapTimestampToValue) { absl::Time ts = absl::FromUnixSeconds(1615852799); auto cel_value = CelValue::CreateTimestamp(ts); Value json; json.set_string_value("2021-03-15T23:59:59Z"); ExpectWrappedMessage(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapList) { std::vector list_elems = { CelValue::CreateDouble(1.5), CelValue::CreateInt64(-2L), }; ContainerBackedListImpl list(std::move(list_elems)); auto cel_value = CelValue::CreateList(&list); Value json; json.mutable_list_value()->add_values()->set_number_value(1.5); json.mutable_list_value()->add_values()->set_number_value(-2.); ExpectWrappedMessage(cel_value, json); ExpectWrappedMessage(cel_value, json.list_value()); Any any; any.PackFrom(json.list_value()); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapFailureListValueBadJSON) { TestMessage message; std::vector list_elems = { CelValue::CreateDouble(1.5), CelProtoWrapper::CreateMessage(&message, arena()), }; ContainerBackedListImpl list(std::move(list_elems)); auto cel_value = CelValue::CreateList(&list); Value json; ExpectNotWrapped(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapStruct) { const std::string kField1 = "field1"; std::vector> args = { {CelValue::CreateString(CelValue::StringHolder(&kField1)), CelValue::CreateBool(true)}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); auto cel_value = CelValue::CreateMap(cel_map.get()); Value json; (*json.mutable_struct_value()->mutable_fields())[kField1].set_bool_value( true); ExpectWrappedMessage(cel_value, json); ExpectWrappedMessage(cel_value, json.struct_value()); Any any; any.PackFrom(json.struct_value()); ExpectWrappedMessage(cel_value, any); } TEST_F(CelProtoWrapperTest, WrapFailureStructBadKeyType) { std::vector> args = { {CelValue::CreateInt64(1L), CelValue::CreateBool(true)}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); auto cel_value = CelValue::CreateMap(cel_map.get()); Value json; ExpectNotWrapped(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapFailureStructBadValueType) { const std::string kField1 = "field1"; TestMessage bad_value; std::vector> args = { {CelValue::CreateString(CelValue::StringHolder(&kField1)), CelProtoWrapper::CreateMessage(&bad_value, arena())}}; auto cel_map = CreateContainerBackedMap( absl::Span>(args.data(), args.size())) .value(); auto cel_value = CelValue::CreateMap(cel_map.get()); Value json; ExpectNotWrapped(cel_value, json); } TEST_F(CelProtoWrapperTest, WrapFailureWrongType) { auto cel_value = CelValue::CreateNull(); std::vector wrong_types = { &BoolValue::default_instance(), &BytesValue::default_instance(), &DoubleValue::default_instance(), &Duration::default_instance(), &FloatValue::default_instance(), &Int32Value::default_instance(), &Int64Value::default_instance(), &ListValue::default_instance(), &StringValue::default_instance(), &Struct::default_instance(), &Timestamp::default_instance(), &UInt32Value::default_instance(), &UInt64Value::default_instance(), }; for (const auto* wrong_type : wrong_types) { ExpectNotWrapped(cel_value, *wrong_type); } } TEST_F(CelProtoWrapperTest, WrapFailureErrorToAny) { auto cel_value = CreateNoSuchFieldError(arena(), "error_field"); ExpectNotWrapped(cel_value, Any::default_instance()); } // A CelMap implementation that returns an error for the ListKeys() method. class InvalidListKeysCelMapBuilder : public CelMapBuilder { public: absl::StatusOr ListKeys() const override { return absl::InternalError("Error while invoking ListKeys()"); } }; TEST_F(CelProtoWrapperTest, DebugString) { google::protobuf::Empty e; EXPECT_THAT(CelProtoWrapper::CreateMessage(&e, arena()).DebugString(), testing::StartsWith("Message: ")); ListValue list_value; list_value.add_values()->set_bool_value(true); list_value.add_values()->set_number_value(1.0); list_value.add_values()->set_string_value("test"); CelValue value = CelProtoWrapper::CreateMessage(&list_value, arena()); EXPECT_EQ(value.DebugString(), "CelList: [bool: 1, double: 1.000000, string: test]"); Struct value_struct; auto& value1 = (*value_struct.mutable_fields())["a"]; value1.set_bool_value(true); auto& value2 = (*value_struct.mutable_fields())["b"]; value2.set_number_value(1.0); auto& value3 = (*value_struct.mutable_fields())["c"]; value3.set_string_value("test"); value = CelProtoWrapper::CreateMessage(&value_struct, arena()); EXPECT_THAT( value.DebugString(), testing::AllOf(testing::StartsWith("CelMap: {"), testing::HasSubstr(": "), testing::HasSubstr(": : "))); // DebugString of a CelMap with an invalid internal list. InvalidListKeysCelMapBuilder invalid_cel_map; auto cel_map_value = CelValue::CreateMap(&invalid_cel_map); EXPECT_EQ(cel_map_value.DebugString(), "CelMap: invalid list keys"); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/dynamic_descriptor_pool_end_to_end_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "parser/parser.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" #include "google/protobuf/util/message_differencer.h" namespace google::api::expr::runtime { namespace { using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::protobuf::DescriptorPool; constexpr int32_t kStartingFieldNumber = 600; constexpr int32_t kIntFieldNumber = kStartingFieldNumber; constexpr int32_t kStringFieldNumber = kStartingFieldNumber + 1; constexpr int32_t kMessageFieldNumber = kStartingFieldNumber + 2; MATCHER_P(CelEqualsProto, msg, absl::StrCat("CEL Equals ", msg->ShortDebugString())) { const google::protobuf::Message* got = arg; const google::protobuf::Message* want = msg; return google::protobuf::util::MessageDifferencer::Equals(*got, *want); } // Simulate a dynamic descriptor pool with an alternate definition for a linked // type. absl::Status AddTestTypes(DescriptorPool& pool) { google::protobuf::FileDescriptorProto file_descriptor; TestAllTypes::descriptor()->file()->CopyTo(&file_descriptor); auto* message_type_entry = file_descriptor.mutable_message_type(0); auto* dynamic_int_field = message_type_entry->add_field(); dynamic_int_field->set_number(kIntFieldNumber); dynamic_int_field->set_name("dynamic_int_field"); dynamic_int_field->set_type(google::protobuf::FieldDescriptorProto::TYPE_INT64); auto* dynamic_string_field = message_type_entry->add_field(); dynamic_string_field->set_number(kStringFieldNumber); dynamic_string_field->set_name("dynamic_string_field"); dynamic_string_field->set_type(google::protobuf::FieldDescriptorProto::TYPE_STRING); auto* dynamic_message_field = message_type_entry->add_field(); dynamic_message_field->set_number(kMessageFieldNumber); dynamic_message_field->set_name("dynamic_message_field"); dynamic_message_field->set_type(google::protobuf::FieldDescriptorProto::TYPE_MESSAGE); dynamic_message_field->set_type_name( ".cel.expr.conformance.proto3.TestAllTypes"); CEL_RETURN_IF_ERROR(AddStandardMessageTypesToDescriptorPool(pool)); if (!pool.BuildFile(file_descriptor)) { return absl::InternalError( "failed initializing custom descriptor pool for test."); } return absl::OkStatus(); } class DynamicDescriptorPoolTest : public ::testing::Test { public: DynamicDescriptorPoolTest() : factory_(&descriptor_pool_) {} void SetUp() override { ASSERT_OK(AddTestTypes(descriptor_pool_)); } protected: absl::StatusOr> CreateMessageFromText( absl::string_view text_format) { const google::protobuf::Descriptor* dynamic_desc = descriptor_pool_.FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes"); auto message = absl::WrapUnique(factory_.GetPrototype(dynamic_desc)->New()); if (!google::protobuf::TextFormat::ParseFromString(text_format, message.get())) { return absl::InvalidArgumentError( "invalid text format for dynamic message"); } return message; } DescriptorPool descriptor_pool_; google::protobuf::DynamicMessageFactory factory_; google::protobuf::Arena arena_; }; TEST_F(DynamicDescriptorPoolTest, FieldAccess) { InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(std::unique_ptr message, CreateMessageFromText("dynamic_int_field: 42")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("msg.dynamic_int_field < 50")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation act; CelValue val = CelProtoWrapper::CreateMessage(message.get(), &arena_); act.InsertValue("msg", val); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); EXPECT_THAT(result, test::IsCelBool(true)); } TEST_F(DynamicDescriptorPoolTest, Create) { InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse( R"cel( TestAllTypes{ dynamic_int_field: 42, dynamic_string_field: "string", dynamic_message_field: TestAllTypes{dynamic_int_field: 50 } } )cel")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation act; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); ASSERT_OK_AND_ASSIGN(auto expected, CreateMessageFromText(R"pb( dynamic_int_field: 42 dynamic_string_field: "string" dynamic_message_field { dynamic_int_field: 50 } )pb")); EXPECT_THAT(result, test::IsCelMessage(CelEqualsProto(expected.get()))); } TEST_F(DynamicDescriptorPoolTest, AnyUnpack) { InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN( auto message, CreateMessageFromText(R"pb( single_any { [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 45 } } )pb")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("msg.single_any.dynamic_int_field < 50")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation act; CelValue val = CelProtoWrapper::CreateMessage(message.get(), &arena_); act.InsertValue("msg", val); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); EXPECT_THAT(result, test::IsCelBool(true)); } TEST_F(DynamicDescriptorPoolTest, AnyWrapperUnpack) { InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN( auto message, CreateMessageFromText(R"pb( single_any { [type.googleapis.com/google.protobuf.Int64Value] { value: 45 } } )pb")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("msg.single_any < 50")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation act; CelValue val = CelProtoWrapper::CreateMessage(message.get(), &arena_); act.InsertValue("msg", val); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); EXPECT_THAT(result, test::IsCelBool(true)); } TEST_F(DynamicDescriptorPoolTest, AnyUnpackRepeated) { InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN( auto message, CreateMessageFromText(R"pb( repeated_any { [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 0 } } repeated_any { [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 1 } } )pb")); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("msg.repeated_any.exists(x, x.dynamic_int_field > 2)")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation act; CelValue val = CelProtoWrapper::CreateMessage(message.get(), &arena_); act.InsertValue("msg", val); ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); EXPECT_THAT(result, test::IsCelBool(false)); } TEST_F(DynamicDescriptorPoolTest, AnyPack) { InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( TestAllTypes{ single_any: TestAllTypes{dynamic_int_field: 42} })cel")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation act; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); ASSERT_OK_AND_ASSIGN( auto expected_message, CreateMessageFromText(R"pb( single_any { [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 42 } } )pb")); EXPECT_THAT(result, test::IsCelMessage(CelEqualsProto(expected_message.get()))); } TEST_F(DynamicDescriptorPoolTest, AnyWrapperPack) { InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( TestAllTypes{ single_any: 42 })cel")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation act; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); ASSERT_OK_AND_ASSIGN( auto expected_message, CreateMessageFromText(R"pb( single_any { [type.googleapis.com/google.protobuf.Int64Value] { value: 42 } } )pb")); EXPECT_THAT(result, test::IsCelMessage(CelEqualsProto(expected_message.get()))); } TEST_F(DynamicDescriptorPoolTest, AnyPackRepeated) { InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(&descriptor_pool_, &factory_, options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); builder->set_container("cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(R"cel( TestAllTypes{ repeated_any: [ TestAllTypes{dynamic_int_field: 0}, TestAllTypes{dynamic_int_field: 1}, ] })cel")); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation act; ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(act, &arena_)); ASSERT_OK_AND_ASSIGN( auto expected_message, CreateMessageFromText(R"pb( repeated_any { [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 0 } } repeated_any { [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { dynamic_int_field: 1 } } )pb")); EXPECT_THAT(result, test::IsCelMessage(CelEqualsProto(expected_message.get()))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/field_access_impl.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/field_access_impl.h" #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "eval/public/structs/cel_proto_wrap_util.h" #include "internal/casts.h" #include "internal/overflow.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #undef GetMessage namespace google::api::expr::runtime::internal { namespace { using ::google::protobuf::Arena; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::MapValueConstRef; using ::google::protobuf::Message; using ::google::protobuf::Reflection; // Singular message fields and repeated message fields have similar access model // To provide common approach, we implement accessor classes, based on CRTP. // FieldAccessor is CRTP base class, specifying Get.. method family. template class FieldAccessor { public: bool GetBool() const { return static_cast(this)->GetBool(); } int64_t GetInt32() const { return static_cast(this)->GetInt32(); } uint64_t GetUInt32() const { return static_cast(this)->GetUInt32(); } int64_t GetInt64() const { return static_cast(this)->GetInt64(); } uint64_t GetUInt64() const { return static_cast(this)->GetUInt64(); } double GetFloat() const { return static_cast(this)->GetFloat(); } double GetDouble() const { return static_cast(this)->GetDouble(); } absl::string_view GetString(std::string* buffer) const { return static_cast(this)->GetString(buffer); } const Message* GetMessage() const { return static_cast(this)->GetMessage(); } int64_t GetEnumValue() const { return static_cast(this)->GetEnumValue(); } // This method provides message field content, wrapped in CelValue. // If value provided successfully, return a CelValue, otherwise returns a // status with non-ok status code. // // arena Arena to use for allocations if needed. absl::StatusOr CreateValueFromFieldAccessor(Arena* arena) { switch (field_desc_->cpp_type()) { case FieldDescriptor::CPPTYPE_BOOL: { bool value = GetBool(); return CelValue::CreateBool(value); } case FieldDescriptor::CPPTYPE_INT32: { int64_t value = GetInt32(); return CelValue::CreateInt64(value); } case FieldDescriptor::CPPTYPE_INT64: { int64_t value = GetInt64(); return CelValue::CreateInt64(value); } case FieldDescriptor::CPPTYPE_UINT32: { uint64_t value = GetUInt32(); return CelValue::CreateUint64(value); } case FieldDescriptor::CPPTYPE_UINT64: { uint64_t value = GetUInt64(); return CelValue::CreateUint64(value); } case FieldDescriptor::CPPTYPE_FLOAT: { double value = GetFloat(); return CelValue::CreateDouble(value); } case FieldDescriptor::CPPTYPE_DOUBLE: { double value = GetDouble(); return CelValue::CreateDouble(value); } case FieldDescriptor::CPPTYPE_STRING: { std::string buffer; absl::string_view value = GetString(&buffer); if (value.data() == buffer.data() && value.size() == buffer.size()) { value = absl::string_view( *google::protobuf::Arena::Create(arena, std::move(buffer))); } switch (field_desc_->type()) { case FieldDescriptor::TYPE_STRING: return CelValue::CreateStringView(value); case FieldDescriptor::TYPE_BYTES: return CelValue::CreateBytesView(value); default: break; } break; } case FieldDescriptor::CPPTYPE_MESSAGE: { const google::protobuf::Message* msg_value = GetMessage(); return UnwrapMessageToValue(msg_value, protobuf_value_factory_, arena); } case FieldDescriptor::CPPTYPE_ENUM: { int enum_value = GetEnumValue(); return CelValue::CreateInt64(enum_value); } default: break; } return absl::Status(absl::StatusCode::kInvalidArgument, "Unhandled C++ type conversion"); } protected: FieldAccessor(const Message* msg, const FieldDescriptor* field_desc, const ProtobufValueFactory& protobuf_value_factory) : msg_(msg), field_desc_(field_desc), protobuf_value_factory_(protobuf_value_factory) {} const Message* msg_; const FieldDescriptor* field_desc_; const ProtobufValueFactory& protobuf_value_factory_; }; const absl::flat_hash_set& WellKnownWrapperTypes() { static auto* wrapper_types = new absl::flat_hash_set{ "google.protobuf.BoolValue", "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value", "google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.StringValue", "google.protobuf.BytesValue", }; return *wrapper_types; } bool IsWrapperType(const FieldDescriptor* field_descriptor) { return WellKnownWrapperTypes().find( field_descriptor->message_type()->full_name()) != WellKnownWrapperTypes().end(); } // Accessor class, to work with singular fields class ScalarFieldAccessor : public FieldAccessor { public: ScalarFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, bool unset_wrapper_as_null, const ProtobufValueFactory& factory) : FieldAccessor(msg, field_desc, factory), unset_wrapper_as_null_(unset_wrapper_as_null) {} bool GetBool() const { return GetReflection()->GetBool(*msg_, field_desc_); } int64_t GetInt32() const { return GetReflection()->GetInt32(*msg_, field_desc_); } uint64_t GetUInt32() const { return GetReflection()->GetUInt32(*msg_, field_desc_); } int64_t GetInt64() const { return GetReflection()->GetInt64(*msg_, field_desc_); } uint64_t GetUInt64() const { return GetReflection()->GetUInt64(*msg_, field_desc_); } double GetFloat() const { return GetReflection()->GetFloat(*msg_, field_desc_); } double GetDouble() const { return GetReflection()->GetDouble(*msg_, field_desc_); } absl::string_view GetString(std::string* buffer) const { return GetReflection()->GetStringReference(*msg_, field_desc_, buffer); } const Message* GetMessage() const { // Unset wrapper types have special semantics. // If set, return the unwrapped value, else return 'null'. if (unset_wrapper_as_null_ && !GetReflection()->HasField(*msg_, field_desc_) && IsWrapperType(field_desc_)) { return nullptr; } return &GetReflection()->GetMessage(*msg_, field_desc_); } int64_t GetEnumValue() const { return GetReflection()->GetEnumValue(*msg_, field_desc_); } const Reflection* GetReflection() const { return msg_->GetReflection(); } private: bool unset_wrapper_as_null_; }; // Accessor class, to work with repeated fields. class RepeatedFieldAccessor : public FieldAccessor { public: RepeatedFieldAccessor(const Message* msg, const FieldDescriptor* field_desc, int index, const ProtobufValueFactory& factory) : FieldAccessor(msg, field_desc, factory), index_(index) {} bool GetBool() const { return GetReflection()->GetRepeatedBool(*msg_, field_desc_, index_); } int64_t GetInt32() const { return GetReflection()->GetRepeatedInt32(*msg_, field_desc_, index_); } uint64_t GetUInt32() const { return GetReflection()->GetRepeatedUInt32(*msg_, field_desc_, index_); } int64_t GetInt64() const { return GetReflection()->GetRepeatedInt64(*msg_, field_desc_, index_); } uint64_t GetUInt64() const { return GetReflection()->GetRepeatedUInt64(*msg_, field_desc_, index_); } double GetFloat() const { return GetReflection()->GetRepeatedFloat(*msg_, field_desc_, index_); } double GetDouble() const { return GetReflection()->GetRepeatedDouble(*msg_, field_desc_, index_); } absl::string_view GetString(std::string* buffer) const { return GetReflection()->GetRepeatedStringReference(*msg_, field_desc_, index_, buffer); } const Message* GetMessage() const { return &GetReflection()->GetRepeatedMessage(*msg_, field_desc_, index_); } int64_t GetEnumValue() const { return GetReflection()->GetRepeatedEnumValue(*msg_, field_desc_, index_); } const Reflection* GetReflection() const { return msg_->GetReflection(); } private: int index_; }; // Accessor class, to work with map values class MapValueAccessor : public FieldAccessor { public: MapValueAccessor(const Message* msg, const FieldDescriptor* field_desc, const MapValueConstRef* value_ref, const ProtobufValueFactory& factory) : FieldAccessor(msg, field_desc, factory), value_ref_(value_ref) {} bool GetBool() const { return value_ref_->GetBoolValue(); } int64_t GetInt32() const { return value_ref_->GetInt32Value(); } uint64_t GetUInt32() const { return value_ref_->GetUInt32Value(); } int64_t GetInt64() const { return value_ref_->GetInt64Value(); } uint64_t GetUInt64() const { return value_ref_->GetUInt64Value(); } double GetFloat() const { return value_ref_->GetFloatValue(); } double GetDouble() const { return value_ref_->GetDoubleValue(); } absl::string_view GetString(std::string* /*buffer*/) const { return value_ref_->GetStringValue(); } const Message* GetMessage() const { return &value_ref_->GetMessageValue(); } int64_t GetEnumValue() const { return value_ref_->GetEnumValue(); } const Reflection* GetReflection() const { return msg_->GetReflection(); } private: const MapValueConstRef* value_ref_; }; // Singular message fields and repeated message fields have similar access model // To provide common approach, we implement field setter classes, based on CRTP. // FieldAccessor is CRTP base class, specifying Get.. method family. template class FieldSetter { public: bool AssignBool(const CelValue& cel_value) const { bool value; if (!cel_value.GetValue(&value)) { return false; } static_cast(this)->SetBool(value); return true; } bool AssignInt32(const CelValue& cel_value) const { int64_t value; if (!cel_value.GetValue(&value)) { return false; } absl::StatusOr checked_cast = cel::internal::CheckedInt64ToInt32(value); if (!checked_cast.ok()) { return false; } static_cast(this)->SetInt32(*checked_cast); return true; } bool AssignUInt32(const CelValue& cel_value) const { uint64_t value; if (!cel_value.GetValue(&value)) { return false; } if (!cel::internal::CheckedUint64ToUint32(value).ok()) { return false; } static_cast(this)->SetUInt32(value); return true; } bool AssignInt64(const CelValue& cel_value) const { int64_t value; if (!cel_value.GetValue(&value)) { return false; } static_cast(this)->SetInt64(value); return true; } bool AssignUInt64(const CelValue& cel_value) const { uint64_t value; if (!cel_value.GetValue(&value)) { return false; } static_cast(this)->SetUInt64(value); return true; } bool AssignFloat(const CelValue& cel_value) const { double value; if (!cel_value.GetValue(&value)) { return false; } static_cast(this)->SetFloat(value); return true; } bool AssignDouble(const CelValue& cel_value) const { double value; if (!cel_value.GetValue(&value)) { return false; } static_cast(this)->SetDouble(value); return true; } bool AssignString(const CelValue& cel_value) const { CelValue::StringHolder value; if (!cel_value.GetValue(&value)) { return false; } static_cast(this)->SetString(value); return true; } bool AssignBytes(const CelValue& cel_value) const { CelValue::BytesHolder value; if (!cel_value.GetValue(&value)) { return false; } static_cast(this)->SetBytes(value); return true; } bool AssignEnum(const CelValue& cel_value) const { int64_t value; if (!cel_value.GetValue(&value)) { return false; } if (!cel::internal::CheckedInt64ToInt32(value).ok()) { return false; } static_cast(this)->SetEnum(value); return true; } bool AssignMessage(const google::protobuf::Message* message) const { return static_cast(this)->SetMessage(message); } // This method provides message field content, wrapped in CelValue. // If value provided successfully, returns Ok. // arena Arena to use for allocations if needed. // result pointer to object to store value in. bool SetFieldFromCelValue(const CelValue& value) { switch (field_desc_->cpp_type()) { case FieldDescriptor::CPPTYPE_BOOL: { return AssignBool(value); } case FieldDescriptor::CPPTYPE_INT32: { return AssignInt32(value); } case FieldDescriptor::CPPTYPE_INT64: { return AssignInt64(value); } case FieldDescriptor::CPPTYPE_UINT32: { return AssignUInt32(value); } case FieldDescriptor::CPPTYPE_UINT64: { return AssignUInt64(value); } case FieldDescriptor::CPPTYPE_FLOAT: { return AssignFloat(value); } case FieldDescriptor::CPPTYPE_DOUBLE: { return AssignDouble(value); } case FieldDescriptor::CPPTYPE_STRING: { switch (field_desc_->type()) { case FieldDescriptor::TYPE_STRING: return AssignString(value); case FieldDescriptor::TYPE_BYTES: return AssignBytes(value); default: return false; } break; } case FieldDescriptor::CPPTYPE_MESSAGE: { // When the field is a message, it might be a well-known type with a // non-proto representation that requires special handling before it // can be set on the field. const google::protobuf::Message* wrapped_value = MaybeWrapValueToMessage( field_desc_->message_type(), msg_->GetReflection()->GetMessageFactory(), value, arena_); if (wrapped_value == nullptr) { // It we aren't unboxing to a protobuf null representation, setting a // field to null is a no-op. if (value.IsNull()) { return true; } if (CelValue::MessageWrapper wrapper; value.GetValue(&wrapper) && wrapper.HasFullProto()) { wrapped_value = static_cast(wrapper.message_ptr()); } else { return false; } } return AssignMessage(wrapped_value); } case FieldDescriptor::CPPTYPE_ENUM: { return AssignEnum(value); } default: return false; } return true; } protected: FieldSetter(Message* msg, const FieldDescriptor* field_desc, Arena* arena) : msg_(msg), field_desc_(field_desc), arena_(arena) {} Message* msg_; const FieldDescriptor* field_desc_; Arena* arena_; }; bool MergeFromWithSerializeFallback(const google::protobuf::Message& value, google::protobuf::Message& field) { if (field.GetDescriptor() == value.GetDescriptor()) { field.MergeFrom(value); return true; } // TODO(uncreated-issue/26): this indicates means we're mixing dynamic messages with // generated messages. This is expected for WKTs where CEL explicitly requires // wire format compatibility, but this may not be the expected behavior for // other types. return field.MergeFromString(value.SerializeAsString()); } // Accessor class, to work with singular fields class ScalarFieldSetter : public FieldSetter { public: ScalarFieldSetter(Message* msg, const FieldDescriptor* field_desc, Arena* arena) : FieldSetter(msg, field_desc, arena) {} bool SetBool(bool value) const { GetReflection()->SetBool(msg_, field_desc_, value); return true; } bool SetInt32(int32_t value) const { GetReflection()->SetInt32(msg_, field_desc_, value); return true; } bool SetUInt32(uint32_t value) const { GetReflection()->SetUInt32(msg_, field_desc_, value); return true; } bool SetInt64(int64_t value) const { GetReflection()->SetInt64(msg_, field_desc_, value); return true; } bool SetUInt64(uint64_t value) const { GetReflection()->SetUInt64(msg_, field_desc_, value); return true; } bool SetFloat(float value) const { GetReflection()->SetFloat(msg_, field_desc_, value); return true; } bool SetDouble(double value) const { GetReflection()->SetDouble(msg_, field_desc_, value); return true; } bool SetString(CelValue::StringHolder value) const { GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); return true; } bool SetBytes(CelValue::BytesHolder value) const { GetReflection()->SetString(msg_, field_desc_, std::string(value.value())); return true; } bool SetMessage(const Message* value) const { if (!value) { ABSL_LOG(ERROR) << "Message is NULL"; return true; } if (value->GetDescriptor()->full_name() == field_desc_->message_type()->full_name()) { auto* assignable_field_msg = GetReflection()->MutableMessage(msg_, field_desc_); return MergeFromWithSerializeFallback(*value, *assignable_field_msg); } return false; } bool SetEnum(const int64_t value) const { GetReflection()->SetEnumValue(msg_, field_desc_, value); return true; } const Reflection* GetReflection() const { return msg_->GetReflection(); } }; // Appender class, to work with repeated fields class RepeatedFieldSetter : public FieldSetter { public: RepeatedFieldSetter(Message* msg, const FieldDescriptor* field_desc, Arena* arena) : FieldSetter(msg, field_desc, arena) {} bool SetBool(bool value) const { GetReflection()->AddBool(msg_, field_desc_, value); return true; } bool SetInt32(int32_t value) const { GetReflection()->AddInt32(msg_, field_desc_, value); return true; } bool SetUInt32(uint32_t value) const { GetReflection()->AddUInt32(msg_, field_desc_, value); return true; } bool SetInt64(int64_t value) const { GetReflection()->AddInt64(msg_, field_desc_, value); return true; } bool SetUInt64(uint64_t value) const { GetReflection()->AddUInt64(msg_, field_desc_, value); return true; } bool SetFloat(float value) const { GetReflection()->AddFloat(msg_, field_desc_, value); return true; } bool SetDouble(double value) const { GetReflection()->AddDouble(msg_, field_desc_, value); return true; } bool SetString(CelValue::StringHolder value) const { GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); return true; } bool SetBytes(CelValue::BytesHolder value) const { GetReflection()->AddString(msg_, field_desc_, std::string(value.value())); return true; } bool SetMessage(const Message* value) const { if (!value) return true; if (value->GetDescriptor()->full_name() != field_desc_->message_type()->full_name()) { return false; } auto* assignable_message = GetReflection()->AddMessage(msg_, field_desc_); return MergeFromWithSerializeFallback(*value, *assignable_message); } bool SetEnum(const int64_t value) const { GetReflection()->AddEnumValue(msg_, field_desc_, value); return true; } private: const Reflection* GetReflection() const { return msg_->GetReflection(); } }; } // namespace absl::StatusOr CreateValueFromSingleField( const google::protobuf::Message* msg, const FieldDescriptor* desc, ProtoWrapperTypeOptions options, const ProtobufValueFactory& factory, google::protobuf::Arena* arena) { ScalarFieldAccessor accessor( msg, desc, (options == ProtoWrapperTypeOptions::kUnsetNull), factory); return accessor.CreateValueFromFieldAccessor(arena); } absl::StatusOr CreateValueFromRepeatedField( const google::protobuf::Message* msg, const FieldDescriptor* desc, int index, const ProtobufValueFactory& factory, google::protobuf::Arena* arena) { RepeatedFieldAccessor accessor(msg, desc, index, factory); return accessor.CreateValueFromFieldAccessor(arena); } absl::StatusOr CreateValueFromMapValue( const google::protobuf::Message* msg, const FieldDescriptor* desc, const MapValueConstRef* value_ref, const ProtobufValueFactory& factory, google::protobuf::Arena* arena) { MapValueAccessor accessor(msg, desc, value_ref, factory); return accessor.CreateValueFromFieldAccessor(arena); } absl::Status SetValueToSingleField(const CelValue& value, const FieldDescriptor* desc, Message* msg, Arena* arena) { ScalarFieldSetter setter(msg, desc, arena); return (setter.SetFieldFromCelValue(value)) ? absl::OkStatus() : absl::InvalidArgumentError(absl::Substitute( "Could not assign supplied argument to message \"$0\" field " "\"$1\" of type $2: value type \"$3\"", msg->GetDescriptor()->name(), desc->name(), desc->type_name(), CelValue::TypeName(value.type()))); } absl::Status AddValueToRepeatedField(const CelValue& value, const FieldDescriptor* desc, Message* msg, Arena* arena) { RepeatedFieldSetter setter(msg, desc, arena); return (setter.SetFieldFromCelValue(value)) ? absl::OkStatus() : absl::InvalidArgumentError(absl::Substitute( "Could not add supplied argument to message \"$0\" field " "\"$1\" of type $2: value type \"$3\"", msg->GetDescriptor()->name(), desc->name(), desc->type_name(), CelValue::TypeName(value.type()))); } } // namespace google::api::expr::runtime::internal ================================================ FILE: eval/public/structs/field_access_impl.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/protobuf_value_factory.h" namespace google::api::expr::runtime::internal { // Creates CelValue from singular message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // options Option to enable treating unset wrapper type fields as null. // arena Arena object to allocate result on, if needed. // result pointer to CelValue to store the result in. absl::StatusOr CreateValueFromSingleField( const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, ProtoWrapperTypeOptions options, const ProtobufValueFactory& factory, google::protobuf::Arena* arena); // Creates CelValue from repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // arena Arena object to allocate result on, if needed. // index position in the repeated field. absl::StatusOr CreateValueFromRepeatedField( const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, int index, const ProtobufValueFactory& factory, google::protobuf::Arena* arena); // Creates CelValue from map message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // value_ref pointer to map value. // arena Arena object to allocate result on, if needed. // TODO(uncreated-issue/7): This should be inlined into the FieldBackedMap // implementation. absl::StatusOr CreateValueFromMapValue( const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, const google::protobuf::MapValueConstRef* value_ref, const ProtobufValueFactory& factory, google::protobuf::Arena* arena); // Assigns content of CelValue to singular message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // arena Arena to perform allocations, if necessary, when setting the field. absl::Status SetValueToSingleField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, google::protobuf::Message* msg, google::protobuf::Arena* arena); // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. // arena Arena to perform allocations, if necessary, when adding the value. absl::Status AddValueToRepeatedField(const CelValue& value, const google::protobuf::FieldDescriptor* desc, google::protobuf::Message* msg, google::protobuf::Arena* arena); } // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_FIELD_ACCESS_IMPL_H_ ================================================ FILE: eval/public/structs/field_access_impl_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/field_access_impl.h" #include #include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "internal/time.h" #include "testutil/util.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace google::api::expr::runtime::internal { namespace { using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::MaxDuration; using ::cel::internal::MaxTimestamp; using ::google::protobuf::Arena; using ::google::protobuf::FieldDescriptor; using ::testing::HasSubstr; using testutil::EqualsProto; TEST(FieldAccessTest, SetDuration) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); auto status = SetValueToSingleField(CelValue::CreateDuration(MaxDuration()), field, &msg, &arena); EXPECT_TRUE(status.ok()); } TEST(FieldAccessTest, SetDurationBadDuration) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); auto status = SetValueToSingleField( CelValue::CreateDuration(MaxDuration() + absl::Seconds(1)), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetDurationBadInputType) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_duration"); auto status = SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetTimestamp) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); auto status = SetValueToSingleField(CelValue::CreateTimestamp(MaxTimestamp()), field, &msg, &arena); EXPECT_TRUE(status.ok()); } TEST(FieldAccessTest, SetTimestampBadTime) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); auto status = SetValueToSingleField( CelValue::CreateTimestamp(MaxTimestamp() + absl::Seconds(1)), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetTimestampBadInputType) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_timestamp"); auto status = SetValueToSingleField(CelValue::CreateInt64(1), field, &msg, &arena); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(FieldAccessTest, SetInt32Overflow) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_int32"); EXPECT_THAT( SetValueToSingleField( CelValue::CreateInt64(std::numeric_limits::max() + 1L), field, &msg, &arena), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Could not assign"))); } TEST(FieldAccessTest, SetUint32Overflow) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("single_uint32"); EXPECT_THAT( SetValueToSingleField( CelValue::CreateUint64(std::numeric_limits::max() + 1L), field, &msg, &arena), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Could not assign"))); } TEST(FieldAccessTest, SetMessage) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("standalone_message"); TestAllTypes::NestedMessage* nested_msg = google::protobuf::Arena::Create(&arena); nested_msg->set_bb(1); auto status = SetValueToSingleField( CelProtoWrapper::CreateMessage(nested_msg, &arena), field, &msg, &arena); EXPECT_TRUE(status.ok()); } TEST(FieldAccessTest, SetMessageWithNull) { Arena arena; TestAllTypes msg; const FieldDescriptor* field = TestAllTypes::descriptor()->FindFieldByName("standalone_message"); auto status = SetValueToSingleField(CelValue::CreateNull(), field, &msg, &arena); EXPECT_TRUE(status.ok()); } struct AccessFieldTestParam { absl::string_view field_name; absl::string_view message_textproto; CelValue cel_value; }; std::string GetTestName( const testing::TestParamInfo& info) { return std::string(info.param.field_name); } class SingleFieldTest : public testing::TestWithParam { public: absl::string_view field_name() const { return GetParam().field_name; } absl::string_view message_textproto() const { return GetParam().message_textproto; } CelValue cel_value() const { return GetParam().cel_value; } }; TEST_P(SingleFieldTest, Getter) { TestAllTypes test_message; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(message_textproto(), &test_message)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( CelValue accessed_value, CreateValueFromSingleField( &test_message, test_message.GetDescriptor()->FindFieldByName(field_name()), ProtoWrapperTypeOptions::kUnsetProtoDefault, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); } TEST_P(SingleFieldTest, Setter) { TestAllTypes test_message; CelValue to_set = cel_value(); google::protobuf::Arena arena; ASSERT_OK(SetValueToSingleField( to_set, test_message.GetDescriptor()->FindFieldByName(field_name()), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto(message_textproto())); } INSTANTIATE_TEST_SUITE_P( AllTypes, SingleFieldTest, testing::ValuesIn({ {"single_int32", "single_int32: 1", CelValue::CreateInt64(1)}, {"single_int64", "single_int64: 1", CelValue::CreateInt64(1)}, {"single_uint32", "single_uint32: 1", CelValue::CreateUint64(1)}, {"single_uint64", "single_uint64: 1", CelValue::CreateUint64(1)}, {"single_sint32", "single_sint32: 1", CelValue::CreateInt64(1)}, {"single_sint64", "single_sint64: 1", CelValue::CreateInt64(1)}, {"single_fixed32", "single_fixed32: 1", CelValue::CreateUint64(1)}, {"single_fixed64", "single_fixed64: 1", CelValue::CreateUint64(1)}, {"single_sfixed32", "single_sfixed32: 1", CelValue::CreateInt64(1)}, {"single_sfixed64", "single_sfixed64: 1", CelValue::CreateInt64(1)}, {"single_float", "single_float: 1.0", CelValue::CreateDouble(1.0)}, {"single_double", "single_double: 1.0", CelValue::CreateDouble(1.0)}, {"single_bool", "single_bool: true", CelValue::CreateBool(true)}, {"single_string", "single_string: 'abcd'", CelValue::CreateStringView("abcd")}, {"single_bytes", "single_bytes: 'asdf'", CelValue::CreateBytesView("asdf")}, {"standalone_enum", "standalone_enum: BAZ", CelValue::CreateInt64(2)}, // Basic coverage for unwrapping -- specifics are managed by the // wrapping library. {"single_int64_wrapper", "single_int64_wrapper { value: 20 }", CelValue::CreateInt64(20)}, {"single_value", "single_value { null_value: NULL_VALUE }", CelValue::CreateNull()}, }), &GetTestName); TEST(CreateValueFromSingleFieldTest, GetMessage) { TestAllTypes test_message; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( "standalone_message { bb: 10 }", &test_message)); ASSERT_OK_AND_ASSIGN( CelValue accessed_value, CreateValueFromSingleField( &test_message, test_message.GetDescriptor()->FindFieldByName("standalone_message"), ProtoWrapperTypeOptions::kUnsetProtoDefault, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::IsCelMessage(EqualsProto("bb: 10"))); } TEST(SetValueToSingleFieldTest, WrongType) { TestAllTypes test_message; google::protobuf::Arena arena; EXPECT_THAT(SetValueToSingleField( CelValue::CreateDouble(1.0), test_message.GetDescriptor()->FindFieldByName("single_int32"), &test_message, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(SetValueToSingleFieldTest, IntOutOfRange) { CelValue out_of_range = CelValue::CreateInt64(1LL << 31); TestAllTypes test_message; const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); google::protobuf::Arena arena; EXPECT_THAT(SetValueToSingleField(out_of_range, descriptor->FindFieldByName("single_int32"), &test_message, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); // proto enums are are represented as int32, but CEL converts to/from int64. EXPECT_THAT(SetValueToSingleField( out_of_range, descriptor->FindFieldByName("standalone_enum"), &test_message, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(SetValueToSingleFieldTest, UintOutOfRange) { CelValue out_of_range = CelValue::CreateUint64(1LL << 32); TestAllTypes test_message; const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); google::protobuf::Arena arena; EXPECT_THAT(SetValueToSingleField( out_of_range, descriptor->FindFieldByName("single_uint32"), &test_message, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(SetValueToSingleFieldTest, SetMessage) { TestAllTypes::NestedMessage nested_message; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( bb: 42 )", &nested_message)); google::protobuf::Arena arena; CelValue nested_value = CelProtoWrapper::CreateMessage(&nested_message, &arena); TestAllTypes test_message; const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); ASSERT_OK(SetValueToSingleField( nested_value, descriptor->FindFieldByName("standalone_message"), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto("standalone_message { bb: 42 }")); } TEST(SetValueToSingleFieldTest, SetAnyMessage) { TestAllTypes::NestedMessage nested_message; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( bb: 42 )", &nested_message)); google::protobuf::Arena arena; CelValue nested_value = CelProtoWrapper::CreateMessage(&nested_message, &arena); TestAllTypes test_message; const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); ASSERT_OK(SetValueToSingleField(nested_value, descriptor->FindFieldByName("single_any"), &test_message, &arena)); TestAllTypes::NestedMessage unpacked; test_message.single_any().UnpackTo(&unpacked); EXPECT_THAT(unpacked, EqualsProto("bb: 42")); } TEST(SetValueToSingleFieldTest, SetMessageToNullNoop) { google::protobuf::Arena arena; TestAllTypes test_message; const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); ASSERT_OK(SetValueToSingleField( CelValue::CreateNull(), descriptor->FindFieldByName("standalone_message"), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto(test_message.default_instance())); } class RepeatedFieldTest : public testing::TestWithParam { public: absl::string_view field_name() const { return GetParam().field_name; } absl::string_view message_textproto() const { return GetParam().message_textproto; } CelValue cel_value() const { return GetParam().cel_value; } }; TEST_P(RepeatedFieldTest, GetFirstElem) { TestAllTypes test_message; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(message_textproto(), &test_message)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( CelValue accessed_value, CreateValueFromRepeatedField( &test_message, test_message.GetDescriptor()->FindFieldByName(field_name()), 0, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::EqualsCelValue(cel_value())); } TEST_P(RepeatedFieldTest, AppendElem) { TestAllTypes test_message; CelValue to_add = cel_value(); google::protobuf::Arena arena; ASSERT_OK(AddValueToRepeatedField( to_add, test_message.GetDescriptor()->FindFieldByName(field_name()), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto(message_textproto())); } INSTANTIATE_TEST_SUITE_P( AllTypes, RepeatedFieldTest, testing::ValuesIn( {{"repeated_int32", "repeated_int32: 1", CelValue::CreateInt64(1)}, {"repeated_int64", "repeated_int64: 1", CelValue::CreateInt64(1)}, {"repeated_uint32", "repeated_uint32: 1", CelValue::CreateUint64(1)}, {"repeated_uint64", "repeated_uint64: 1", CelValue::CreateUint64(1)}, {"repeated_sint32", "repeated_sint32: 1", CelValue::CreateInt64(1)}, {"repeated_sint64", "repeated_sint64: 1", CelValue::CreateInt64(1)}, {"repeated_fixed32", "repeated_fixed32: 1", CelValue::CreateUint64(1)}, {"repeated_fixed64", "repeated_fixed64: 1", CelValue::CreateUint64(1)}, {"repeated_sfixed32", "repeated_sfixed32: 1", CelValue::CreateInt64(1)}, {"repeated_sfixed64", "repeated_sfixed64: 1", CelValue::CreateInt64(1)}, {"repeated_float", "repeated_float: 1.0", CelValue::CreateDouble(1.0)}, {"repeated_double", "repeated_double: 1.0", CelValue::CreateDouble(1.0)}, {"repeated_bool", "repeated_bool: true", CelValue::CreateBool(true)}, {"repeated_string", "repeated_string: 'abcd'", CelValue::CreateStringView("abcd")}, {"repeated_bytes", "repeated_bytes: 'asdf'", CelValue::CreateBytesView("asdf")}, {"repeated_nested_enum", "repeated_nested_enum: BAZ", CelValue::CreateInt64(2)}}), &GetTestName); TEST(RepeatedFieldTest, GetMessage) { TestAllTypes test_message; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( "repeated_nested_message { bb: 30 }", &test_message)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue accessed_value, CreateValueFromRepeatedField( &test_message, test_message.GetDescriptor()->FindFieldByName( "repeated_nested_message"), 0, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(accessed_value, test::IsCelMessage(EqualsProto("bb: 30"))); } TEST(AddValueToRepeatedFieldTest, WrongType) { TestAllTypes test_message; google::protobuf::Arena arena; EXPECT_THAT( AddValueToRepeatedField( CelValue::CreateDouble(1.0), test_message.GetDescriptor()->FindFieldByName("repeated_int32"), &test_message, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(AddValueToRepeatedFieldTest, IntOutOfRange) { CelValue out_of_range = CelValue::CreateInt64(1LL << 31); TestAllTypes test_message; const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); google::protobuf::Arena arena; EXPECT_THAT(AddValueToRepeatedField( out_of_range, descriptor->FindFieldByName("repeated_int32"), &test_message, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); // proto enums are are represented as int32, but CEL converts to/from int64. EXPECT_THAT( AddValueToRepeatedField( out_of_range, descriptor->FindFieldByName("repeated_nested_enum"), &test_message, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(AddValueToRepeatedFieldTest, UintOutOfRange) { CelValue out_of_range = CelValue::CreateUint64(1LL << 32); TestAllTypes test_message; const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); google::protobuf::Arena arena; EXPECT_THAT(AddValueToRepeatedField( out_of_range, descriptor->FindFieldByName("repeated_uint32"), &test_message, &arena), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(AddValueToRepeatedFieldTest, AddMessage) { TestAllTypes::NestedMessage nested_message; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( bb: 42 )", &nested_message)); google::protobuf::Arena arena; CelValue nested_value = CelProtoWrapper::CreateMessage(&nested_message, &arena); TestAllTypes test_message; const google::protobuf::Descriptor* descriptor = test_message.GetDescriptor(); ASSERT_OK(AddValueToRepeatedField( nested_value, descriptor->FindFieldByName("repeated_nested_message"), &test_message, &arena)); EXPECT_THAT(test_message, EqualsProto("repeated_nested_message { bb: 42 }")); } constexpr std::array kWrapperFieldNames = { "single_bool_wrapper", "single_int64_wrapper", "single_int32_wrapper", "single_uint64_wrapper", "single_uint32_wrapper", "single_double_wrapper", "single_float_wrapper", "single_string_wrapper", "single_bytes_wrapper"}; // Unset wrapper type fields are treated as null if accessed after option // enabled. TEST(CreateValueFromFieldTest, UnsetWrapperTypesNullIfEnabled) { CelValue result; TestAllTypes test_message; google::protobuf::Arena arena; for (const auto& field : kWrapperFieldNames) { ASSERT_OK_AND_ASSIGN( result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName(field), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); ASSERT_TRUE(result.IsNull()) << field << ": " << result.DebugString(); } } // Unset wrapper type fields are treated as proto default under old // behavior. TEST(CreateValueFromFieldTest, UnsetWrapperTypesDefaultValueIfDisabled) { CelValue result; TestAllTypes test_message; google::protobuf::Arena arena; for (const auto& field : kWrapperFieldNames) { ASSERT_OK_AND_ASSIGN( result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName(field), ProtoWrapperTypeOptions::kUnsetProtoDefault, &CelProtoWrapper::InternalWrapMessage, &arena)); ASSERT_FALSE(result.IsNull()) << field << ": " << result.DebugString(); } } // If a wrapper type is set to default value, the corresponding CelValue is the // proto default value. TEST(CreateValueFromFieldTest, SetWrapperTypesDefaultValue) { CelValue result; TestAllTypes test_message; google::protobuf::Arena arena; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( single_bool_wrapper {} single_int64_wrapper {} single_int32_wrapper {} single_uint64_wrapper {} single_uint32_wrapper {} single_double_wrapper {} single_float_wrapper {} single_string_wrapper {} single_bytes_wrapper {} )pb", &test_message)); ASSERT_OK_AND_ASSIGN( result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName("single_bool_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelBool(false)); ASSERT_OK_AND_ASSIGN(result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName( "single_int64_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelInt64(0)); ASSERT_OK_AND_ASSIGN(result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName( "single_int32_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelInt64(0)); ASSERT_OK_AND_ASSIGN( result, CreateValueFromSingleField(&test_message, TestAllTypes::GetDescriptor()->FindFieldByName( "single_uint64_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelUint64(0)); ASSERT_OK_AND_ASSIGN( result, CreateValueFromSingleField(&test_message, TestAllTypes::GetDescriptor()->FindFieldByName( "single_uint32_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelUint64(0)); ASSERT_OK_AND_ASSIGN(result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName( "single_double_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelDouble(0.0f)); ASSERT_OK_AND_ASSIGN(result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName( "single_float_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelDouble(0.0f)); ASSERT_OK_AND_ASSIGN(result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName( "single_string_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelString("")); ASSERT_OK_AND_ASSIGN(result, CreateValueFromSingleField( &test_message, TestAllTypes::GetDescriptor()->FindFieldByName( "single_bytes_wrapper"), ProtoWrapperTypeOptions::kUnsetNull, &CelProtoWrapper::InternalWrapMessage, &arena)); EXPECT_THAT(result, test::IsCelBytes("")); } } // namespace } // namespace google::api::expr::runtime::internal ================================================ FILE: eval/public/structs/legacy_type_adapter.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Definitions for legacy type APIs to emulate the behavior of the new type // system. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/memory.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" namespace google::api::expr::runtime { // Interface for mutation apis. // Note: in the new type system, a type provider represents this by returning // a cel::Type and cel::ValueManager for the type. class LegacyTypeMutationApis { public: virtual ~LegacyTypeMutationApis() = default; // Return whether the type defines the given field. // TODO(uncreated-issue/3): This is only used to eagerly fail during the planning // phase. Check if it's safe to remove this behavior and fail at runtime. virtual bool DefinesField(absl::string_view field_name) const = 0; // Create a new empty instance of the type. // May return a status if the type is not possible to create. virtual absl::StatusOr NewInstance( cel::MemoryManagerRef memory_manager) const = 0; // Normalize special types to a native CEL value after building. // The interpreter guarantees that instance is uniquely owned by the // interpreter, and can be safely mutated. virtual absl::StatusOr AdaptFromWellKnownType( cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder instance) const = 0; // Set field on instance to value. // The interpreter guarantees that instance is uniquely owned by the // interpreter, and can be safely mutated. virtual absl::Status SetField( absl::string_view field_name, const CelValue& value, cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder& instance) const = 0; virtual absl::Status SetFieldByNumber( int64_t field_number [[maybe_unused]], const CelValue& value [[maybe_unused]], cel::MemoryManagerRef memory_manager [[maybe_unused]], CelValue::MessageWrapper::Builder& instance [[maybe_unused]]) const { return absl::UnimplementedError("SetFieldByNumber is not yet implemented"); } }; // Interface for access apis. // Note: in new type system this is integrated into the StructValue (via // dynamic dispatch to concrete implementations). class LegacyTypeAccessApis { public: struct LegacyQualifyResult { // The possibly intermediate result of the select operation. CelValue value; // Number of qualifiers applied. int qualifier_count; }; virtual ~LegacyTypeAccessApis() = default; // Return whether an instance of the type has field set to a non-default // value. virtual absl::StatusOr HasField( absl::string_view field_name, const CelValue::MessageWrapper& value) const = 0; // Access field on instance. virtual absl::StatusOr GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager) const = 0; // Apply a series of select operations on the given instance. // // Each select qualifier may represent either a singular field access ( // FieldSpecifier) or an index into a container (AttributeQualifier). // // The Qualify implementation should return an appropriate CelError when // intermediate fields or indexes are not found, or the given qualifier // doesn't apply to operand. // // A Status with a non-ok error code may be returned for other errors. // absl::StatusCode::kUnimplemented signals that Qualify is unsupported and // the evaluator should emulate the default select behavior. // // - presence_test controls whether to treat the call as a 'has' call, // returning // whether the leaf field is set to a non-default value. virtual absl::StatusOr Qualify( absl::Span, const CelValue::MessageWrapper& instance [[maybe_unused]], bool presence_test [[maybe_unused]], cel::MemoryManagerRef memory_manager [[maybe_unused]]) const { return absl::UnimplementedError("Qualify unsupported."); } // Interface for equality operator. // The interpreter will check that both instances report to be the same type, // but implementations should confirm that both instances are actually of the // same type. // If the two instances are of different type, return false. Otherwise, // return whether they are equal. // To conform to the CEL spec, message equality should follow the behavior of // MessageDifferencer::Equals. virtual bool IsEqualTo(const CelValue::MessageWrapper&, const CelValue::MessageWrapper&) const { return false; } virtual std::vector ListFields( const CelValue::MessageWrapper& instance) const = 0; }; // Type information about a legacy Struct type. // Provides methods to the interpreter for interacting with a custom type. // // mutation_apis() provide equivalent behavior to a cel::Type and // cel::ValueManager (resolved from a type name). // // access_apis() provide equivalent behavior to cel::StructValue accessors // (virtual dispatch to a concrete implementation for accessing underlying // values). // // This class is a simple wrapper around (nullable) pointers to the interface // implementations. The underlying pointers are expected to be valid as long as // the type provider that returned this object. class LegacyTypeAdapter { public: LegacyTypeAdapter(const LegacyTypeAccessApis* access, const LegacyTypeMutationApis* mutation) : access_apis_(access), mutation_apis_(mutation) {} // Apis for access for the represented type. // If null, access is not supported (this is an opaque type). const LegacyTypeAccessApis* access_apis() { return access_apis_; } // Apis for mutation for the represented type. // If null, mutation is not supported (this type cannot be created). const LegacyTypeMutationApis* mutation_apis() { return mutation_apis_; } private: const LegacyTypeAccessApis* access_apis_; const LegacyTypeMutationApis* mutation_apis_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_ADPATER_H_ ================================================ FILE: eval/public/structs/legacy_type_adapter_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/legacy_type_adapter.h" #include #include "eval/public/cel_value.h" #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { class TestAccessApiImpl : public LegacyTypeAccessApis { public: TestAccessApiImpl() {} absl::StatusOr HasField( absl::string_view field_name, const CelValue::MessageWrapper& value) const override { return absl::UnimplementedError("Not implemented"); } absl::StatusOr GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager) const override { return absl::UnimplementedError("Not implemented"); } std::vector ListFields( const CelValue::MessageWrapper& instance) const override { return std::vector(); } }; TEST(LegacyTypeAdapterAccessApis, DefaultAlwaysInequal) { TestMessage message; MessageWrapper wrapper(&message, nullptr); MessageWrapper wrapper2(&message, nullptr); TestAccessApiImpl impl; EXPECT_FALSE(impl.IsEqualTo(wrapper, wrapper2)); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/legacy_type_info_apis.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "eval/public/message_wrapper.h" #include "google/protobuf/descriptor.h" namespace google::api::expr::runtime { // Forward declared to resolve cyclic dependency. class LegacyTypeAccessApis; class LegacyTypeMutationApis; // Interface for providing type info from a user defined type (represented as a // message). // // Provides ability to obtain field access apis, type info, and debug // representation of a message. // // The message parameter may wrap a nullptr to request generic accessors / // mutators for the TypeInfo instance if it is available. // // This is implemented as a separate class from LegacyTypeAccessApis to resolve // cyclic dependency between CelValue (which needs to access these apis to // provide DebugString and ObtainCelTypename) and LegacyTypeAccessApis (which // needs to return CelValue type for field access). class LegacyTypeInfoApis { public: struct FieldDescription { int number; absl::string_view name; }; virtual ~LegacyTypeInfoApis() = default; // Return a debug representation of the wrapped message. virtual std::string DebugString( const MessageWrapper& wrapped_message) const = 0; // Return a reference to the typename for the wrapped message's type. // The CEL interpreter assumes that the typename is owned externally and will // outlive any CelValues created by the interpreter. virtual absl::string_view GetTypename( const MessageWrapper& wrapped_message) const = 0; virtual const google::protobuf::Descriptor* absl_nullable GetDescriptor( const MessageWrapper& wrapped_message [[maybe_unused]]) const { return nullptr; } // Return a pointer to the wrapped message's access api implementation. // // The CEL interpreter assumes that the returned pointer is owned externally // and will outlive any CelValues created by the interpreter. // // Nullptr signals that the value does not provide access apis. For field // access, the interpreter will treat this the same as accessing a field that // is not defined for the type. virtual const LegacyTypeAccessApis* GetAccessApis( const MessageWrapper& wrapped_message) const = 0; // Return a pointer to the wrapped message's mutation api implementation. // // The CEL interpreter assumes that the returned pointer is owned externally // and will outlive any CelValues created by the interpreter. // // Nullptr signals that the value does not provide mutation apis. virtual const LegacyTypeMutationApis* GetMutationApis( const MessageWrapper& wrapped_message [[maybe_unused]]) const { return nullptr; } // Return a description of the underlying field if defined. // // The underlying string is expected to remain valid as long as the // LegacyTypeInfoApis instance. virtual absl::optional FindFieldByName( absl::string_view name [[maybe_unused]]) const { return absl::nullopt; } }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_LEGACY_TYPE_INFO_APIS_H_ ================================================ FILE: eval/public/structs/legacy_type_provider.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/legacy_type_provider.h" #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/legacy_value.h" #include "common/memory.h" #include "common/type.h" #include "common/type_introspector.h" #include "common/value.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using google::api::expr::runtime::LegacyTypeAdapter; using google::api::expr::runtime::MessageWrapper; class LegacyStructValueBuilder final : public cel::StructValueBuilder { public: LegacyStructValueBuilder(cel::MemoryManagerRef memory_manager, LegacyTypeAdapter adapter, MessageWrapper::Builder builder) : memory_manager_(memory_manager), adapter_(adapter), builder_(std::move(builder)) {} absl::StatusOr> SetFieldByName( absl::string_view name, cel::Value value) override { CEL_ASSIGN_OR_RETURN( auto legacy_value, LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), value), _.With(cel::ErrorValueReturn())); CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetField( name, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); return absl::nullopt; } absl::StatusOr> SetFieldByNumber( int64_t number, cel::Value value) override { CEL_ASSIGN_OR_RETURN( auto legacy_value, LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), value), _.With(cel::ErrorValueReturn())); CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetFieldByNumber( number, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); return absl::nullopt; } absl::StatusOr Build() && override { CEL_ASSIGN_OR_RETURN(auto message, adapter_.mutation_apis()->AdaptFromWellKnownType( memory_manager_, std::move(builder_))); if (!message.IsMessage()) { return absl::FailedPreconditionError("expected MessageWrapper"); } auto message_wrapper = message.MessageWrapperOrDie(); return cel::common_internal::LegacyStructValue( google::protobuf::DownCastMessage(message_wrapper.message_ptr()), message_wrapper.legacy_type_info()); } private: cel::MemoryManagerRef memory_manager_; LegacyTypeAdapter adapter_; MessageWrapper::Builder builder_; }; class LegacyValueBuilder final : public cel::ValueBuilder { public: LegacyValueBuilder(cel::MemoryManagerRef memory_manager, LegacyTypeAdapter adapter, MessageWrapper::Builder builder) : memory_manager_(memory_manager), adapter_(adapter), builder_(std::move(builder)) {} absl::StatusOr> SetFieldByName( absl::string_view name, cel::Value value) override { CEL_ASSIGN_OR_RETURN( auto legacy_value, LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), value), _.With(cel::ErrorValueReturn())); CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetField( name, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); return absl::nullopt; } absl::StatusOr> SetFieldByNumber( int64_t number, cel::Value value) override { CEL_ASSIGN_OR_RETURN( auto legacy_value, LegacyValue(cel::extensions::ProtoMemoryManagerArena(memory_manager_), value), _.With(cel::ErrorValueReturn())); CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetFieldByNumber( number, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); return absl::nullopt; } absl::StatusOr Build() && override { CEL_ASSIGN_OR_RETURN(auto value, adapter_.mutation_apis()->AdaptFromWellKnownType( memory_manager_, std::move(builder_)), _.With(cel::ErrorValueReturn())); CEL_ASSIGN_OR_RETURN( auto result, cel::ModernValue( cel::extensions::ProtoMemoryManagerArena(memory_manager_), value), _.With(cel::ErrorValueReturn())); return result; } private: cel::MemoryManagerRef memory_manager_; LegacyTypeAdapter adapter_; MessageWrapper::Builder builder_; }; } // namespace absl::StatusOr LegacyTypeProvider::NewValueBuilder( absl::string_view name, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { if (auto type_adapter = ProvideLegacyType(name); type_adapter.has_value()) { const auto* mutation_apis = type_adapter->mutation_apis(); if (mutation_apis == nullptr) { return absl::FailedPreconditionError( absl::StrCat("LegacyTypeMutationApis missing for type: ", name)); } CEL_ASSIGN_OR_RETURN( auto builder, mutation_apis->NewInstance(cel::MemoryManagerRef::Pooling(arena))); return std::make_unique( cel::MemoryManagerRef::Pooling(arena), *type_adapter, std::move(builder)); } return nullptr; } absl::StatusOr> LegacyTypeProvider::FindTypeImpl( absl::string_view name) const { if (auto type = cel::FindWellKnownType(name); type.has_value()) { return type; } if (auto type_info = ProvideLegacyTypeInfo(name); type_info.has_value()) { const auto* descriptor = (*type_info)->GetDescriptor(MessageWrapper()); if (descriptor != nullptr) { return cel::MessageType(descriptor); } return cel::common_internal::MakeBasicStructType( (*type_info)->GetTypename(MessageWrapper())); } return absl::nullopt; } absl::StatusOr> LegacyTypeProvider::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { if (auto result = cel::FindWellKnownTypeFieldByName(type, name); result.has_value()) { return result; } if (auto type_info = ProvideLegacyTypeInfo(type); type_info.has_value()) { if (auto field_desc = (*type_info)->FindFieldByName(name); field_desc.has_value()) { return cel::common_internal::BasicStructTypeField( field_desc->name, field_desc->number, cel::DynType{}); } else { const auto* mutation_apis = (*type_info)->GetMutationApis(MessageWrapper()); if (mutation_apis == nullptr || !mutation_apis->DefinesField(name)) { return absl::nullopt; } return cel::common_internal::BasicStructTypeField(name, 0, cel::DynType{}); } } return absl::nullopt; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/legacy_type_provider.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { // An internal extension of cel::TypeProvider that also deals with legacy types. // // Note: This API is not finalized. Consult the CEL team before introducing new // implementations. class LegacyTypeProvider : public cel::TypeReflector { public: virtual ~LegacyTypeProvider() = default; // Return LegacyTypeAdapter for the fully qualified type name if available. // // nullopt values are interpreted as not present. // // Returned non-null pointers from the adapter implemententation must remain // valid as long as the type provider. // TODO(uncreated-issue/3): add alternative for new type system. virtual absl::optional ProvideLegacyType( absl::string_view name) const = 0; // Return LegacyTypeInfoApis for the fully qualified type name if available. // // nullopt values are interpreted as not present. // // Since custom type providers should create values compatible with evaluator // created ones, the TypeInfoApis returned from this method should be the same // as the ones used in value creation. virtual absl::optional ProvideLegacyTypeInfo( ABSL_ATTRIBUTE_UNUSED absl::string_view name) const { return absl::nullopt; } absl::StatusOr NewValueBuilder( absl::string_view name, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const final; protected: absl::StatusOr> FindTypeImpl( absl::string_view name) const final; absl::StatusOr> FindStructTypeFieldByNameImpl(absl::string_view type, absl::string_view name) const final; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TYPE_PROVIDER_H_ ================================================ FILE: eval/public/structs/legacy_type_provider_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/legacy_type_provider.h" #include #include #include "absl/strings/string_view.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { class LegacyTypeProviderTestEmpty : public LegacyTypeProvider { public: absl::optional ProvideLegacyType( absl::string_view name) const override { return absl::nullopt; } }; class LegacyTypeInfoApisEmpty : public LegacyTypeInfoApis { public: std::string DebugString( const MessageWrapper& wrapped_message) const override { return ""; } absl::string_view GetTypename( const MessageWrapper& wrapped_message) const override { return test_string_; } const LegacyTypeAccessApis* GetAccessApis( const MessageWrapper& wrapped_message) const override { return nullptr; } private: const std::string test_string_ = "test"; }; class LegacyTypeProviderTestImpl : public LegacyTypeProvider { public: explicit LegacyTypeProviderTestImpl(const LegacyTypeInfoApis* test_type_info) : test_type_info_(test_type_info) {} absl::optional ProvideLegacyType( absl::string_view name) const override { if (name == "test") { return LegacyTypeAdapter(nullptr, nullptr); } return absl::nullopt; } absl::optional ProvideLegacyTypeInfo( absl::string_view name) const override { if (name == "test") { return test_type_info_; } return absl::nullopt; } private: const LegacyTypeInfoApis* test_type_info_ = nullptr; }; TEST(LegacyTypeProviderTest, EmptyTypeProviderHasProvideTypeInfo) { LegacyTypeProviderTestEmpty provider; EXPECT_EQ(provider.ProvideLegacyType("test"), absl::nullopt); EXPECT_EQ(provider.ProvideLegacyTypeInfo("test"), absl::nullopt); } TEST(LegacyTypeProviderTest, NonEmptyTypeProviderProvidesSomeTypes) { LegacyTypeInfoApisEmpty test_type_info; LegacyTypeProviderTestImpl provider(&test_type_info); EXPECT_TRUE(provider.ProvideLegacyType("test").has_value()); EXPECT_TRUE(provider.ProvideLegacyTypeInfo("test").has_value()); EXPECT_EQ(provider.ProvideLegacyType("other"), absl::nullopt); EXPECT_EQ(provider.ProvideLegacyTypeInfo("other"), absl::nullopt); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/proto_message_type_adapter.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/proto_message_type_adapter.h" #include #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/memory.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/internal_field_backed_list_impl.h" #include "eval/public/containers/internal_field_backed_map_impl.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/cel_proto_wrap_util.h" #include "eval/public/structs/field_access_impl.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "extensions/protobuf/internal/qualify.h" #include "extensions/protobuf/memory_manager.h" #include "internal/casts.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #include "google/protobuf/util/message_differencer.h" namespace google::api::expr::runtime { namespace { using ::cel::extensions::ProtoMemoryManagerArena; using ::cel::extensions::ProtoMemoryManagerRef; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Message; using ::google::protobuf::Reflection; using LegacyQualifyResult = LegacyTypeAccessApis::LegacyQualifyResult; const std::string& UnsupportedTypeName() { static absl::NoDestructor kUnsupportedTypeName( ""); return *kUnsupportedTypeName; } CelValue MessageCelValueFactory(const google::protobuf::Message* message); inline absl::StatusOr UnwrapMessage( const MessageWrapper& value, absl::string_view op) { if (!value.HasFullProto() || value.message_ptr() == nullptr) { return absl::InternalError( absl::StrCat(op, " called on non-message type.")); } return static_cast(value.message_ptr()); } inline absl::StatusOr UnwrapMessage( const MessageWrapper::Builder& value, absl::string_view op) { if (!value.HasFullProto() || value.message_ptr() == nullptr) { return absl::InternalError( absl::StrCat(op, " called on non-message type.")); } return static_cast(value.message_ptr()); } bool ProtoEquals(const google::protobuf::Message& m1, const google::protobuf::Message& m2) { // Equality behavior is undefined for message differencer if input messages // have different descriptors. For CEL just return false. if (m1.GetDescriptor() != m2.GetDescriptor()) { return false; } return google::protobuf::util::MessageDifferencer::Equals(m1, m2); } // Implements CEL's notion of field presence for protobuf. // Assumes all arguments non-null. bool CelFieldIsPresent(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field_desc, const google::protobuf::Reflection* reflection) { if (field_desc->is_map()) { // When the map field appears in a has(msg.map_field) expression, the map // is considered 'present' when it is non-empty. Since maps are repeated // fields they don't participate with standard proto presence testing since // the repeated field is always at least empty. return reflection->FieldSize(*message, field_desc) != 0; } if (field_desc->is_repeated()) { // When the list field appears in a has(msg.list_field) expression, the list // is considered 'present' when it is non-empty. return reflection->FieldSize(*message, field_desc) != 0; } // Standard proto presence test for non-repeated fields. return reflection->HasField(*message, field_desc); } // Shared implementation for HasField. // Handles list or map specific behavior before calling reflection helpers. absl::StatusOr HasFieldImpl(const google::protobuf::Message* message, const google::protobuf::Descriptor* descriptor, absl::string_view field_name) { ABSL_ASSERT(descriptor == message->GetDescriptor()); const Reflection* reflection = message->GetReflection(); const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name); if (field_desc == nullptr && reflection != nullptr) { // Search to see whether the field name is referring to an extension. field_desc = reflection->FindKnownExtensionByName(field_name); } if (field_desc == nullptr) { return absl::NotFoundError(absl::StrCat("no_such_field : ", field_name)); } if (reflection == nullptr) { return absl::FailedPreconditionError( "google::protobuf::Reflection unavailble in CEL field access."); } return CelFieldIsPresent(message, field_desc, reflection); } absl::StatusOr CreateCelValueFromField( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field_desc, ProtoWrapperTypeOptions unboxing_option, google::protobuf::Arena* arena) { if (field_desc->is_map()) { auto* map = google::protobuf::Arena::Create( arena, message, field_desc, &MessageCelValueFactory, arena); return CelValue::CreateMap(map); } if (field_desc->is_repeated()) { auto* list = google::protobuf::Arena::Create( arena, message, field_desc, &MessageCelValueFactory, arena); return CelValue::CreateList(list); } CEL_ASSIGN_OR_RETURN( CelValue result, internal::CreateValueFromSingleField(message, field_desc, unboxing_option, &MessageCelValueFactory, arena)); return result; } // Shared implementation for GetField. // Handles list or map specific behavior before calling reflection helpers. absl::StatusOr GetFieldImpl(const google::protobuf::Message* message, const google::protobuf::Descriptor* descriptor, absl::string_view field_name, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager) { ABSL_ASSERT(descriptor == message->GetDescriptor()); const Reflection* reflection = message->GetReflection(); const FieldDescriptor* field_desc = descriptor->FindFieldByName(field_name); if (field_desc == nullptr && reflection != nullptr) { std::string ext_name(field_name); field_desc = reflection->FindKnownExtensionByName(ext_name); } if (field_desc == nullptr) { return CreateNoSuchFieldError(memory_manager, field_name); } google::protobuf::Arena* arena = ProtoMemoryManagerArena(memory_manager); return CreateCelValueFromField(message, field_desc, unboxing_option, arena); } // State machine for incrementally applying qualifiers. // // Reusing the state machine to represent intermediate states (as opposed to // returning the intermediates) is more efficient for longer select chains while // still allowing decomposition of the qualify routine. class LegacyQualifyState final : public cel::extensions::protobuf_internal::ProtoQualifyState { public: using ProtoQualifyState::ProtoQualifyState; LegacyQualifyState(const LegacyQualifyState&) = delete; LegacyQualifyState& operator=(const LegacyQualifyState&) = delete; absl::optional& result() { return result_; } private: void SetResultFromError(absl::Status status, cel::MemoryManagerRef memory_manager) override { result_ = CreateErrorValue(memory_manager, status); } void SetResultFromBool(bool value) override { result_ = CelValue::CreateBool(value); } absl::Status SetResultFromField( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager) override { CEL_ASSIGN_OR_RETURN(result_, CreateCelValueFromField( message, field, unboxing_option, ProtoMemoryManagerArena(memory_manager))); return absl::OkStatus(); } absl::Status SetResultFromRepeatedField( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, int index, cel::MemoryManagerRef memory_manager) override { CEL_ASSIGN_OR_RETURN(result_, internal::CreateValueFromRepeatedField( message, field, index, &MessageCelValueFactory, ProtoMemoryManagerArena(memory_manager))); return absl::OkStatus(); } absl::Status SetResultFromMapField( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, const google::protobuf::MapValueConstRef& value, cel::MemoryManagerRef memory_manager) override { CEL_ASSIGN_OR_RETURN(result_, internal::CreateValueFromMapValue( message, field, &value, &MessageCelValueFactory, ProtoMemoryManagerArena(memory_manager))); return absl::OkStatus(); } absl::optional result_; }; absl::StatusOr QualifyImpl( const google::protobuf::Message* message, const google::protobuf::Descriptor* descriptor, absl::Span path, bool presence_test, cel::MemoryManagerRef memory_manager) { google::protobuf::Arena* arena = ProtoMemoryManagerArena(memory_manager); ABSL_DCHECK(descriptor == message->GetDescriptor()); LegacyQualifyState qualify_state(message, descriptor, message->GetReflection()); for (int i = 0; i < path.size() - 1; i++) { const auto& qualifier = path.at(i); CEL_RETURN_IF_ERROR(qualify_state.ApplySelectQualifier( qualifier, ProtoMemoryManagerRef(arena))); if (qualify_state.result().has_value()) { LegacyQualifyResult result; result.value = std::move(qualify_state.result()).value(); result.qualifier_count = result.value.IsError() ? -1 : i + 1; return result; } } const auto& last_qualifier = path.back(); LegacyQualifyResult result; result.qualifier_count = -1; if (presence_test) { CEL_RETURN_IF_ERROR(qualify_state.ApplyLastQualifierHas( last_qualifier, ProtoMemoryManagerRef(arena))); } else { CEL_RETURN_IF_ERROR(qualify_state.ApplyLastQualifierGet( last_qualifier, ProtoMemoryManagerRef(arena))); } result.value = *qualify_state.result(); return result; } std::vector ListFieldsImpl( const CelValue::MessageWrapper& instance) { if (instance.message_ptr() == nullptr) { return std::vector(); } ABSL_ASSERT(instance.HasFullProto()); const auto* message = static_cast(instance.message_ptr()); const auto* reflect = message->GetReflection(); std::vector fields; reflect->ListFields(*message, &fields); std::vector field_names; field_names.reserve(fields.size()); for (const auto* field : fields) { field_names.emplace_back(field->name()); } return field_names; } class DucktypedMessageAdapter : public LegacyTypeAccessApis, public LegacyTypeMutationApis, public LegacyTypeInfoApis { public: // Implement field access APIs. absl::StatusOr HasField( absl::string_view field_name, const CelValue::MessageWrapper& value) const override { CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, UnwrapMessage(value, "HasField")); return HasFieldImpl(message, message->GetDescriptor(), field_name); } absl::StatusOr GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager) const override { CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, UnwrapMessage(instance, "GetField")); return GetFieldImpl(message, message->GetDescriptor(), field_name, unboxing_option, memory_manager); } absl::StatusOr Qualify( absl::Span qualifiers, const CelValue::MessageWrapper& instance, bool presence_test, cel::MemoryManagerRef memory_manager) const override { CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, UnwrapMessage(instance, "Qualify")); return QualifyImpl(message, message->GetDescriptor(), qualifiers, presence_test, memory_manager); } bool IsEqualTo( const CelValue::MessageWrapper& instance, const CelValue::MessageWrapper& other_instance) const override { absl::StatusOr lhs = UnwrapMessage(instance, "IsEqualTo"); absl::StatusOr rhs = UnwrapMessage(other_instance, "IsEqualTo"); if (!lhs.ok() || !rhs.ok()) { // Treat this as though the underlying types are different, just return // false. return false; } return ProtoEquals(**lhs, **rhs); } // Implement TypeInfo Apis absl::string_view GetTypename( const MessageWrapper& wrapped_message) const override { if (!wrapped_message.HasFullProto() || wrapped_message.message_ptr() == nullptr) { return UnsupportedTypeName(); } auto* message = static_cast(wrapped_message.message_ptr()); return message->GetDescriptor()->full_name(); } std::string DebugString( const MessageWrapper& wrapped_message) const override { if (!wrapped_message.HasFullProto() || wrapped_message.message_ptr() == nullptr) { return UnsupportedTypeName(); } auto* message = static_cast(wrapped_message.message_ptr()); return message->ShortDebugString(); } bool DefinesField(absl::string_view field_name) const override { // Pretend all our fields exist. Real errors will be returned from field // getters and setters. return true; } absl::StatusOr NewInstance( cel::MemoryManagerRef memory_manager) const override { return absl::UnimplementedError("NewInstance is not implemented"); } absl::StatusOr AdaptFromWellKnownType( cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder instance) const override { if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { return absl::UnimplementedError( "MessageLite is not supported, descriptor is required"); } return ProtoMessageTypeAdapter( static_cast(instance.message_ptr()) ->GetDescriptor(), nullptr) .AdaptFromWellKnownType(memory_manager, instance); } absl::Status SetField( absl::string_view field_name, const CelValue& value, cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder& instance) const override { if (!instance.HasFullProto() || instance.message_ptr() == nullptr) { return absl::UnimplementedError( "MessageLite is not supported, descriptor is required"); } return ProtoMessageTypeAdapter( static_cast(instance.message_ptr()) ->GetDescriptor(), nullptr) .SetField(field_name, value, memory_manager, instance); } std::vector ListFields( const CelValue::MessageWrapper& instance) const override { return ListFieldsImpl(instance); } const LegacyTypeAccessApis* GetAccessApis( const MessageWrapper& wrapped_message) const override { return this; } const LegacyTypeMutationApis* GetMutationApis( const MessageWrapper& wrapped_message) const override { return this; } static const DucktypedMessageAdapter& GetSingleton() { static absl::NoDestructor instance; return *instance; } }; CelValue MessageCelValueFactory(const google::protobuf::Message* message) { return CelValue::CreateMessageWrapper( MessageWrapper(message, &DucktypedMessageAdapter::GetSingleton())); } } // namespace std::string ProtoMessageTypeAdapter::DebugString( const MessageWrapper& wrapped_message) const { if (!wrapped_message.HasFullProto() || wrapped_message.message_ptr() == nullptr) { return UnsupportedTypeName(); } auto* message = static_cast(wrapped_message.message_ptr()); return message->ShortDebugString(); } absl::string_view ProtoMessageTypeAdapter::GetTypename( const MessageWrapper& wrapped_message) const { return descriptor_->full_name(); } const LegacyTypeMutationApis* ProtoMessageTypeAdapter::GetMutationApis( const MessageWrapper& wrapped_message) const { // Defer checks for misuse on wrong message kind in the accessor calls. return this; } const LegacyTypeAccessApis* ProtoMessageTypeAdapter::GetAccessApis( const MessageWrapper& wrapped_message) const { // Defer checks for misuse on wrong message kind in the builder calls. return this; } absl::optional ProtoMessageTypeAdapter::FindFieldByName(absl::string_view field_name) const { if (descriptor_ == nullptr) { return absl::nullopt; } const google::protobuf::FieldDescriptor* field_descriptor = descriptor_->FindFieldByName(field_name); if (field_descriptor == nullptr) { return absl::nullopt; } return LegacyTypeInfoApis::FieldDescription{field_descriptor->number(), field_descriptor->name()}; } absl::Status ProtoMessageTypeAdapter::ValidateSetFieldOp( bool assertion, absl::string_view field, absl::string_view detail) const { if (!assertion) { return absl::InvalidArgumentError( absl::Substitute("SetField failed on message $0, field '$1': $2", descriptor_->full_name(), field, detail)); } return absl::OkStatus(); } absl::StatusOr ProtoMessageTypeAdapter::NewInstance( cel::MemoryManagerRef memory_manager) const { if (message_factory_ == nullptr) { return absl::UnimplementedError( absl::StrCat("Cannot create message ", descriptor_->name())); } // This implementation requires arena-backed memory manager. google::protobuf::Arena* arena = ProtoMemoryManagerArena(memory_manager); const Message* prototype = message_factory_->GetPrototype(descriptor_); Message* msg = (prototype != nullptr) ? prototype->New(arena) : nullptr; if (msg == nullptr) { return absl::InvalidArgumentError( absl::StrCat("Failed to create message ", descriptor_->name())); } return MessageWrapper::Builder(msg); } bool ProtoMessageTypeAdapter::DefinesField(absl::string_view field_name) const { return descriptor_->FindFieldByName(field_name) != nullptr; } absl::StatusOr ProtoMessageTypeAdapter::HasField( absl::string_view field_name, const CelValue::MessageWrapper& value) const { CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, UnwrapMessage(value, "HasField")); return HasFieldImpl(message, descriptor_, field_name); } absl::StatusOr ProtoMessageTypeAdapter::GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager) const { CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, UnwrapMessage(instance, "GetField")); return GetFieldImpl(message, descriptor_, field_name, unboxing_option, memory_manager); } absl::StatusOr ProtoMessageTypeAdapter::Qualify( absl::Span qualifiers, const CelValue::MessageWrapper& instance, bool presence_test, cel::MemoryManagerRef memory_manager) const { CEL_ASSIGN_OR_RETURN(const google::protobuf::Message* message, UnwrapMessage(instance, "Qualify")); return QualifyImpl(message, descriptor_, qualifiers, presence_test, memory_manager); } absl::Status ProtoMessageTypeAdapter::SetField( const google::protobuf::FieldDescriptor* field, const CelValue& value, google::protobuf::Arena* arena, google::protobuf::Message* message) const { if (field->is_map()) { constexpr int kKeyField = 1; constexpr int kValueField = 2; const CelMap* cel_map; CEL_RETURN_IF_ERROR(ValidateSetFieldOp( value.GetValue(&cel_map) && cel_map != nullptr, field->name(), absl::StrCat("value is not CelMap - value is ", CelValue::TypeName(value.type())))); auto entry_descriptor = field->message_type(); CEL_RETURN_IF_ERROR( ValidateSetFieldOp(entry_descriptor != nullptr, field->name(), "failed to find map entry descriptor")); auto key_field_descriptor = entry_descriptor->FindFieldByNumber(kKeyField); auto value_field_descriptor = entry_descriptor->FindFieldByNumber(kValueField); CEL_RETURN_IF_ERROR( ValidateSetFieldOp(key_field_descriptor != nullptr, field->name(), "failed to find key field descriptor")); CEL_RETURN_IF_ERROR( ValidateSetFieldOp(value_field_descriptor != nullptr, field->name(), "failed to find value field descriptor")); CEL_ASSIGN_OR_RETURN(const CelList* key_list, cel_map->ListKeys(arena)); for (int i = 0; i < key_list->size(); i++) { CelValue key = (*key_list).Get(arena, i); auto value = (*cel_map).Get(arena, key); CEL_RETURN_IF_ERROR(ValidateSetFieldOp(value.has_value(), field->name(), "error serializing CelMap")); Message* entry_msg = message->GetReflection()->AddMessage(message, field); CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( key, key_field_descriptor, entry_msg, arena)); CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( value.value(), value_field_descriptor, entry_msg, arena)); } } else if (field->is_repeated()) { const CelList* cel_list; CEL_RETURN_IF_ERROR(ValidateSetFieldOp( value.GetValue(&cel_list) && cel_list != nullptr, field->name(), absl::StrCat("expected CelList value - value is", CelValue::TypeName(value.type())))); for (int i = 0; i < cel_list->size(); i++) { CEL_RETURN_IF_ERROR(internal::AddValueToRepeatedField( (*cel_list).Get(arena, i), field, message, arena)); } } else { CEL_RETURN_IF_ERROR( internal::SetValueToSingleField(value, field, message, arena)); } return absl::OkStatus(); } absl::Status ProtoMessageTypeAdapter::SetField( absl::string_view field_name, const CelValue& value, cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder& instance) const { // Assume proto arena implementation if this provider is used. google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManagerArena(memory_manager); CEL_ASSIGN_OR_RETURN(google::protobuf::Message * mutable_message, UnwrapMessage(instance, "SetField")); const google::protobuf::FieldDescriptor* field_descriptor = descriptor_->FindFieldByName(field_name); CEL_RETURN_IF_ERROR( ValidateSetFieldOp(field_descriptor != nullptr, field_name, "not found")); return SetField(field_descriptor, value, arena, mutable_message); } absl::Status ProtoMessageTypeAdapter::SetFieldByNumber( int64_t field_number, const CelValue& value, cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder& instance) const { // Assume proto arena implementation if this provider is used. google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManagerArena(memory_manager); CEL_ASSIGN_OR_RETURN(google::protobuf::Message * mutable_message, UnwrapMessage(instance, "SetField")); const google::protobuf::FieldDescriptor* field_descriptor = descriptor_->FindFieldByNumber(field_number); CEL_RETURN_IF_ERROR(ValidateSetFieldOp( field_descriptor != nullptr, absl::StrCat(field_number), "not found")); return SetField(field_descriptor, value, arena, mutable_message); } absl::StatusOr ProtoMessageTypeAdapter::AdaptFromWellKnownType( cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder instance) const { // Assume proto arena implementation if this provider is used. google::protobuf::Arena* arena = cel::extensions::ProtoMemoryManagerArena(memory_manager); CEL_ASSIGN_OR_RETURN(google::protobuf::Message * message, UnwrapMessage(instance, "AdaptFromWellKnownType")); return internal::UnwrapMessageToValue(message, &MessageCelValueFactory, arena); } bool ProtoMessageTypeAdapter::IsEqualTo( const CelValue::MessageWrapper& instance, const CelValue::MessageWrapper& other_instance) const { absl::StatusOr lhs = UnwrapMessage(instance, "IsEqualTo"); absl::StatusOr rhs = UnwrapMessage(other_instance, "IsEqualTo"); if (!lhs.ok() || !rhs.ok()) { // Treat this as though the underlying types are different, just return // false. return false; } return ProtoEquals(**lhs, **rhs); } std::vector ProtoMessageTypeAdapter::ListFields( const CelValue::MessageWrapper& instance) const { return ListFieldsImpl(instance); } const LegacyTypeInfoApis& GetGenericProtoTypeInfoInstance() { return DucktypedMessageAdapter::GetSingleton(); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/proto_message_type_adapter.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "common/memory.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "google/protobuf/descriptor.h" namespace google::api::expr::runtime { // Implementation for legacy struct (message) type apis using reflection. // // Note: The type info API implementation attached to message values is // generally the duck-typed instance to support the default behavior of // deferring to the protobuf reflection apis on the message instance. class ProtoMessageTypeAdapter : public LegacyTypeInfoApis, public LegacyTypeAccessApis, public LegacyTypeMutationApis { public: ProtoMessageTypeAdapter(const google::protobuf::Descriptor* descriptor, google::protobuf::MessageFactory* message_factory) : message_factory_(message_factory), descriptor_(descriptor) {} ~ProtoMessageTypeAdapter() override = default; // Implement LegacyTypeInfoApis std::string DebugString(const MessageWrapper& wrapped_message) const override; absl::string_view GetTypename( const MessageWrapper& wrapped_message) const override; const google::protobuf::Descriptor* absl_nullable GetDescriptor( const MessageWrapper& wrapped_message [[maybe_unused]]) const override { return descriptor_; } const LegacyTypeAccessApis* GetAccessApis( const MessageWrapper& wrapped_message) const override; const LegacyTypeMutationApis* GetMutationApis( const MessageWrapper& wrapped_message) const override; absl::optional FindFieldByName( absl::string_view field_name) const override; // Implement LegacyTypeMutation APIs. absl::StatusOr NewInstance( cel::MemoryManagerRef memory_manager) const override; bool DefinesField(absl::string_view field_name) const override; absl::Status SetField( absl::string_view field_name, const CelValue& value, cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder& instance) const override; absl::Status SetFieldByNumber( int64_t field_number, const CelValue& value, cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder& instance) const override; absl::StatusOr AdaptFromWellKnownType( cel::MemoryManagerRef memory_manager, CelValue::MessageWrapper::Builder instance) const override; // Implement LegacyTypeAccessAPIs. absl::StatusOr GetField( absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager) const override; absl::StatusOr HasField( absl::string_view field_name, const CelValue::MessageWrapper& value) const override; absl::StatusOr Qualify( absl::Span qualifiers, const CelValue::MessageWrapper& instance, bool presence_test, cel::MemoryManagerRef memory_manager) const override; bool IsEqualTo(const CelValue::MessageWrapper& instance, const CelValue::MessageWrapper& other_instance) const override; std::vector ListFields( const CelValue::MessageWrapper& instance) const override; private: // Helper for standardizing error messages for SetField operation. absl::Status ValidateSetFieldOp(bool assertion, absl::string_view field, absl::string_view detail) const; absl::Status SetField(const google::protobuf::FieldDescriptor* field, const CelValue& value, google::protobuf::Arena* arena, google::protobuf::Message* message) const; google::protobuf::MessageFactory* message_factory_; const google::protobuf::Descriptor* descriptor_; }; // Returns a TypeInfo provider representing an arbitrary message. // This allows for the legacy duck-typed behavior of messages on field access // instead of expecting a particular message type given a TypeInfo. const LegacyTypeInfoApis& GetGenericProtoTypeInfoInstance(); } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTO_MESSAGE_TYPE_ADAPTER_H_ ================================================ FILE: eval/public/structs/proto_message_type_adapter_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/proto_message_type_adapter.h" #include #include "google/protobuf/wrappers.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/status/status.h" #include "base/attribute.h" #include "common/value.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" #include "internal/proto_matchers.h" #include "internal/testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::ProtoWrapperTypeOptions; using ::cel::extensions::ProtoMemoryManagerRef; using ::cel::internal::test::EqualsProto; using ::google::protobuf::Int64Value; using ::testing::_; using ::testing::AllOf; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Field; using ::testing::HasSubstr; using ::testing::Optional; using ::testing::Truly; using LegacyQualifyResult = LegacyTypeAccessApis::LegacyQualifyResult; class ProtoMessageTypeAccessorTest : public testing::TestWithParam { public: ProtoMessageTypeAccessorTest() : type_specific_instance_( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()) {} const LegacyTypeAccessApis& GetAccessApis() { bool use_generic_instance = GetParam(); if (use_generic_instance) { // implementation detail: in general, type info implementations may // return a different accessor object based on the message instance, but // this implementation returns the same one no matter the message. return *GetGenericProtoTypeInfoInstance().GetAccessApis(dummy_); } else { return type_specific_instance_; } } private: ProtoMessageTypeAdapter type_specific_instance_; CelValue::MessageWrapper dummy_; }; TEST_P(ProtoMessageTypeAccessorTest, HasFieldSingular) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_value", value), IsOkAndHolds(false)); example.set_int64_value(10); EXPECT_THAT(accessor.HasField("int64_value", value), IsOkAndHolds(true)); } TEST_P(ProtoMessageTypeAccessorTest, HasFieldRepeated) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_list", value), IsOkAndHolds(false)); example.add_int64_list(10); EXPECT_THAT(accessor.HasField("int64_list", value), IsOkAndHolds(true)); } TEST_P(ProtoMessageTypeAccessorTest, HasFieldMap) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; example.set_int64_value(10); MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("int64_int32_map", value), IsOkAndHolds(false)); (*example.mutable_int64_int32_map())[2] = 3; EXPECT_THAT(accessor.HasField("int64_int32_map", value), IsOkAndHolds(true)); } TEST_P(ProtoMessageTypeAccessorTest, HasFieldUnknownField) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; example.set_int64_value(10); MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.HasField("unknown_field", value), StatusIs(absl::StatusCode::kNotFound)); } TEST_P(ProtoMessageTypeAccessorTest, HasFieldNonMessageType) { const LegacyTypeAccessApis& accessor = GetAccessApis(); MessageWrapper value(static_cast(nullptr), nullptr); EXPECT_THAT(accessor.HasField("unknown_field", value), StatusIs(absl::StatusCode::kInternal)); } TEST_P(ProtoMessageTypeAccessorTest, GetFieldSingular) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); auto manager = ProtoMemoryManagerRef(&arena); TestMessage example; example.set_int64_value(10); MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(10))); } TEST_P(ProtoMessageTypeAccessorTest, GetFieldNoSuchField) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); auto manager = ProtoMemoryManagerRef(&arena); TestMessage example; example.set_int64_value(10); MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("unknown_field", value, ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelError(StatusIs( absl::StatusCode::kNotFound, HasSubstr("unknown_field"))))); } TEST_P(ProtoMessageTypeAccessorTest, GetFieldNotAMessage) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); auto manager = ProtoMemoryManagerRef(&arena); MessageWrapper value(static_cast(nullptr), nullptr); EXPECT_THAT(accessor.GetField("int64_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), StatusIs(absl::StatusCode::kInternal)); } TEST_P(ProtoMessageTypeAccessorTest, GetFieldRepeated) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); auto manager = ProtoMemoryManagerRef(&arena); TestMessage example; example.add_int64_list(10); example.add_int64_list(20); MessageWrapper value(&example, nullptr); ASSERT_OK_AND_ASSIGN( CelValue result, accessor.GetField("int64_list", value, ProtoWrapperTypeOptions::kUnsetNull, manager)); const CelList* held_value; ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); EXPECT_EQ(held_value->size(), 2); EXPECT_THAT((*held_value)[0], test::IsCelInt64(10)); EXPECT_THAT((*held_value)[1], test::IsCelInt64(20)); } TEST_P(ProtoMessageTypeAccessorTest, GetFieldMap) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); auto manager = ProtoMemoryManagerRef(&arena); TestMessage example; (*example.mutable_int64_int32_map())[10] = 20; MessageWrapper value(&example, nullptr); ASSERT_OK_AND_ASSIGN( CelValue result, accessor.GetField("int64_int32_map", value, ProtoWrapperTypeOptions::kUnsetNull, manager)); const CelMap* held_value; ASSERT_TRUE(result.GetValue(&held_value)) << result.DebugString(); EXPECT_EQ(held_value->size(), 1); EXPECT_THAT((*held_value)[CelValue::CreateInt64(10)], Optional(test::IsCelInt64(20))); } TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperType) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); auto manager = ProtoMemoryManagerRef(&arena); TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(10))); } TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperTypeUnsetNullUnbox) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); auto manager = ProtoMemoryManagerRef(&arena); TestMessage example; MessageWrapper value(&example, nullptr); EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelNull())); // Wrapper field present, but default value. example.mutable_int64_wrapper_value()->clear_value(); EXPECT_THAT(accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(_))); } TEST_P(ProtoMessageTypeAccessorTest, GetFieldWrapperTypeUnsetDefaultValueUnbox) { google::protobuf::Arena arena; const LegacyTypeAccessApis& accessor = GetAccessApis(); auto manager = ProtoMemoryManagerRef(&arena); TestMessage example; MessageWrapper value(&example, nullptr); EXPECT_THAT( accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), IsOkAndHolds(test::IsCelInt64(_))); // Wrapper field present with unset value is used to signal Null, but legacy // behavior just returns the proto default value. example.mutable_int64_wrapper_value()->clear_value(); // Same behavior for this option. EXPECT_THAT( accessor.GetField("int64_wrapper_value", value, ProtoWrapperTypeOptions::kUnsetProtoDefault, manager), IsOkAndHolds(test::IsCelInt64(_))); } TEST_P(ProtoMessageTypeAccessorTest, IsEqualTo) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); TestMessage example2; example2.mutable_int64_wrapper_value()->set_value(10); MessageWrapper value(&example, nullptr); MessageWrapper value2(&example2, nullptr); EXPECT_TRUE(accessor.IsEqualTo(value, value2)); EXPECT_TRUE(accessor.IsEqualTo(value2, value)); } TEST_P(ProtoMessageTypeAccessorTest, IsEqualToSameTypeInequal) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); TestMessage example2; example2.mutable_int64_wrapper_value()->set_value(12); MessageWrapper value(&example, nullptr); MessageWrapper value2(&example2, nullptr); EXPECT_FALSE(accessor.IsEqualTo(value, value2)); EXPECT_FALSE(accessor.IsEqualTo(value2, value)); } TEST_P(ProtoMessageTypeAccessorTest, IsEqualToDifferentTypeInequal) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); Int64Value example2; example2.set_value(10); MessageWrapper value(&example, nullptr); MessageWrapper value2(&example2, nullptr); EXPECT_FALSE(accessor.IsEqualTo(value, value2)); EXPECT_FALSE(accessor.IsEqualTo(value2, value)); } TEST_P(ProtoMessageTypeAccessorTest, IsEqualToNonMessageInequal) { const LegacyTypeAccessApis& accessor = GetAccessApis(); TestMessage example; example.mutable_int64_wrapper_value()->set_value(10); TestMessage example2; example2.mutable_int64_wrapper_value()->set_value(10); MessageWrapper value(&example, nullptr); // Upcast to message lite to prevent unwrapping to message. MessageWrapper value2(static_cast(&example2), nullptr); EXPECT_FALSE(accessor.IsEqualTo(value, value2)); EXPECT_FALSE(accessor.IsEqualTo(value2, value)); } INSTANTIATE_TEST_SUITE_P(GenericAndSpecific, ProtoMessageTypeAccessorTest, testing::Bool()); TEST(GetGenericProtoTypeInfoInstance, GetTypeName) { const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); TestMessage test_message; CelValue::MessageWrapper wrapped_message(&test_message, nullptr); EXPECT_EQ(info_api.GetTypename(wrapped_message), test_message.GetTypeName()); } TEST(GetGenericProtoTypeInfoInstance, DebugString) { const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); TestMessage test_message; test_message.set_string_value("abcd"); CelValue::MessageWrapper wrapped_message(&test_message, nullptr); EXPECT_EQ(info_api.DebugString(wrapped_message), test_message.ShortDebugString()); } TEST(GetGenericProtoTypeInfoInstance, GetAccessApis) { const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); TestMessage test_message; test_message.set_string_value("abcd"); CelValue::MessageWrapper wrapped_message(&test_message, nullptr); auto* accessor = info_api.GetAccessApis(wrapped_message); google::protobuf::Arena arena; auto manager = ProtoMemoryManagerRef(&arena); ASSERT_OK_AND_ASSIGN( CelValue result, accessor->GetField("string_value", wrapped_message, ProtoWrapperTypeOptions::kUnsetNull, manager)); EXPECT_THAT(result, test::IsCelString("abcd")); } TEST(GetGenericProtoTypeInfoInstance, FallbackForNonMessage) { const LegacyTypeInfoApis& info_api = GetGenericProtoTypeInfoInstance(); TestMessage test_message; test_message.set_string_value("abcd"); // Upcast to signal no google::protobuf::Message / reflection support. CelValue::MessageWrapper wrapped_message( static_cast(&test_message), nullptr); EXPECT_EQ(info_api.GetTypename(wrapped_message), ""); EXPECT_EQ(info_api.DebugString(wrapped_message), ""); // Check for not-null. CelValue::MessageWrapper null_message( static_cast(nullptr), nullptr); EXPECT_EQ(info_api.GetTypename(null_message), ""); EXPECT_EQ(info_api.DebugString(null_message), ""); } TEST(ProtoMessageTypeAdapter, NewInstance) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder result, adapter.NewInstance(manager)); EXPECT_EQ(result.message_ptr()->SerializeAsString(), ""); } TEST(ProtoMessageTypeAdapter, NewInstanceUnsupportedDescriptor) { google::protobuf::Arena arena; google::protobuf::DescriptorPool pool; google::protobuf::FileDescriptorProto faked_file; faked_file.set_name("faked.proto"); faked_file.set_syntax("proto3"); faked_file.set_package("google.api.expr.runtime"); auto msg_descriptor = faked_file.add_message_type(); msg_descriptor->set_name("FakeMessage"); pool.BuildFile(faked_file); ProtoMessageTypeAdapter adapter( pool.FindMessageTypeByName("google.api.expr.runtime.FakeMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); // Message factory doesn't know how to create our custom message, even though // we provided a descriptor for it. EXPECT_THAT( adapter.NewInstance(manager), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("FakeMessage"))); } TEST(ProtoMessageTypeAdapter, DefinesField) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); EXPECT_TRUE(adapter.DefinesField("int64_value")); EXPECT_FALSE(adapter.DefinesField("not_a_field")); } TEST(ProtoMessageTypeAdapter, SetFieldSingular) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder value, adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(10), manager, value)); TestMessage message; message.set_int64_value(10); EXPECT_EQ(value.message_ptr()->SerializeAsString(), message.SerializeAsString()); ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), manager, value), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("field 'not_a_field': not found"))); } TEST(ProtoMessageTypeAdapter, SetFieldRepeated) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); ContainerBackedListImpl list( {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); CelValue value_to_set = CelValue::CreateList(&list); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_list", value_to_set, manager, instance)); TestMessage message; message.add_int64_list(1); message.add_int64_list(2); EXPECT_EQ(instance.message_ptr()->SerializeAsString(), message.SerializeAsString()); } TEST(ProtoMessageTypeAdapter, SetFieldNotAField) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_THAT(adapter.SetField("not_a_field", CelValue::CreateInt64(10), manager, instance), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("field 'not_a_field': not found"))); } TEST(ProtoMesssageTypeAdapter, SetFieldWrongType) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); ContainerBackedListImpl list( {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); CelValue list_value = CelValue::CreateList(&list); CelMapBuilder builder; ASSERT_OK(builder.Add(CelValue::CreateInt64(1), CelValue::CreateInt64(2))); ASSERT_OK(builder.Add(CelValue::CreateInt64(2), CelValue::CreateInt64(4))); CelValue map_value = CelValue::CreateMap(&builder); CelValue int_value = CelValue::CreateInt64(42); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); EXPECT_THAT(adapter.SetField("int64_value", map_value, manager, instance), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(adapter.SetField("int64_value", list_value, manager, instance), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT( adapter.SetField("int64_int32_map", list_value, manager, instance), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(adapter.SetField("int64_int32_map", int_value, manager, instance), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(adapter.SetField("int64_list", int_value, manager, instance), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(adapter.SetField("int64_list", map_value, manager, instance), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ProtoMesssageTypeAdapter, SetFieldNotAMessage) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); CelValue int_value = CelValue::CreateInt64(42); CelValue::MessageWrapper::Builder instance( static_cast(nullptr)); EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), StatusIs(absl::StatusCode::kInternal)); } TEST(ProtoMesssageTypeAdapter, SetFieldNullMessage) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); CelValue int_value = CelValue::CreateInt64(42); CelValue::MessageWrapper::Builder instance( static_cast(nullptr)); EXPECT_THAT(adapter.SetField("int64_value", int_value, manager, instance), StatusIs(absl::StatusCode::kInternal)); } TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownType) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.protobuf.Int64Value"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_OK( adapter.SetField("value", CelValue::CreateInt64(42), manager, instance)); ASSERT_OK_AND_ASSIGN(CelValue value, adapter.AdaptFromWellKnownType(manager, instance)); EXPECT_THAT(value, test::IsCelInt64(42)); } TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeUnspecial) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder instance, adapter.NewInstance(manager)); ASSERT_OK(adapter.SetField("int64_value", CelValue::CreateInt64(42), manager, instance)); ASSERT_OK_AND_ASSIGN(CelValue value, adapter.AdaptFromWellKnownType(manager, instance)); // TestMessage should not be converted to a CEL primitive type. EXPECT_THAT(value, test::IsCelMessage(EqualsProto("int64_value: 42"))); } TEST(ProtoMessageTypeAdapter, AdaptFromWellKnownTypeNotAMessageError) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); CelValue::MessageWrapper::Builder instance( static_cast(nullptr)); // Interpreter guaranteed to call this with a message type, otherwise, // something has broken. EXPECT_THAT(adapter.AdaptFromWellKnownType(manager, instance), StatusIs(absl::StatusCode::kInternal)); } TEST(ProtoMesssageTypeAdapter, TypeInfoDebug) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); TestMessage message; message.set_int64_value(42); EXPECT_THAT(adapter.DebugString(MessageWrapper(&message, &adapter)), HasSubstr(message.ShortDebugString())); EXPECT_THAT(adapter.DebugString(MessageWrapper()), HasSubstr("")); } TEST(ProtoMesssageTypeAdapter, TypeInfoName) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); EXPECT_EQ(adapter.GetTypename(MessageWrapper()), "google.api.expr.runtime.TestMessage"); } TEST(ProtoMesssageTypeAdapter, FindFieldFound) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); EXPECT_THAT( adapter.FindFieldByName("int64_value"), Optional(Truly([](const LegacyTypeInfoApis::FieldDescription& desc) { return desc.name == "int64_value" && desc.number == 2; }))) << "expected field int64_value: 2"; } TEST(ProtoMesssageTypeAdapter, FindFieldNotFound) { ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); EXPECT_EQ(adapter.FindFieldByName("foo_not_a_field"), absl::nullopt); } TEST(ProtoMesssageTypeAdapter, TypeInfoMutator) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); const LegacyTypeMutationApis* api = adapter.GetMutationApis(MessageWrapper()); ASSERT_NE(api, nullptr); ASSERT_OK_AND_ASSIGN(MessageWrapper::Builder builder, api->NewInstance(manager)); EXPECT_NE(dynamic_cast(builder.message_ptr()), nullptr); } TEST(ProtoMesssageTypeAdapter, TypeInfoAccesor) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; message.set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); EXPECT_THAT(api->GetField("int64_value", wrapped, ProtoWrapperTypeOptions::kUnsetNull, manager), IsOkAndHolds(test::IsCelInt64(42))); } TEST(ProtoMesssageTypeAdapter, Qualify) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; message.mutable_message_value()->set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{2, "int64_value"}}; EXPECT_THAT( api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(42)))); } TEST(ProtoMesssageTypeAdapter, QualifyDynamicFieldAccessUnsupported) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; message.mutable_message_value()->set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::AttributeQualifier::OfString("int64_value")}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), StatusIs(absl::StatusCode::kUnimplemented)); } TEST(ProtoMesssageTypeAdapter, QualifyNoSuchField) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; message.mutable_message_value()->set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{99, "not_a_field"}, cel::FieldSpecifier{2, "int64_value"}}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field( &LegacyQualifyResult::value, test::IsCelError(StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field")))))); } TEST(ProtoMesssageTypeAdapter, QualifyHasNoSuchField) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; message.mutable_message_value()->set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{99, "not_a_field"}}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/true, manager), IsOkAndHolds(Field( &LegacyQualifyResult::value, test::IsCelError(StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field")))))); } TEST(ProtoMesssageTypeAdapter, QualifyNoSuchFieldLeaf) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; message.mutable_message_value()->set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{99, "not_a_field"}}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field( &LegacyQualifyResult::value, test::IsCelError(StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field")))))); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalSupport) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_string_message_map())["@key"].set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{210, "string_message_map"}, cel::AttributeQualifier::OfString("@key"), cel::FieldSpecifier{2, "int64_value"}}; EXPECT_THAT( api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(42)))); } TEST(ProtoMesssageTypeAdapter, TypedFieldAccessOnMapUnsupported) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_string_message_map())["@key"].set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{210, "string_message_map"}, // This is probably a bug, but defer to evaluator for consistent handling. cel::FieldSpecifier{2, "value"}, cel::FieldSpecifier{2, "int64_value"}}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), StatusIs(absl::StatusCode::kUnimplemented)); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalWrongKeyType) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_string_message_map())["@key"].set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{210, "string_message_map"}, cel::AttributeQualifier::OfInt(0), cel::FieldSpecifier{2, "int64_value"}}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelError(StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("Invalid map key type")))))); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalHasWrongKeyType) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_string_message_map())["@key"].set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{210, "string_message_map"}, cel::AttributeQualifier::OfInt(0)}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/true, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelError(StatusIs( absl::StatusCode::kUnknown, HasSubstr("No matching overloads")))))); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalSupportNoSuchKey) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_string_message_map())["@key"].set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{210, "string_message_map"}, cel::AttributeQualifier::OfString("bad_key"), cel::FieldSpecifier{2, "int64_value"}}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field( &LegacyQualifyResult::value, test::IsCelError(StatusIs(absl::StatusCode::kNotFound, HasSubstr("Key not found")))))); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalInt32Key) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_int32_int32_map())[0] = 42; CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{205, "int32_int32_map"}, cel::AttributeQualifier::OfInt(0)}; EXPECT_THAT( api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(42)))); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalIntOutOfRange) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_int32_int32_map())[0] = 42; CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{205, "int32_int32_map"}, cel::AttributeQualifier::OfInt(1LL << 32)}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field( &LegacyQualifyResult::value, test::IsCelError(StatusIs(absl::StatusCode::kOutOfRange, HasSubstr("integer overflow")))))); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalUint32Key) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_uint32_uint32_map())[0] = 42; CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{206, "uint32_uint32_map"}, cel::AttributeQualifier::OfUint(0)}; EXPECT_THAT( api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelUint64(42)))); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalUintOutOfRange) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_uint32_uint32_map())[0] = 42; CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{206, "uint32_uint32_map"}, cel::AttributeQualifier::OfUint(1LL << 32)}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field( &LegacyQualifyResult::value, test::IsCelError(StatusIs(absl::StatusCode::kOutOfRange, HasSubstr("integer overflow")))))); } TEST(ProtoMesssageTypeAdapter, QualifyMapTraversalUnexpectedFieldAccess) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_string_message_map())["@key"].set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{210, "string_message_map"}, // For coverage check that qualify gives up if there's a strong field // access requested for a map. cel::FieldSpecifier{0, "field_like_key"}}; auto result = api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager); EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), StatusIs(absl::StatusCode::kUnimplemented, _)); } TEST(ProtoMesssageTypeAdapter, UntypedQualifiersNotYetSupported) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; (*message.mutable_string_message_map())["@key"].set_int64_value(42); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::AttributeQualifier::OfString("string_message_map"), cel::AttributeQualifier::OfString("@key"), cel::AttributeQualifier::OfString("int64_value")}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), StatusIs(absl::StatusCode::kUnimplemented, _)); } TEST(ProtoMesssageTypeAdapter, QualifyRepeatedIndexWrongType) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; message.add_message_list()->add_int64_list(1); message.add_message_list()->add_int64_list(2); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{112, "message_list"}, cel::AttributeQualifier::OfBool(false), cel::FieldSpecifier{102, "int64_list"}, cel::AttributeQualifier::OfInt(0)}; EXPECT_THAT( api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelError(StatusIs( absl::StatusCode::kUnknown, HasSubstr("No matching overloads found")))))); } TEST(ProtoMesssageTypeAdapter, QualifyRepeatedTypeCheckError) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; message.add_int64_list(1); message.add_int64_list(2); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{102, "int64_list"}, cel::AttributeQualifier::OfInt(0), // index on an int. cel::AttributeQualifier::OfInt(1)}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), StatusIs(absl::StatusCode::kInternal, HasSubstr("Unexpected qualify intermediate type"))); } TEST(ProtoMesssageTypeAdapter, QualifyRepeatedLeaf) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; auto* nested = message.mutable_message_value(); nested->add_int64_list(1); nested->add_int64_list(2); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{102, "int64_list"}, }; EXPECT_THAT( api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelList(ElementsAre(test::IsCelInt64(1), test::IsCelInt64(2)))))); } TEST(ProtoMesssageTypeAdapter, QualifyRepeatedIndexLeaf) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; auto* nested = message.mutable_message_value(); nested->add_int64_list(1); nested->add_int64_list(2); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{102, "int64_list"}, cel::AttributeQualifier::OfInt(1)}; EXPECT_THAT( api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(2)))); } TEST(ProtoMesssageTypeAdapter, QualifyRepeatedIndexLeafOutOfBounds) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; auto* nested = message.mutable_message_value(); nested->add_int64_list(1); nested->add_int64_list(2); CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{102, "int64_list"}, cel::AttributeQualifier::OfInt(2)}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelError(StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("index out of bounds")))))); } TEST(ProtoMesssageTypeAdapter, QualifyMapLeaf) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; auto* nested_map = message.mutable_message_value()->mutable_string_int32_map(); (*nested_map)["@key"] = 42; (*nested_map)["@key2"] = -42; CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{203, "string_int32_map"}, }; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field( &LegacyQualifyResult::value, Truly([](const CelValue& v) { return v.IsMap() && v.MapOrDie()->size() == 2; })))); } TEST(ProtoMesssageTypeAdapter, QualifyMapIndexLeaf) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; auto* nested_map = message.mutable_message_value()->mutable_string_int32_map(); (*nested_map)["@key"] = 42; (*nested_map)["@key2"] = -42; CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{203, "string_int32_map"}, cel::AttributeQualifier::OfString("@key")}; EXPECT_THAT( api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelInt64(42)))); } TEST(ProtoMesssageTypeAdapter, QualifyMapIndexLeafWrongType) { google::protobuf::Arena arena; ProtoMessageTypeAdapter adapter( google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); auto manager = ProtoMemoryManagerRef(&arena); TestMessage message; auto* nested_map = message.mutable_message_value()->mutable_string_int32_map(); (*nested_map)["@key"] = 42; (*nested_map)["@key2"] = -42; CelValue::MessageWrapper wrapped(&message, &adapter); const LegacyTypeAccessApis* api = adapter.GetAccessApis(MessageWrapper()); ASSERT_NE(api, nullptr); std::vector qualfiers{ cel::FieldSpecifier{12, "message_value"}, cel::FieldSpecifier{203, "string_int32_map"}, cel::AttributeQualifier::OfInt(0)}; EXPECT_THAT(api->Qualify(qualfiers, wrapped, /*presence_test=*/false, manager), IsOkAndHolds(Field(&LegacyQualifyResult::value, test::IsCelError(StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("Invalid map key type")))))); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/protobuf_descriptor_type_provider.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/protobuf_descriptor_type_provider.h" #include #include #include "absl/synchronization/mutex.h" #include "eval/public/structs/proto_message_type_adapter.h" #include "google/protobuf/descriptor.h" namespace google::api::expr::runtime { absl::optional ProtobufDescriptorProvider::ProvideLegacyType( absl::string_view name) const { const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); if (result == nullptr) { return absl::nullopt; } // ProtoMessageTypeAdapter provides apis for both access and mutation. return LegacyTypeAdapter(result, result); } absl::optional ProtobufDescriptorProvider::ProvideLegacyTypeInfo( absl::string_view name) const { const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); if (result == nullptr) { return absl::nullopt; } return result; } std::unique_ptr ProtobufDescriptorProvider::CreateTypeAdapter(absl::string_view name) const { const google::protobuf::Descriptor* descriptor = descriptor_pool_->FindMessageTypeByName(name); if (descriptor == nullptr) { return nullptr; } return std::make_unique(descriptor, message_factory_); } const ProtoMessageTypeAdapter* ProtobufDescriptorProvider::GetTypeAdapter( absl::string_view name) const { absl::MutexLock lock(mu_); auto it = type_cache_.find(name); if (it != type_cache_.end()) { return it->second.get(); } auto type_provider = CreateTypeAdapter(name); const ProtoMessageTypeAdapter* result = type_provider.get(); type_cache_[name] = std::move(type_provider); return result; } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/protobuf_descriptor_type_provider.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ #include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "eval/public/structs/legacy_type_provider.h" #include "eval/public/structs/proto_message_type_adapter.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { // Implementation of a type provider that generates types from protocol buffer // descriptors. class ProtobufDescriptorProvider : public LegacyTypeProvider { public: ProtobufDescriptorProvider(const google::protobuf::DescriptorPool* pool, google::protobuf::MessageFactory* factory) : descriptor_pool_(pool), message_factory_(factory) {} absl::optional ProvideLegacyType( absl::string_view name) const final; absl::optional ProvideLegacyTypeInfo( absl::string_view name) const final; private: // Create a new type instance if found in the registered descriptor pool. // Otherwise, returns nullptr. std::unique_ptr CreateTypeAdapter( absl::string_view name) const; const ProtoMessageTypeAdapter* GetTypeAdapter(absl::string_view name) const; const google::protobuf::DescriptorPool* descriptor_pool_; google::protobuf::MessageFactory* message_factory_; mutable absl::flat_hash_map> type_cache_ ABSL_GUARDED_BY(mu_); mutable absl::Mutex mu_; }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_DESCRIPTOR_TYPE_PROVIDER_H_ ================================================ FILE: eval/public/structs/protobuf_descriptor_type_provider_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/protobuf_descriptor_type_provider.h" #include #include "google/protobuf/wrappers.pb.h" #include "eval/public/cel_value.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "eval/public/testing/matchers.h" #include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::cel::extensions::ProtoMemoryManager; TEST(ProtobufDescriptorProvider, Basic) { ProtobufDescriptorProvider provider( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory()); google::protobuf::Arena arena; auto manager = ProtoMemoryManager(&arena); auto type_adapter = provider.ProvideLegacyType("google.protobuf.Int64Value"); absl::optional type_info = provider.ProvideLegacyTypeInfo("google.protobuf.Int64Value"); ASSERT_TRUE(type_adapter.has_value()); ASSERT_TRUE(type_adapter->mutation_apis() != nullptr); ASSERT_TRUE(type_info.has_value()); ASSERT_TRUE(type_info != nullptr); google::protobuf::Int64Value int64_value; CelValue::MessageWrapper int64_cel_value(&int64_value, *type_info); EXPECT_EQ((*type_info)->GetTypename(int64_cel_value), "google.protobuf.Int64Value"); ASSERT_TRUE(type_adapter->mutation_apis()->DefinesField("value")); ASSERT_OK_AND_ASSIGN(CelValue::MessageWrapper::Builder value, type_adapter->mutation_apis()->NewInstance(manager)); ASSERT_OK(type_adapter->mutation_apis()->SetField( "value", CelValue::CreateInt64(10), manager, value)); ASSERT_OK_AND_ASSIGN( CelValue adapted, type_adapter->mutation_apis()->AdaptFromWellKnownType(manager, value)); EXPECT_THAT(adapted, test::IsCelInt64(10)); } // This is an implementation detail, but testing for coverage. TEST(ProtobufDescriptorProvider, MemoizesAdapters) { ProtobufDescriptorProvider provider( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory()); auto type_adapter = provider.ProvideLegacyType("google.protobuf.Int64Value"); ASSERT_TRUE(type_adapter.has_value()); ASSERT_TRUE(type_adapter->mutation_apis() != nullptr); auto type_adapter2 = provider.ProvideLegacyType("google.protobuf.Int64Value"); ASSERT_TRUE(type_adapter2.has_value()); EXPECT_EQ(type_adapter->mutation_apis(), type_adapter2->mutation_apis()); EXPECT_EQ(type_adapter->access_apis(), type_adapter2->access_apis()); } TEST(ProtobufDescriptorProvider, NotFound) { ProtobufDescriptorProvider provider( google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory()); auto type_adapter = provider.ProvideLegacyType("UnknownType"); auto type_info = provider.ProvideLegacyTypeInfo("UnknownType"); ASSERT_FALSE(type_adapter.has_value()); ASSERT_FALSE(type_info.has_value()); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/structs/protobuf_value_factory.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ #include #include "eval/public/cel_value.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime::internal { // Definiton for factory producing a properly initialized message-typed // CelValue. // // google::protobuf::Message is assumed adapted as possible, so this function just // associates it with appropriate type information. // // Used to break cyclic dependency between field access and message wrapping -- // not intended for general use. using ProtobufValueFactory = CelValue (*)(const google::protobuf::Message*); } // namespace google::api::expr::runtime::internal #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_PROTOBUF_VALUE_FACTORY_H_ ================================================ FILE: eval/public/structs/trivial_legacy_type_info.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ #include #include "absl/base/no_destructor.h" #include "absl/strings/string_view.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_info_apis.h" namespace google::api::expr::runtime { // Implementation of type info APIs suitable for testing where no message // operations need to be supported. class TrivialTypeInfo : public LegacyTypeInfoApis { public: absl::string_view GetTypename(const MessageWrapper& wrapper) const override { return "opaque"; } std::string DebugString(const MessageWrapper& wrapper) const override { return "opaque"; } const LegacyTypeAccessApis* GetAccessApis( const MessageWrapper& wrapper) const override { // Accessors unsupported -- caller should treat this as an opaque type (no // fields defined, field access always results in a CEL error). return nullptr; } static const TrivialTypeInfo* GetInstance() { static absl::NoDestructor kInstance; return &*kInstance; } }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_STRUCTS_TRIVIAL_LEGACY_TYPE_INFO_H_ ================================================ FILE: eval/public/structs/trivial_legacy_type_info_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/public/message_wrapper.h" #include "internal/testing.h" namespace google::api::expr::runtime { namespace { TEST(TrivialTypeInfo, GetTypename) { TrivialTypeInfo info; MessageWrapper wrapper; EXPECT_EQ(info.GetTypename(wrapper), "opaque"); EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetTypename(wrapper), "opaque"); } TEST(TrivialTypeInfo, DebugString) { TrivialTypeInfo info; MessageWrapper wrapper; EXPECT_EQ(info.DebugString(wrapper), "opaque"); EXPECT_EQ(TrivialTypeInfo::GetInstance()->DebugString(wrapper), "opaque"); } TEST(TrivialTypeInfo, GetAccessApis) { TrivialTypeInfo info; MessageWrapper wrapper; EXPECT_EQ(info.GetAccessApis(wrapper), nullptr); EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetAccessApis(wrapper), nullptr); } TEST(TrivialTypeInfo, GetMutationApis) { TrivialTypeInfo info; MessageWrapper wrapper; EXPECT_EQ(info.GetMutationApis(wrapper), nullptr); EXPECT_EQ(TrivialTypeInfo::GetInstance()->GetMutationApis(wrapper), nullptr); } TEST(TrivialTypeInfo, FindFieldByName) { TrivialTypeInfo info; MessageWrapper wrapper; EXPECT_EQ(info.FindFieldByName("foo"), absl::nullopt); EXPECT_EQ(TrivialTypeInfo::GetInstance()->FindFieldByName("foo"), absl::nullopt); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/public/testing/BUILD ================================================ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package( default_testonly = True, default_visibility = ["//visibility:public"], ) licenses(["notice"]) cc_library( name = "matchers", srcs = ["matchers.cc"], hdrs = ["matchers.h"], deps = [ "//eval/public:cel_value", "//eval/public:set_util", "//internal:casts", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "matchers_test", srcs = ["matchers_test.cc"], deps = [ ":matchers", "//eval/public/containers:container_backed_list_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//internal:testing", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/time", ], ) ================================================ FILE: eval/public/testing/matchers.cc ================================================ #include "eval/public/testing/matchers.h" #include #include #include #include "absl/strings/string_view.h" #include "eval/public/cel_value.h" #include "eval/public/set_util.h" #include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { void PrintTo(const CelValue& value, std::ostream* os) { *os << value.DebugString(); } namespace test { namespace { using ::testing::_; using ::testing::MatcherInterface; using ::testing::MatchResultListener; class CelValueEqualImpl : public MatcherInterface { public: explicit CelValueEqualImpl(const CelValue& v) : value_(v) {} bool MatchAndExplain(CelValue arg, MatchResultListener* listener) const override { return CelValueEqual(arg, value_); } void DescribeTo(std::ostream* os) const override { *os << value_.DebugString(); } private: const CelValue& value_; }; // used to implement matchers for CelValues template class CelValueMatcherImpl : public testing::MatcherInterface { public: explicit CelValueMatcherImpl(testing::Matcher m) : underlying_type_matcher_(std::move(m)) {} bool MatchAndExplain(const CelValue& v, testing::MatchResultListener* listener) const override { UnderlyingType arg; return v.GetValue(&arg) && underlying_type_matcher_.Matches(arg); } void DescribeTo(std::ostream* os) const override { CelValue::Type type = static_cast(CelValue::IndexOf::value); *os << absl::StrCat("type is ", CelValue::TypeName(type), " and "); underlying_type_matcher_.DescribeTo(os); } private: const testing::Matcher underlying_type_matcher_; }; // Template specialization for google::protobuf::Message. template <> class CelValueMatcherImpl : public testing::MatcherInterface { public: explicit CelValueMatcherImpl(testing::Matcher m) : underlying_type_matcher_(std::move(m)) {} bool MatchAndExplain(const CelValue& v, testing::MatchResultListener* listener) const override { CelValue::MessageWrapper arg; return v.GetValue(&arg) && arg.HasFullProto() && underlying_type_matcher_.Matches( cel::internal::down_cast( arg.message_ptr())); } void DescribeTo(std::ostream* os) const override { *os << absl::StrCat("type is ", CelValue::TypeName(CelValue::Type::kMessage), " and "); underlying_type_matcher_.DescribeTo(os); } private: const testing::Matcher underlying_type_matcher_; }; } // namespace CelValueMatcher EqualsCelValue(const CelValue& v) { return CelValueMatcher(new CelValueEqualImpl(v)); } CelValueMatcher IsCelNull() { return CelValueMatcher(new CelValueMatcherImpl(_)); } CelValueMatcher IsCelBool(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelInt64(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelUint64(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelDouble(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelString(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl( testing::Property(&CelValue::StringHolder::value, m))); } CelValueMatcher IsCelBytes(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl( testing::Property(&CelValue::BytesHolder::value, m))); } CelValueMatcher IsCelMessage(testing::Matcher m) { return CelValueMatcher( new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelDuration(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelTimestamp(testing::Matcher m) { return CelValueMatcher(new CelValueMatcherImpl(std::move(m))); } CelValueMatcher IsCelError(testing::Matcher m) { return CelValueMatcher( new CelValueMatcherImpl( testing::AllOf(testing::NotNull(), testing::Pointee(m)))); } } // namespace test } // namespace google::api::expr::runtime ================================================ FILE: eval/public/testing/matchers.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TESTING_MATCHERS_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TESTING_MATCHERS_H_ #include #include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "eval/public/cel_value.h" #include "internal/testing.h" #include "google/protobuf/message.h" namespace google { namespace api { namespace expr { namespace runtime { // GTest Printer void PrintTo(const CelValue& value, std::ostream* os); namespace test { // readability alias using CelValueMatcher = testing::Matcher; // Tests equality to CelValue v using the set_util implementation. CelValueMatcher EqualsCelValue(const CelValue& v); // Matches CelValues of type null. CelValueMatcher IsCelNull(); // Matches CelValues of type bool whose held value matches |m|. CelValueMatcher IsCelBool(testing::Matcher m); // Matches CelValues of type int64 whose held value matches |m|. CelValueMatcher IsCelInt64(testing::Matcher m); // Matches CelValues of type uint64_t whose held value matches |m|. CelValueMatcher IsCelUint64(testing::Matcher m); // Matches CelValues of type double whose held value matches |m|. CelValueMatcher IsCelDouble(testing::Matcher m); // Matches CelValues of type string whose held value matches |m|. CelValueMatcher IsCelString(testing::Matcher m); // Matches CelValues of type bytes whose held value matches |m|. CelValueMatcher IsCelBytes(testing::Matcher m); // Matches CelValues of type message whose held value matches |m|. CelValueMatcher IsCelMessage(testing::Matcher m); // Matches CelValues of type duration whose held value matches |m|. CelValueMatcher IsCelDuration(testing::Matcher m); // Matches CelValues of type timestamp whose held value matches |m|. CelValueMatcher IsCelTimestamp(testing::Matcher m); // Matches CelValues of type error whose held value matches |m|. // The matcher |m| is wrapped to allow using the testing::status::... matchers. CelValueMatcher IsCelError(testing::Matcher m); // A matcher that wraps a Container matcher so that container matchers can be // used for matching CelList. // // This matcher can be avoided if CelList supported the iterators needed by the // standard container matchers but given that it is an interface it is a much // larger project. // // TODO(issues/73): Re-use CelValueMatcherImpl. There are template details // that need to be worked out specifically on how CelValueMatcherImpl can accept // a generic matcher for CelList instead of testing::Matcher. template class CelListMatcher : public testing::MatcherInterface { public: explicit CelListMatcher(ContainerMatcher m) : container_matcher_(m) {} bool MatchAndExplain(const CelValue& v, testing::MatchResultListener* listener) const override { const CelList* cel_list; if (!v.GetValue(&cel_list) || cel_list == nullptr) return false; std::vector cel_vector; cel_vector.reserve(cel_list->size()); for (int i = 0; i < cel_list->size(); ++i) { cel_vector.push_back((*cel_list)[i]); } return container_matcher_.Matches(cel_vector); } void DescribeTo(std::ostream* os) const override { CelValue::Type type = static_cast(CelValue::IndexOf::value); *os << absl::StrCat("type is ", CelValue::TypeName(type), " and "); container_matcher_.DescribeTo(os); } private: const testing::Matcher> container_matcher_; }; template CelValueMatcher IsCelList(ContainerMatcher m) { return CelValueMatcher(new CelListMatcher(m)); } // TODO(issues/73): add helpers for working with maps and unknown sets. } // namespace test } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TESTING_MATCHERS_H_ ================================================ FILE: eval/public/testing/matchers_test.cc ================================================ #include "eval/public/testing/matchers.h" #include "absl/status/status.h" #include "absl/time/time.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "testutil/util.h" namespace google::api::expr::runtime::test { namespace { using ::testing::Contains; using ::testing::DoubleEq; using ::testing::DoubleNear; using ::testing::ElementsAre; using ::testing::Gt; using ::testing::Lt; using ::testing::Not; using ::testing::UnorderedElementsAre; using testutil::EqualsProto; TEST(IsCelValue, EqualitySmoketest) { EXPECT_THAT(CelValue::CreateBool(true), EqualsCelValue(CelValue::CreateBool(true))); EXPECT_THAT(CelValue::CreateInt64(-1), EqualsCelValue(CelValue::CreateInt64(-1))); EXPECT_THAT(CelValue::CreateUint64(2), EqualsCelValue(CelValue::CreateUint64(2))); EXPECT_THAT(CelValue::CreateDouble(1.25), EqualsCelValue(CelValue::CreateDouble(1.25))); EXPECT_THAT(CelValue::CreateStringView("abc"), EqualsCelValue(CelValue::CreateStringView("abc"))); EXPECT_THAT(CelValue::CreateBytesView("def"), EqualsCelValue(CelValue::CreateBytesView("def"))); EXPECT_THAT(CelValue::CreateDuration(absl::Seconds(2)), EqualsCelValue(CelValue::CreateDuration(absl::Seconds(2)))); EXPECT_THAT( CelValue::CreateTimestamp(absl::FromUnixSeconds(1)), EqualsCelValue(CelValue::CreateTimestamp(absl::FromUnixSeconds(1)))); EXPECT_THAT(CelValue::CreateInt64(-1), Not(EqualsCelValue(CelValue::CreateBool(true)))); EXPECT_THAT(CelValue::CreateUint64(2), Not(EqualsCelValue(CelValue::CreateInt64(-1)))); EXPECT_THAT(CelValue::CreateDouble(1.25), Not(EqualsCelValue(CelValue::CreateUint64(2)))); EXPECT_THAT(CelValue::CreateStringView("abc"), Not(EqualsCelValue(CelValue::CreateDouble(1.25)))); EXPECT_THAT(CelValue::CreateBytesView("def"), Not(EqualsCelValue(CelValue::CreateStringView("abc")))); EXPECT_THAT(CelValue::CreateDuration(absl::Seconds(2)), Not(EqualsCelValue(CelValue::CreateBytesView("def")))); EXPECT_THAT(CelValue::CreateTimestamp(absl::FromUnixSeconds(1)), Not(EqualsCelValue(CelValue::CreateDuration(absl::Seconds(2))))); EXPECT_THAT( CelValue::CreateBool(true), Not(EqualsCelValue(CelValue::CreateTimestamp(absl::FromUnixSeconds(1))))); } TEST(PrimitiveMatchers, Smoketest) { EXPECT_THAT(CelValue::CreateNull(), IsCelNull()); EXPECT_THAT(CelValue::CreateBool(false), Not(IsCelNull())); EXPECT_THAT(CelValue::CreateBool(true), IsCelBool(true)); EXPECT_THAT(CelValue::CreateBool(false), IsCelBool(Not(true))); EXPECT_THAT(CelValue::CreateInt64(1), IsCelInt64(1)); EXPECT_THAT(CelValue::CreateInt64(-1), IsCelInt64(Not(Gt(0)))); EXPECT_THAT(CelValue::CreateUint64(1), IsCelUint64(1)); EXPECT_THAT(CelValue::CreateUint64(2), IsCelUint64(Not(Lt(2)))); EXPECT_THAT(CelValue::CreateDouble(1.5), IsCelDouble(DoubleEq(1.5))); EXPECT_THAT(CelValue::CreateDouble(1.0 + 0.8), IsCelDouble(DoubleNear(1.8, 1e-5))); EXPECT_THAT(CelValue::CreateStringView("abc"), IsCelString("abc")); EXPECT_THAT(CelValue::CreateStringView("abcdef"), IsCelString(testing::HasSubstr("def"))); EXPECT_THAT(CelValue::CreateBytesView("abc"), IsCelBytes("abc")); EXPECT_THAT(CelValue::CreateBytesView("abcdef"), IsCelBytes(testing::HasSubstr("def"))); EXPECT_THAT(CelValue::CreateDuration(absl::Seconds(2)), IsCelDuration(Lt(absl::Minutes(1)))); EXPECT_THAT(CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), IsCelTimestamp(Lt(absl::FromUnixSeconds(30)))); } TEST(PrimitiveMatchers, WrongType) { EXPECT_THAT(CelValue::CreateBool(true), Not(IsCelInt64(1))); EXPECT_THAT(CelValue::CreateInt64(1), Not(IsCelUint64(1))); EXPECT_THAT(CelValue::CreateUint64(1), Not(IsCelDouble(1.0))); EXPECT_THAT(CelValue::CreateDouble(1.5), Not(IsCelString("abc"))); EXPECT_THAT(CelValue::CreateStringView("abc"), Not(IsCelBytes("abc"))); EXPECT_THAT(CelValue::CreateBytesView("abc"), Not(IsCelDuration(Lt(absl::Minutes(1))))); EXPECT_THAT(CelValue::CreateDuration(absl::Seconds(2)), Not(IsCelTimestamp(Lt(absl::FromUnixSeconds(30))))); EXPECT_THAT(CelValue::CreateTimestamp(absl::FromUnixSeconds(20)), Not(IsCelBool(true))); } TEST(SpecialMatchers, SmokeTest) { auto status = absl::InternalError("error"); CelValue error = CelValue::CreateError(&status); EXPECT_THAT(error, IsCelError(testing::Eq( absl::Status(absl::StatusCode::kInternal, "error")))); TestMessage proto_message; proto_message.add_bool_list(true); proto_message.add_bool_list(false); proto_message.add_int64_list(1); proto_message.add_int64_list(-1); CelValue message = CelProtoWrapper::CreateMessage(&proto_message, nullptr); EXPECT_THAT(message, IsCelMessage(EqualsProto(proto_message))); } TEST(ListMatchers, NotList) { EXPECT_THAT(CelValue::CreateInt64(1), Not(IsCelList(Contains(IsCelInt64(1))))); } TEST(ListMatchers, All) { ContainerBackedListImpl list({ CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3), CelValue::CreateInt64(4), }); CelValue cel_list = CelValue::CreateList(&list); EXPECT_THAT(cel_list, IsCelList(Contains(IsCelInt64(3)))); EXPECT_THAT(cel_list, IsCelList(Not(Contains(IsCelInt64(0))))); EXPECT_THAT(cel_list, IsCelList(ElementsAre(IsCelInt64(1), IsCelInt64(2), IsCelInt64(3), IsCelInt64(4)))); EXPECT_THAT(cel_list, IsCelList(Not(ElementsAre(IsCelInt64(2), IsCelInt64(1), IsCelInt64(3), IsCelInt64(4))))); EXPECT_THAT(cel_list, IsCelList(UnorderedElementsAre(IsCelInt64(2), IsCelInt64(1), IsCelInt64(4), IsCelInt64(3)))); EXPECT_THAT( cel_list, IsCelList(Not(UnorderedElementsAre(IsCelInt64(2), IsCelInt64(1), IsCelInt64(4), IsCelInt64(0))))); } } // namespace } // namespace google::api::expr::runtime::test ================================================ FILE: eval/public/transform_utility.cc ================================================ #include "eval/public/transform_utility.h" #include #include #include #include #include "cel/expr/value.pb.h" #include "google/protobuf/any.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "internal/proto_time_encoding.h" #include "internal/status_macros.h" namespace google { namespace api { namespace expr { namespace runtime { absl::Status CelValueToValue(const CelValue& value, Value* result, google::protobuf::Arena* arena) { switch (value.type()) { case CelValue::Type::kBool: result->set_bool_value(value.BoolOrDie()); break; case CelValue::Type::kInt64: result->set_int64_value(value.Int64OrDie()); break; case CelValue::Type::kUint64: result->set_uint64_value(value.Uint64OrDie()); break; case CelValue::Type::kDouble: result->set_double_value(value.DoubleOrDie()); break; case CelValue::Type::kString: result->set_string_value(value.StringOrDie().value().data(), value.StringOrDie().value().size()); break; case CelValue::Type::kBytes: result->set_bytes_value(value.BytesOrDie().value().data(), value.BytesOrDie().value().size()); break; case CelValue::Type::kDuration: { google::protobuf::Duration duration; auto status = cel::internal::EncodeDuration(value.DurationOrDie(), &duration); if (!status.ok()) { return status; } result->mutable_object_value()->PackFrom(duration); break; } case CelValue::Type::kTimestamp: { google::protobuf::Timestamp timestamp; auto status = cel::internal::EncodeTime(value.TimestampOrDie(), ×tamp); if (!status.ok()) { return status; } result->mutable_object_value()->PackFrom(timestamp); break; } case CelValue::Type::kNullType: result->set_null_value(google::protobuf::NullValue::NULL_VALUE); break; case CelValue::Type::kMessage: if (value.IsNull()) { result->set_null_value(google::protobuf::NullValue::NULL_VALUE); } else { result->mutable_object_value()->PackFrom(*value.MessageOrDie()); } break; case CelValue::Type::kList: { auto& list = *value.ListOrDie(); auto* list_value = result->mutable_list_value(); for (int i = 0; i < list.size(); ++i) { CEL_RETURN_IF_ERROR(CelValueToValue(list.Get(arena, i), list_value->add_values(), arena)); } break; } case CelValue::Type::kMap: { auto* map_value = result->mutable_map_value(); auto& cel_map = *value.MapOrDie(); CEL_ASSIGN_OR_RETURN(const auto* keys, cel_map.ListKeys(arena)); for (int i = 0; i < keys->size(); ++i) { CelValue key = (*keys).Get(arena, i); auto* entry = map_value->add_entries(); CEL_RETURN_IF_ERROR(CelValueToValue(key, entry->mutable_key(), arena)); auto optional_value = cel_map.Get(arena, key); if (!optional_value) { return absl::Status(absl::StatusCode::kInternal, "key not found in map"); } CEL_RETURN_IF_ERROR( CelValueToValue(*optional_value, entry->mutable_value(), arena)); } break; } case CelValue::Type::kError: // TODO(issues/87): Migrate to google.api.expr.ExprValue result->set_string_value("CelValue::Type::kError"); break; case CelValue::Type::kCelType: result->set_type_value(value.CelTypeOrDie().value().data(), value.CelTypeOrDie().value().size()); break; case CelValue::Type::kAny: // kAny is a special value used in function descriptors. return absl::Status(absl::StatusCode::kInternal, "CelValue has type kAny"); default: return absl::Status( absl::StatusCode::kUnimplemented, absl::StrCat("Can't convert ", CelValue::TypeName(value.type()), " to Constant.")); } return absl::OkStatus(); } absl::StatusOr ValueToCelValue(const Value& value, google::protobuf::Arena* arena) { switch (value.kind_case()) { case Value::kBoolValue: return CelValue::CreateBool(value.bool_value()); case Value::kBytesValue: return CelValue::CreateBytes(CelValue::BytesHolder( arena->Create(arena, value.bytes_value()))); case Value::kDoubleValue: return CelValue::CreateDouble(value.double_value()); case Value::kEnumValue: return CelValue::CreateInt64(value.enum_value().value()); case Value::kInt64Value: return CelValue::CreateInt64(value.int64_value()); case Value::kListValue: { std::vector list; for (const auto& subvalue : value.list_value().values()) { CEL_ASSIGN_OR_RETURN(auto list_value, ValueToCelValue(subvalue, arena)); list.push_back(list_value); } return CelValue::CreateList( arena->Create(arena, list)); } case Value::kMapValue: { std::vector> key_values; for (const auto& entry : value.map_value().entries()) { CEL_ASSIGN_OR_RETURN(auto map_key, ValueToCelValue(entry.key(), arena)); CEL_RETURN_IF_ERROR(CelValue::CheckMapKeyType(map_key)); CEL_ASSIGN_OR_RETURN(auto map_value, ValueToCelValue(entry.value(), arena)); key_values.push_back(std::pair(map_key, map_value)); } CEL_ASSIGN_OR_RETURN( auto cel_map, CreateContainerBackedMap(absl::Span>( key_values.data(), key_values.size()))); auto* cel_map_ptr = cel_map.release(); arena->Own(cel_map_ptr); return CelValue::CreateMap(cel_map_ptr); } case Value::kNullValue: return CelValue::CreateNull(); case Value::kObjectValue: { auto cel_value = CelProtoWrapper::CreateMessage(&value.object_value(), arena); if (cel_value.IsError()) return *cel_value.ErrorOrDie(); return cel_value; } case Value::kStringValue: return CelValue::CreateString(CelValue::StringHolder( arena->Create(arena, value.string_value()))); case Value::kTypeValue: return CelValue::CreateCelType(CelValue::CelTypeHolder( arena->Create(arena, value.type_value()))); case Value::kUint64Value: return CelValue::CreateUint64(value.uint64_value()); case Value::KIND_NOT_SET: default: return absl::InvalidArgumentError("Value proto is not set"); } } } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/transform_utility.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TRANSFORM_UTILITY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TRANSFORM_UTILITY_H_ #include "cel/expr/value.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/public/cel_value.h" #include "google/protobuf/arena.h" namespace google { namespace api { namespace expr { namespace runtime { using cel::expr::Value; // Translates a CelValue into a cel::expr::Value. Returns an error if // translation is not supported. absl::Status CelValueToValue(const CelValue& value, Value* result, google::protobuf::Arena* arena); inline absl::Status CelValueToValue(const CelValue& value, Value* result) { google::protobuf::Arena arena; return CelValueToValue(value, result, &arena); } // Translates a cel::expr::Value into a CelValue. Allocates any required // external data on the provided arena. Returns an error if translation is not // supported. absl::StatusOr ValueToCelValue(const Value& value, google::protobuf::Arena* arena); } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_TRANSFORM_UTILITY_H_ ================================================ FILE: eval/public/unknown_attribute_set.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_ATTRIBUTE_SET_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_ATTRIBUTE_SET_H_ #include "base/attribute_set.h" namespace google { namespace api { namespace expr { namespace runtime { // UnknownAttributeSet is a container for CEL attributes that are identified as // unknown during expression evaluation. using UnknownAttributeSet = ::cel::AttributeSet; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_ATTRIBUTE_SET_H_ ================================================ FILE: eval/public/unknown_attribute_set_test.cc ================================================ #include "eval/public/unknown_attribute_set.h" #include #include #include #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "internal/testing.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::testing::Eq; using cel::expr::Expr; TEST(UnknownAttributeSetTest, TestCreate) { const std::string kAttr1 = "a1"; const std::string kAttr2 = "a2"; const std::string kAttr3 = "a3"; std::shared_ptr cel_attr = std::make_shared( "root", std::vector( {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), CreateCelAttributeQualifier(CelValue::CreateInt64(1)), CreateCelAttributeQualifier(CelValue::CreateUint64(2)), CreateCelAttributeQualifier(CelValue::CreateBool(true))})); UnknownAttributeSet unknown_set({*cel_attr}); EXPECT_THAT(unknown_set.size(), Eq(1)); EXPECT_THAT(*(unknown_set.begin()), Eq(*cel_attr)); } TEST(UnknownAttributeSetTest, TestMergeSets) { const std::string kAttr1 = "a1"; const std::string kAttr2 = "a2"; const std::string kAttr3 = "a3"; CelAttribute cel_attr1( "root", std::vector( {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), CreateCelAttributeQualifier(CelValue::CreateInt64(1)), CreateCelAttributeQualifier(CelValue::CreateUint64(2)), CreateCelAttributeQualifier(CelValue::CreateBool(true))})); CelAttribute cel_attr1_copy( "root", std::vector( {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), CreateCelAttributeQualifier(CelValue::CreateInt64(1)), CreateCelAttributeQualifier(CelValue::CreateUint64(2)), CreateCelAttributeQualifier(CelValue::CreateBool(true))})); CelAttribute cel_attr2( "root", std::vector( {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), CreateCelAttributeQualifier(CelValue::CreateInt64(2)), CreateCelAttributeQualifier(CelValue::CreateUint64(2)), CreateCelAttributeQualifier(CelValue::CreateBool(true))})); CelAttribute cel_attr3( "root", std::vector( {CreateCelAttributeQualifier(CelValue::CreateString(&kAttr1)), CreateCelAttributeQualifier(CelValue::CreateInt64(2)), CreateCelAttributeQualifier(CelValue::CreateUint64(2)), CreateCelAttributeQualifier(CelValue::CreateBool(false))})); UnknownAttributeSet unknown_set1({cel_attr1, cel_attr2}); UnknownAttributeSet unknown_set2({cel_attr1_copy, cel_attr3}); UnknownAttributeSet unknown_set3 = UnknownAttributeSet::Merge(unknown_set1, unknown_set2); EXPECT_THAT(unknown_set3.size(), Eq(3)); std::vector attrs1; for (const auto& attr_ptr : unknown_set3) { attrs1.push_back(attr_ptr); } std::vector attrs2 = {cel_attr1, cel_attr2, cel_attr3}; EXPECT_THAT(attrs1, testing::UnorderedPointwise(Eq(), attrs2)); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/unknown_function_result_set.cc ================================================ #include "eval/public/unknown_function_result_set.h" ================================================ FILE: eval/public/unknown_function_result_set.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ #include "base/function_result.h" #include "base/function_result_set.h" namespace google { namespace api { namespace expr { namespace runtime { // Represents a function result that is unknown at the time of execution. This // allows for lazy evaluation of expensive functions. using UnknownFunctionResult = ::cel::FunctionResult; // Represents a collection of unknown function results at a particular point in // execution. Execution should advance further if this set of unknowns are // provided. It may not advance if only a subset are provided. // Set semantics use |IsEqualTo()| defined on |UnknownFunctionResult|. using UnknownFunctionResultSet = ::cel::FunctionResultSet; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ ================================================ FILE: eval/public/unknown_function_result_set_test.cc ================================================ #include "eval/public/unknown_function_result_set.h" #include #include #include #include "google/protobuf/duration.pb.h" #include "google/protobuf/empty.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "eval/public/cel_function.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::google::protobuf::ListValue; using ::google::protobuf::Struct; using ::google::protobuf::Arena; using ::testing::Eq; using ::testing::SizeIs; CelFunctionDescriptor kTwoInt("TwoInt", false, {CelValue::Type::kInt64, CelValue::Type::kInt64}); CelFunctionDescriptor kOneInt("OneInt", false, {CelValue::Type::kInt64}); TEST(UnknownFunctionResult, Equals) { UnknownFunctionResult call1(kTwoInt, /*expr_id=*/0); UnknownFunctionResult call2(kTwoInt, /*expr_id=*/0); EXPECT_TRUE(call1.IsEqualTo(call2)); UnknownFunctionResult call3(kOneInt, /*expr_id=*/0); UnknownFunctionResult call4(kOneInt, /*expr_id=*/0); EXPECT_TRUE(call3.IsEqualTo(call4)); UnknownFunctionResultSet call_set({call1, call3}); EXPECT_EQ(call_set.size(), 2); EXPECT_EQ(*call_set.begin(), call3); EXPECT_EQ(*(++call_set.begin()), call1); } TEST(UnknownFunctionResult, InequalDescriptor) { UnknownFunctionResult call1(kTwoInt, /*expr_id=*/0); UnknownFunctionResult call2(kOneInt, /*expr_id=*/0); EXPECT_FALSE(call1.IsEqualTo(call2)); CelFunctionDescriptor one_uint("OneInt", false, {CelValue::Type::kUint64}); UnknownFunctionResult call3(kOneInt, /*expr_id=*/0); UnknownFunctionResult call4(one_uint, /*expr_id=*/0); EXPECT_FALSE(call3.IsEqualTo(call4)); UnknownFunctionResultSet call_set({call1, call3, call4}); EXPECT_EQ(call_set.size(), 3); auto it = call_set.begin(); EXPECT_EQ(*it++, call3); EXPECT_EQ(*it++, call4); EXPECT_EQ(*it++, call1); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/unknown_set.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ #include "base/internal/unknown_set.h" #include "eval/public/unknown_attribute_set.h" // IWYU pragma: keep #include "eval/public/unknown_function_result_set.h" // IWYU pragma: keep namespace google { namespace api { namespace expr { namespace runtime { // Class representing a collection of unknowns from a single evaluation pass of // a CEL expression. using UnknownSet = ::cel::base_internal::UnknownSet; } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ ================================================ FILE: eval/public/unknown_set_test.cc ================================================ #include "eval/public/unknown_set.h" #include #include "cel/expr/syntax.pb.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_function_result_set.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::google::protobuf::Arena; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; UnknownFunctionResultSet MakeFunctionResult(Arena* arena, int64_t id) { CelFunctionDescriptor desc("OneInt", false, {CelValue::Type::kInt64}); return UnknownFunctionResultSet(UnknownFunctionResult(desc, /*expr_id=*/0)); } UnknownAttributeSet MakeAttribute(Arena* arena, int64_t id) { std::vector attr_trail{ CreateCelAttributeQualifier(CelValue::CreateInt64(id))}; return UnknownAttributeSet({CelAttribute("x", std::move(attr_trail))}); } MATCHER_P(UnknownAttributeIs, id, "") { const CelAttribute& attr = arg; if (attr.qualifier_path().size() != 1) { return false; } auto maybe_qualifier = attr.qualifier_path()[0].GetInt64Key(); if (!maybe_qualifier.has_value()) { return false; } return maybe_qualifier.value() == id; } TEST(UnknownSet, AttributesMerge) { Arena arena; UnknownSet a(MakeAttribute(&arena, 1)); UnknownSet b(MakeAttribute(&arena, 2)); UnknownSet c(MakeAttribute(&arena, 2)); UnknownSet d(a, b); UnknownSet e(c, d); EXPECT_THAT( d.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); EXPECT_THAT( e.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); } TEST(UnknownSet, DefaultEmpty) { UnknownSet empty_set; EXPECT_THAT(empty_set.unknown_attributes(), IsEmpty()); EXPECT_THAT(empty_set.unknown_function_results(), IsEmpty()); } TEST(UnknownSet, MixedMerges) { Arena arena; UnknownSet a(MakeAttribute(&arena, 1), MakeFunctionResult(&arena, 1)); UnknownSet b(MakeFunctionResult(&arena, 2)); UnknownSet c(MakeAttribute(&arena, 2)); UnknownSet d(a, b); UnknownSet e(c, d); EXPECT_THAT(d.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1))); EXPECT_THAT( e.unknown_attributes(), UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/public/value_export_util.cc ================================================ #include "eval/public/value_export_util.h" #include #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "internal/proto_time_encoding.h" #include "google/protobuf/util/json_util.h" #include "google/protobuf/util/time_util.h" namespace google::api::expr::runtime { using google::protobuf::Duration; using google::protobuf::Timestamp; using google::protobuf::Value; using google::protobuf::util::TimeUtil; absl::Status KeyAsString(const CelValue& value, std::string* key) { switch (value.type()) { case CelValue::Type::kInt64: { *key = absl::StrCat(value.Int64OrDie()); break; } case CelValue::Type::kUint64: { *key = absl::StrCat(value.Uint64OrDie()); break; } case CelValue::Type::kString: { key->assign(value.StringOrDie().value().data(), value.StringOrDie().value().size()); break; } default: { return absl::InvalidArgumentError("Unsupported map type"); } } return absl::OkStatus(); } // Export content of CelValue as google.protobuf.Value. absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value, google::protobuf::Arena* arena) { if (in_value.IsNull()) { out_value->set_null_value(google::protobuf::NULL_VALUE); return absl::OkStatus(); } switch (in_value.type()) { case CelValue::Type::kBool: { out_value->set_bool_value(in_value.BoolOrDie()); break; } case CelValue::Type::kInt64: { out_value->set_number_value(static_cast(in_value.Int64OrDie())); break; } case CelValue::Type::kUint64: { out_value->set_number_value(static_cast(in_value.Uint64OrDie())); break; } case CelValue::Type::kDouble: { out_value->set_number_value(in_value.DoubleOrDie()); break; } case CelValue::Type::kString: { auto value = in_value.StringOrDie().value(); out_value->set_string_value(value.data(), value.size()); break; } case CelValue::Type::kBytes: { *out_value->mutable_string_value() = absl::Base64Escape(in_value.BytesOrDie().value()); break; } case CelValue::Type::kDuration: { Duration duration; auto status = cel::internal::EncodeDuration(in_value.DurationOrDie(), &duration); if (!status.ok()) { return status; } out_value->set_string_value(TimeUtil::ToString(duration)); break; } case CelValue::Type::kTimestamp: { Timestamp timestamp; auto status = cel::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); if (!status.ok()) { return status; } out_value->set_string_value(TimeUtil::ToString(timestamp)); break; } case CelValue::Type::kMessage: { google::protobuf::util::JsonPrintOptions json_options; json_options.preserve_proto_field_names = true; std::string json; auto status = google::protobuf::util::MessageToJsonString(*in_value.MessageOrDie(), &json, json_options); if (!status.ok()) { return absl::InternalError(status.ToString()); } google::protobuf::util::JsonParseOptions json_parse_options; status = google::protobuf::util::JsonStringToMessage(json, out_value, json_parse_options); if (!status.ok()) { return absl::InternalError(status.ToString()); } break; } case CelValue::Type::kList: { const CelList* cel_list = in_value.ListOrDie(); auto out_values = out_value->mutable_list_value(); for (int i = 0; i < cel_list->size(); i++) { auto status = ExportAsProtoValue((*cel_list).Get(arena, i), out_values->add_values(), arena); if (!status.ok()) { return status; } } break; } case CelValue::Type::kMap: { const CelMap* cel_map = in_value.MapOrDie(); CEL_ASSIGN_OR_RETURN(auto keys_list, cel_map->ListKeys(arena)); auto out_values = out_value->mutable_struct_value()->mutable_fields(); for (int i = 0; i < keys_list->size(); i++) { std::string key; CelValue map_key = (*keys_list).Get(arena, i); auto status = KeyAsString(map_key, &key); if (!status.ok()) { return status; } auto map_value_ref = (*cel_map).Get(arena, map_key); CelValue map_value = (map_value_ref) ? map_value_ref.value() : CelValue(); status = ExportAsProtoValue(map_value, &((*out_values)[key]), arena); if (!status.ok()) { return status; } } break; } default: { return absl::InvalidArgumentError("Unsupported value type"); } } return absl::OkStatus(); } } // namespace google::api::expr::runtime ================================================ FILE: eval/public/value_export_util.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "eval/public/cel_value.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { // Exports content of CelValue as google.protobuf.Value. // Current limitations: // - exports integer values as doubles (Value.number_value); // - exports integer keys in maps as strings; // - handles Duration and Timestamp as generic messages. absl::Status ExportAsProtoValue(const CelValue& in_value, google::protobuf::Value* out_value, google::protobuf::Arena* arena); inline absl::Status ExportAsProtoValue(const CelValue& in_value, google::protobuf::Value* out_value) { google::protobuf::Arena arena; return ExportAsProtoValue(in_value, out_value, &arena); } } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ ================================================ FILE: eval/public/value_export_util_test.cc ================================================ #include "eval/public/value_export_util.h" #include #include #include #include "absl/strings/str_cat.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "testutil/util.h" namespace google::api::expr::runtime { namespace { using google::protobuf::Duration; using google::protobuf::ListValue; using google::protobuf::Struct; using google::protobuf::Timestamp; using google::protobuf::Value; using google::protobuf::Arena; TEST(ValueExportUtilTest, ConvertBoolValue) { CelValue cel_value = CelValue::CreateBool(true); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kBoolValue); EXPECT_EQ(value.bool_value(), true); } TEST(ValueExportUtilTest, ConvertInt64Value) { CelValue cel_value = CelValue::CreateInt64(-1); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), -1); } TEST(ValueExportUtilTest, ConvertUint64Value) { CelValue cel_value = CelValue::CreateUint64(1); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), 1); } TEST(ValueExportUtilTest, ConvertDoubleValue) { CelValue cel_value = CelValue::CreateDouble(1.3); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), 1.3); } TEST(ValueExportUtilTest, ConvertStringValue) { std::string test = "test"; CelValue cel_value = CelValue::CreateString(&test); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "test"); } TEST(ValueExportUtilTest, ConvertBytesValue) { std::string test = "test"; CelValue cel_value = CelValue::CreateBytes(&test); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); // Check that the result is BASE64 encoded. EXPECT_EQ(value.string_value(), "dGVzdA=="); } TEST(ValueExportUtilTest, ConvertDurationValue) { Duration duration; duration.set_seconds(2); duration.set_nanos(3); CelValue cel_value = CelProtoWrapper::CreateDuration(&duration); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "2.000000003s"); } TEST(ValueExportUtilTest, ConvertTimestampValue) { Timestamp timestamp; timestamp.set_seconds(1000000000); timestamp.set_nanos(3); CelValue cel_value = CelProtoWrapper::CreateTimestamp(×tamp); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "2001-09-09T01:46:40.000000003Z"); } TEST(ValueExportUtilTest, ConvertStructMessage) { Struct struct_msg; (*struct_msg.mutable_fields())["string_value"].set_string_value("test"); Arena arena; CelValue cel_value = CelProtoWrapper::CreateMessage(&struct_msg, &arena); Value value; EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); EXPECT_THAT(value.struct_value(), testutil::EqualsProto(struct_msg)); } TEST(ValueExportUtilTest, ConvertValueMessage) { Value value_in; // key-based access forces value to be a struct. (*value_in.mutable_struct_value()->mutable_fields())["boolean_value"] .set_bool_value(true); Arena arena; CelValue cel_value = CelProtoWrapper::CreateMessage(&value_in, &arena); Value value_out; EXPECT_OK(ExportAsProtoValue(cel_value, &value_out)); EXPECT_THAT(value_in, testutil::EqualsProto(value_out)); } TEST(ValueExportUtilTest, ConvertListValueMessage) { ListValue list_value; list_value.add_values()->set_string_value("test"); list_value.add_values()->set_bool_value(true); Arena arena; CelValue cel_value = CelProtoWrapper::CreateMessage(&list_value, &arena); Value value_out; EXPECT_OK(ExportAsProtoValue(cel_value, &value_out)); EXPECT_THAT(list_value, testutil::EqualsProto(value_out.list_value())); } TEST(ValueExportUtilTest, ConvertRepeatedBoolValue) { Arena arena; Value value; TestMessage* msg = Arena::Create(&arena); msg->add_bool_list(true); msg->add_bool_list(false); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("bool_list"); EXPECT_TRUE(list_value.has_list_value()); EXPECT_EQ(list_value.list_value().values(0).bool_value(), true); EXPECT_EQ(list_value.list_value().values(1).bool_value(), false); } TEST(ValueExportUtilTest, ConvertRepeatedInt32Value) { Arena arena; Value value; TestMessage* msg = Arena::Create(&arena); msg->add_int32_list(2); msg->add_int32_list(3); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("int32_list"); EXPECT_TRUE(list_value.has_list_value()); EXPECT_DOUBLE_EQ(list_value.list_value().values(0).number_value(), 2); EXPECT_DOUBLE_EQ(list_value.list_value().values(1).number_value(), 3); } TEST(ValueExportUtilTest, ConvertRepeatedInt64Value) { Arena arena; Value value; TestMessage* msg = Arena::Create(&arena); msg->add_int64_list(2); msg->add_int64_list(3); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("int64_list"); EXPECT_TRUE(list_value.has_list_value()); EXPECT_EQ(list_value.list_value().values(0).string_value(), "2"); EXPECT_EQ(list_value.list_value().values(1).string_value(), "3"); } TEST(ValueExportUtilTest, ConvertRepeatedUint64Value) { Arena arena; Value value; TestMessage* msg = Arena::Create(&arena); msg->add_uint64_list(2); msg->add_uint64_list(3); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("uint64_list"); EXPECT_TRUE(list_value.has_list_value()); EXPECT_EQ(list_value.list_value().values(0).string_value(), "2"); EXPECT_EQ(list_value.list_value().values(1).string_value(), "3"); } TEST(ValueExportUtilTest, ConvertRepeatedDoubleValue) { Arena arena; Value value; TestMessage* msg = Arena::Create(&arena); msg->add_double_list(2); msg->add_double_list(3); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("double_list"); EXPECT_TRUE(list_value.has_list_value()); EXPECT_DOUBLE_EQ(list_value.list_value().values(0).number_value(), 2); EXPECT_DOUBLE_EQ(list_value.list_value().values(1).number_value(), 3); } TEST(ValueExportUtilTest, ConvertRepeatedStringValue) { Arena arena; Value value; TestMessage* msg = Arena::Create(&arena); msg->add_string_list("test1"); msg->add_string_list("test2"); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("string_list"); EXPECT_TRUE(list_value.has_list_value()); EXPECT_EQ(list_value.list_value().values(0).string_value(), "test1"); EXPECT_EQ(list_value.list_value().values(1).string_value(), "test2"); } TEST(ValueExportUtilTest, ConvertRepeatedBytesValue) { Arena arena; Value value; TestMessage* msg = Arena::Create(&arena); msg->add_bytes_list("test1"); msg->add_bytes_list("test2"); CelValue cel_value = CelProtoWrapper::CreateMessage(msg, &arena); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("bytes_list"); EXPECT_TRUE(list_value.has_list_value()); EXPECT_EQ(list_value.list_value().values(0).string_value(), "dGVzdDE="); EXPECT_EQ(list_value.list_value().values(1).string_value(), "dGVzdDI="); } TEST(ValueExportUtilTest, ConvertCelList) { Arena arena; Value value; std::vector values; values.push_back(CelValue::CreateInt64(2)); values.push_back(CelValue::CreateInt64(3)); CelList *cel_list = Arena::Create(&arena, values); CelValue cel_value = CelValue::CreateList(cel_list); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kListValue); EXPECT_DOUBLE_EQ(value.list_value().values(0).number_value(), 2); EXPECT_DOUBLE_EQ(value.list_value().values(1).number_value(), 3); } TEST(ValueExportUtilTest, ConvertCelMapWithStringKey) { Value value; std::vector> map_entries; std::string key1 = "key1"; std::string key2 = "key2"; std::string value1 = "value1"; std::string value2 = "value2"; map_entries.push_back( {CelValue::CreateString(&key1), CelValue::CreateString(&value1)}); map_entries.push_back( {CelValue::CreateString(&key2), CelValue::CreateString(&value2)}); auto cel_map = CreateContainerBackedMap( absl::Span>(map_entries)) .value(); CelValue cel_value = CelValue::CreateMap(cel_map.get()); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); const auto& fields = value.struct_value().fields(); EXPECT_EQ(fields.at(key1).string_value(), value1); EXPECT_EQ(fields.at(key2).string_value(), value2); } TEST(ValueExportUtilTest, ConvertCelMapWithInt64Key) { Value value; std::vector> map_entries; int key1 = -1; int key2 = 2; std::string value1 = "value1"; std::string value2 = "value2"; map_entries.push_back( {CelValue::CreateInt64(key1), CelValue::CreateString(&value1)}); map_entries.push_back( {CelValue::CreateInt64(key2), CelValue::CreateString(&value2)}); auto cel_map = CreateContainerBackedMap( absl::Span>(map_entries)) .value(); CelValue cel_value = CelValue::CreateMap(cel_map.get()); EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); const auto& fields = value.struct_value().fields(); EXPECT_EQ(fields.at(absl::StrCat(key1)).string_value(), value1); EXPECT_EQ(fields.at(absl::StrCat(key2)).string_value(), value2); } } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/tests/BUILD ================================================ # This package contains CEL evaluator tests (end-to-end, benchmark etc.) # # load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) exports_files(["LICENSE"]) cc_test( name = "benchmark_test", srcs = [ "benchmark_test.cc", ], tags = [ "benchmark", "manual", ], deps = [ ":request_context_cc_proto", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//internal:benchmark", "//internal:status_macros", "//internal:testing", "//parser", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "modern_benchmark_test", srcs = [ "modern_benchmark_test.cc", ], tags = [ "benchmark", "manual", ], deps = [ ":request_context_cc_proto", "//common:allocator", "//common:casting", "//common:legacy_value", "//common:memory", "//common:native_type", "//common:value", "//extensions:comprehensions_v2_functions", "//extensions:comprehensions_v2_macros", "//extensions/protobuf:runtime_adapter", "//extensions/protobuf:value", "//internal:benchmark", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//parser", "//parser:macro", "//parser:macro_registry", "//runtime", "//runtime:activation", "//runtime:constant_folding", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "allocation_benchmark_test", size = "small", srcs = [ "allocation_benchmark_test.cc", ], tags = [ "benchmark", "manual", ], deps = [ ":request_context_cc_proto", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_value", "//internal:benchmark", "//internal:testing", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "memory_safety_test", srcs = [ "memory_safety_test.cc", ], deps = [ "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function_adapter", "//eval/public:cel_options", "//eval/public/testing:matchers", "//internal:testing", "//parser", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_googleapis//google/rpc/context:attribute_context_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "expression_builder_benchmark_test", size = "small", srcs = [ "expression_builder_benchmark_test.cc", ], tags = [ "benchmark", "manual", ], deps = [ ":request_context_cc_proto", "//common:minimal_descriptor_pool", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_type_registry", "//internal:benchmark", "//internal:status_macros", "//internal:testing", "//parser", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "end_to_end_test", size = "small", srcs = [ "end_to_end_test.cc", ], deps = [ "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//internal:status_macros", "//internal:testing", "//testutil:util", "@com_google_absl//absl/status", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) cc_test( name = "unknowns_end_to_end_test", size = "small", srcs = [ "unknowns_end_to_end_test.cc", ], deps = [ "//base:attributes", "//base:function_result", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_attribute", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:unknown_set", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//internal:status_macros", "//internal:testing", "//parser", "//runtime/internal:activation_attribute_matcher_access", "//runtime/internal:attribute_matcher", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], ) proto_library( name = "request_context_protos", srcs = [ "request_context.proto", ], ) cc_proto_library( name = "request_context_cc_proto", deps = [":request_context_protos"], ) cc_library( name = "mock_cel_expression", testonly = 1, hdrs = ["mock_cel_expression.h"], deps = [ "//eval/public:base_activation", "//eval/public:cel_expression", "//internal:testing_no_main", "@com_google_absl//absl/status:statusor", ], ) ================================================ FILE: eval/tests/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: eval/tests/README.md ================================================ # Integration tests for c++ CEL Runtime ## Benchmarks To run the benchmark tests: `blaze run -c opt --dynamic_mode=off //eval/tests:benchmark_test --benchmark_filter=all` or `blaze run -c opt --dynamic_mode=off //eval/tests:unknowns_benchmark_test --benchmark_filter=all` see go/benchmark For csv formatting: `awk '{print $1 "," $2 "," $3 "," $4}'` ================================================ FILE: eval/tests/allocation_benchmark_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "absl/strings/substitute.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/tests/request_context.pb.h" #include "internal/benchmark.h" #include "internal/testing.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::HasSubstr; // Evaluates cel expression: // '"1" + "1" + ...' static void BM_StrCatLocalArena(benchmark::State& state) { std::string expr("'1'"); int len = state.range(0); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); for (int i = 0; i < len; i++) { expr = absl::Substitute("($0 + $0)", expr); } ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); CelValue::StringHolder holder; ASSERT_TRUE(result.GetValue(&holder)); ASSERT_EQ(holder.value().length(), 1 << len); } } BENCHMARK(BM_StrCatLocalArena)->DenseRange(0, 8, 2); // Evaluates cel expression: // '("1" + "1") + ...' static void BM_StrCatSharedArena(benchmark::State& state) { google::protobuf::Arena arena; std::string expr("'1'"); int len = state.range(0); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); for (int i = 0; i < len; i++) { expr = absl::Substitute("($0 + $0)", expr); } ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : state) { Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); CelValue::StringHolder holder; ASSERT_TRUE(result.GetValue(&holder)); ASSERT_EQ(holder.value().length(), 1 << len); } } // Expression grows exponentially. BENCHMARK(BM_StrCatSharedArena)->DenseRange(0, 8, 2); // Series of simple expressions that are expected to require an allocation. static void BM_AllocateString(benchmark::State& state) { google::protobuf::Arena arena; std::string expr("'1' + '1'"); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : state) { Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); CelValue::StringHolder holder; ASSERT_TRUE(result.GetValue(&holder)); ASSERT_EQ(holder.value(), "11"); } } BENCHMARK(BM_AllocateString); static void BM_AllocateError(benchmark::State& state) { google::protobuf::Arena arena; std::string expr("1 / 0"); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : state) { Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); const CelError* value; ASSERT_TRUE(result.GetValue(&value)); ASSERT_THAT(*value, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("divide by zero"))); } } BENCHMARK(BM_AllocateError); static void BM_AllocateMap(benchmark::State& state) { google::protobuf::Arena arena; std::string expr("{1: 2, 3: 4}"); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : state) { Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsMap()); } } BENCHMARK(BM_AllocateMap); static void BM_AllocateMessage(benchmark::State& state) { google::protobuf::Arena arena; std::string expr( "google.api.expr.runtime.RequestContext{" "ip: '192.168.0.1'," "path: '/root'}"); // Make sure RequestContext is loaded in the generated descriptor pool. RequestContext context; static_cast(context); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : state) { Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsMessage()); } } BENCHMARK(BM_AllocateMessage); static void BM_AllocateLargeMessage(benchmark::State& state) { // Make sure attribute context is loaded in the generated descriptor pool. rpc::context::AttributeContext context; static_cast(context); google::protobuf::Arena arena; std::string expr(R"( google.rpc.context.AttributeContext{ source: google.rpc.context.AttributeContext.Peer{ ip: '192.168.0.1', port: 1025, labels: {"abc": "123", "def": "456"} }, request: google.rpc.context.AttributeContext.Request{ method: 'GET', path: 'root', host: 'www.example.com' }, resource: google.rpc.context.AttributeContext.Resource{ labels: {"abc": "123", "def": "456"}, } })"); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : state) { Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsMessage()); } } BENCHMARK(BM_AllocateLargeMessage); static void BM_AllocateList(benchmark::State& state) { google::protobuf::Arena arena; std::string expr("[1, 2, 3, 4]"); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); for (auto _ : state) { Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsList()); } } BENCHMARK(BM_AllocateList); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/tests/benchmark_test.cc ================================================ #include "internal/benchmark.h" #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/base/attributes.h" #include "absl/container/btree_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/flags/flag.h" #include "absl/strings/match.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/tests/request_context.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" ABSL_FLAG(bool, enable_optimizations, false, "enable const folding opt"); ABSL_FLAG(bool, enable_recursive_planning, false, "enable recursive planning"); namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::google::rpc::context::AttributeContext; InterpreterOptions GetOptions(google::protobuf::Arena& arena) { InterpreterOptions options; if (absl::GetFlag(FLAGS_enable_optimizations)) { options.constant_arena = &arena; options.constant_folding = true; } if (absl::GetFlag(FLAGS_enable_recursive_planning)) { options.max_recursion_depth = -1; } return options; } // Benchmark test // Evaluates cel expression: // '1 + 1 + 1 .... +1' static void BM_Eval(benchmark::State& state) { google::protobuf::Arena arena; InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); Expr root_expr; Expr* cur_expr = &root_expr; for (int i = 0; i < len; i++) { Expr::Call* call = cur_expr->mutable_call_expr(); call->set_function("_+_"); call->add_args()->mutable_const_expr()->set_int64_value(1); cur_expr = call->add_args(); } cur_expr->mutable_const_expr()->set_int64_value(1); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&root_expr, &source_info)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); ASSERT_TRUE(result.Int64OrDie() == len + 1); } } BENCHMARK(BM_Eval)->Range(1, 10000); absl::Status EmptyCallback(int64_t expr_id, const CelValue& value, google::protobuf::Arena* arena) { return absl::OkStatus(); } // Benchmark test // Traces cel expression with an empty callback: // '1 + 1 + 1 .... +1' static void BM_Eval_Trace(benchmark::State& state) { google::protobuf::Arena arena; InterpreterOptions options = GetOptions(arena); options.enable_recursive_tracing = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); Expr root_expr; Expr* cur_expr = &root_expr; for (int i = 0; i < len; i++) { Expr::Call* call = cur_expr->mutable_call_expr(); call->set_function("_+_"); call->add_args()->mutable_const_expr()->set_int64_value(1); cur_expr = call->add_args(); } cur_expr->mutable_const_expr()->set_int64_value(1); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&root_expr, &source_info)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Trace(activation, &arena, EmptyCallback)); ASSERT_TRUE(result.IsInt64()); ASSERT_TRUE(result.Int64OrDie() == len + 1); } } // A number higher than 10k leads to a stack overflow due to the recursive // nature of the proto to native type conversion. BENCHMARK(BM_Eval_Trace)->Range(1, 10000); // Benchmark test // Evaluates cel expression: // '"a" + "a" + "a" .... + "a"' static void BM_EvalString(benchmark::State& state) { google::protobuf::Arena arena; InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); Expr root_expr; Expr* cur_expr = &root_expr; for (int i = 0; i < len; i++) { Expr::Call* call = cur_expr->mutable_call_expr(); call->set_function("_+_"); call->add_args()->mutable_const_expr()->set_string_value("a"); cur_expr = call->add_args(); } cur_expr->mutable_const_expr()->set_string_value("a"); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&root_expr, &source_info)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsString()); ASSERT_TRUE(result.StringOrDie().value().size() == len + 1); } } // A number higher than 10k leads to a stack overflow due to the recursive // nature of the proto to native type conversion. BENCHMARK(BM_EvalString)->Range(1, 10000); // Benchmark test // Traces cel expression with an empty callback: // '"a" + "a" + "a" .... + "a"' static void BM_EvalString_Trace(benchmark::State& state) { google::protobuf::Arena arena; InterpreterOptions options = GetOptions(arena); options.enable_recursive_tracing = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); int len = state.range(0); Expr root_expr; Expr* cur_expr = &root_expr; for (int i = 0; i < len; i++) { Expr::Call* call = cur_expr->mutable_call_expr(); call->set_function("_+_"); call->add_args()->mutable_const_expr()->set_string_value("a"); cur_expr = call->add_args(); } cur_expr->mutable_const_expr()->set_string_value("a"); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&root_expr, &source_info)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Trace(activation, &arena, EmptyCallback)); ASSERT_TRUE(result.IsString()); ASSERT_TRUE(result.StringOrDie().value().size() == len + 1); } } // A number higher than 10k leads to a stack overflow due to the recursive // nature of the proto to native type conversion. BENCHMARK(BM_EvalString_Trace)->Range(1, 10000); const char kIP[] = "10.0.1.2"; const char kPath[] = "/admin/edit"; const char kToken[] = "admin"; ABSL_ATTRIBUTE_NOINLINE bool NativeCheck(absl::btree_map& attributes, const absl::flat_hash_set& denylists, const absl::flat_hash_set& allowlists) { auto& ip = attributes["ip"]; auto& path = attributes["path"]; auto& token = attributes["token"]; if (denylists.find(ip) != denylists.end()) { return false; } if (absl::StartsWith(path, "v1")) { if (token == "v1" || token == "v2" || token == "admin") { return true; } } else if (absl::StartsWith(path, "v2")) { if (token == "v2" || token == "admin") { return true; } } else if (absl::StartsWith(path, "/admin")) { if (token == "admin") { if (allowlists.find(ip) != allowlists.end()) { return true; } } } return false; } void BM_PolicyNative(benchmark::State& state) { const auto denylists = absl::flat_hash_set{"10.0.1.4", "10.0.1.5", "10.0.1.6"}; const auto allowlists = absl::flat_hash_set{"10.0.1.1", "10.0.1.2", "10.0.1.3"}; auto attributes = absl::btree_map{ {"ip", kIP}, {"token", kToken}, {"path", kPath}}; for (auto _ : state) { auto result = NativeCheck(attributes, denylists, allowlists); ASSERT_TRUE(result); } } BENCHMARK(BM_PolicyNative); void BM_PolicySymbolic(benchmark::State& state) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( !(ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && ((path.startsWith("v1") && token in ["v1", "v2", "admin"]) || (path.startsWith("v2") && token in ["v2", "admin"]) || (path.startsWith("/admin") && token == "admin" && ip in [ "10.0.1.1", "10.0.1.2", "10.0.1.3" ]) ))cel")); InterpreterOptions options = GetOptions(arena); options.constant_folding = true; options.constant_arena = &arena; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( &parsed_expr.expr(), &source_info)); Activation activation; activation.InsertValue("ip", CelValue::CreateStringView(kIP)); activation.InsertValue("path", CelValue::CreateStringView(kPath)); activation.InsertValue("token", CelValue::CreateStringView(kToken)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_PolicySymbolic); class RequestMap : public CelMap { public: absl::optional operator[](CelValue key) const override { if (!key.IsString()) { return {}; } auto value = key.StringOrDie().value(); if (value == "ip") { return CelValue::CreateStringView(kIP); } else if (value == "path") { return CelValue::CreateStringView(kPath); } else if (value == "token") { return CelValue::CreateStringView(kToken); } return {}; } int size() const override { return 3; } absl::StatusOr ListKeys() const override { return absl::UnimplementedError("CelMap::ListKeys is not implemented"); } }; // Uses a lazily constructed map container for "ip", "path", and "token". void BM_PolicySymbolicMap(benchmark::State& state) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || (request.path.startsWith("/admin") && request.token == "admin" && request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) ))cel")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( &parsed_expr.expr(), &source_info)); Activation activation; RequestMap request; activation.InsertValue("request", CelValue::CreateMap(&request)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_PolicySymbolicMap); // Uses a protobuf container for "ip", "path", and "token". void BM_PolicySymbolicProto(benchmark::State& state) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || (request.path.startsWith("/admin") && request.token == "admin" && request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) ))cel")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( &parsed_expr.expr(), &source_info)); Activation activation; RequestContext request; request.set_ip(kIP); request.set_path(kPath); request.set_token(kToken); activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_PolicySymbolicProto); // This expression has no equivalent CEL constexpr char kListSum[] = R"( id: 1 comprehension_expr: < accu_var: "__result__" iter_var: "x" iter_range: < id: 2 ident_expr: < name: "list_var" > > accu_init: < id: 3 const_expr: < int64_value: 0 > > loop_step: < id: 4 call_expr: < function: "_+_" args: < id: 5 ident_expr: < name: "__result__" > > args: < id: 6 ident_expr: < name: "x" > > > > loop_condition: < id: 7 const_expr: < bool_value: true > > result: < id: 8 ident_expr: < name: "__result__" > > >)"; void BM_Comprehension(benchmark::State& state) { google::protobuf::Arena arena; Expr expr; Activation activation; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); int len = state.range(0); std::vector list; list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(CelValue::CreateInt64(1)); } ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, nullptr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), len); } } BENCHMARK(BM_Comprehension)->Range(1, 1 << 20); void BM_Comprehension_Trace(benchmark::State& state) { google::protobuf::Arena arena; Expr expr; Activation activation; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); int len = state.range(0); std::vector list; list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(CelValue::CreateInt64(1)); } ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.enable_recursive_tracing = true; options.comprehension_max_iterations = 10000000; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, nullptr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Trace(activation, &arena, EmptyCallback)); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), len); } } BENCHMARK(BM_Comprehension_Trace)->Range(1, 1 << 20); void BM_HasMap(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("has(request.path) && !has(request.ip)")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); std::vector> map_pairs{ {CelValue::CreateStringView("path"), CelValue::CreateStringView("path")}}; auto cel_map = CreateContainerBackedMap(absl::Span>( map_pairs.data(), map_pairs.size())); activation.InsertValue("request", CelValue::CreateMap((*cel_map).get())); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_HasMap); void BM_HasProto(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("has(request.path) && !has(request.ip)")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); RequestContext request; request.set_path(kPath); request.set_token(kToken); activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_HasProto); void BM_HasProtoMap(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("has(request.headers.create_time) && " "!has(request.headers.update_time)")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); RequestContext request; request.mutable_headers()->insert({"create_time", "2021-01-01"}); activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_HasProtoMap); void BM_ReadProtoMap(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( request.headers.create_time == "2021-01-01" )cel")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); RequestContext request; request.mutable_headers()->insert({"create_time", "2021-01-01"}); activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_ReadProtoMap); void BM_NestedProtoFieldRead(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( !request.a.b.c.d.e )cel")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); RequestContext request; request.mutable_a()->mutable_b()->mutable_c()->mutable_d()->set_e(false); activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_NestedProtoFieldRead); void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( !request.a.b.c.d.e )cel")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry(), options); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); RequestContext request; activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_NestedProtoFieldReadDefaults); void BM_ProtoStructAccess(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( has(request.auth.claims.iss) && request.auth.claims.iss == 'accounts.google.com' )cel")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); AttributeContext::Request request; auto* auth = request.mutable_auth(); (*auth->mutable_claims()->mutable_fields())["iss"].set_string_value( "accounts.google.com"); activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_ProtoStructAccess); void BM_ProtoListAccess(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(R"cel( "//.../accessLevels/MY_LEVEL_4" in request.auth.access_levels )cel")); InterpreterOptions options = GetOptions(arena); auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&parsed_expr.expr(), nullptr)); AttributeContext::Request request; auto* auth = request.mutable_auth(); auth->add_access_levels("//.../accessLevels/MY_LEVEL_0"); auth->add_access_levels("//.../accessLevels/MY_LEVEL_1"); auth->add_access_levels("//.../accessLevels/MY_LEVEL_2"); auth->add_access_levels("//.../accessLevels/MY_LEVEL_3"); auth->add_access_levels("//.../accessLevels/MY_LEVEL_4"); activation.InsertValue("request", CelProtoWrapper::CreateMessage(&request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.BoolOrDie()); } } BENCHMARK(BM_ProtoListAccess); // This expression has no equivalent CEL expression. // Sum a square with a nested comprehension constexpr char kNestedListSum[] = R"( id: 1 comprehension_expr: < accu_var: "__result__" iter_var: "x" iter_range: < id: 2 ident_expr: < name: "list_var" > > accu_init: < id: 3 const_expr: < int64_value: 0 > > loop_step: < id: 4 call_expr: < function: "_+_" args: < id: 5 ident_expr: < name: "__result__" > > args: < id: 6 comprehension_expr: < accu_var: "__result__" iter_var: "x" iter_range: < id: 9 ident_expr: < name: "list_var" > > accu_init: < id: 10 const_expr: < int64_value: 0 > > loop_step: < id: 11 call_expr: < function: "_+_" args: < id: 12 ident_expr: < name: "__result__" > > args: < id: 13 ident_expr: < name: "x" > > > > loop_condition: < id: 14 const_expr: < bool_value: true > > result: < id: 15 ident_expr: < name: "__result__" > > > > > > loop_condition: < id: 7 const_expr: < bool_value: true > > result: < id: 8 ident_expr: < name: "__result__" > > >)"; void BM_NestedComprehension(benchmark::State& state) { google::protobuf::Arena arena; Expr expr; Activation activation; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); int len = state.range(0); std::vector list; list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(CelValue::CreateInt64(1)); } ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, nullptr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), len * len); } } BENCHMARK(BM_NestedComprehension)->Range(1, 1 << 10); void BM_NestedComprehension_Trace(benchmark::State& state) { google::protobuf::Arena arena; Expr expr; Activation activation; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); int len = state.range(0); std::vector list; list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(CelValue::CreateInt64(1)); } ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; options.enable_recursive_tracing = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, nullptr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Trace(activation, &arena, EmptyCallback)); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), len * len); } } BENCHMARK(BM_NestedComprehension_Trace)->Range(1, 1 << 10); void BM_ListComprehension(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("list_var.map(x, x * 2)")); int len = state.range(0); std::vector list; list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(CelValue::CreateInt64(1)); } ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN( auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsList()); ASSERT_EQ(result.ListOrDie()->size(), len); } } BENCHMARK(BM_ListComprehension)->Range(1, 1 << 16); void BM_ListComprehension_Trace(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("list_var.map(x, x * 2)")); int len = state.range(0); std::vector list; list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(CelValue::CreateInt64(1)); } ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; options.enable_recursive_tracing = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN( auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Trace(activation, &arena, EmptyCallback)); ASSERT_TRUE(result.IsList()); ASSERT_EQ(result.ListOrDie()->size(), len); } } BENCHMARK(BM_ListComprehension_Trace)->Range(1, 1 << 16); void BM_ListComprehension_Opt(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("list_var.map(x, x * 2)")); int len = state.range(0); std::vector list; list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(CelValue::CreateInt64(1)); } ContainerBackedListImpl cel_list(std::move(list)); activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options; options.constant_arena = &arena; options.constant_folding = true; options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN( auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsList()); ASSERT_EQ(result.ListOrDie()->size(), len); } } BENCHMARK(BM_ListComprehension_Opt)->Range(1, 1 << 16); void BM_ComprehensionCpp(benchmark::State& state) { Activation activation; int len = state.range(0); std::vector list; list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(CelValue::CreateInt64(1)); } auto op = [&list]() { int sum = 0; for (const auto& value : list) { sum += value.Int64OrDie(); } return sum; }; for (auto _ : state) { int result = op(); ASSERT_EQ(result, len); } } BENCHMARK(BM_ComprehensionCpp)->Range(1, 1 << 20); } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/tests/end_to_end_test.cc ================================================ #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/testutil/test_message.pb.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "testutil/util.h" #include "google/protobuf/text_format.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::absl_testing::StatusIs; using ::cel::expr::Expr; using ::cel::expr::SourceInfo; using ::google::protobuf::Arena; using ::google::protobuf::TextFormat; // Simple one parameter function that records the message argument it receives. class RecordArgFunction : public CelFunction { public: explicit RecordArgFunction(const std::string& name, std::vector* output) : CelFunction( CelFunctionDescriptor{name, false, {CelValue::Type::kMessage}}), output_(*output) {} absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 1) { return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } output_.push_back(args.at(0)); *result = CelValue::CreateBool(true); return absl::OkStatus(); } std::vector& output_; }; // Simple end-to-end test, which also serves as usage example. TEST(EndToEndTest, SimpleOnePlusOne) { // AST CEL equivalent of "1+var" constexpr char kExpr0[] = R"( call_expr: < function: "_+_" args: < ident_expr: < name: "var" > > args: < const_expr: < int64_value: 1 > > > )"; Expr expr; SourceInfo source_info; TextFormat::ParseFromString(kExpr0, &expr); // Obtain CEL Expression builder. std::unique_ptr builder = CreateCelExpressionBuilder(); // Builtin registration. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, &source_info)); Activation activation; // Bind value to "var" parameter. activation.InsertValue("var", CelValue::CreateInt64(1)); Arena arena; // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 2); } // Simple end-to-end test, which also serves as usage example. TEST(EndToEndTest, EmptyStringCompare) { // AST CEL equivalent of "var.string_value == '' && var.int64_value == 0" constexpr char kExpr0[] = R"( call_expr: < function: "_&&_" args: < call_expr: < function: "_==_" args: < select_expr: < operand: < ident_expr: < name: "var" > > field: "string_value" > > args: < const_expr: < string_value: "" > > > > args: < call_expr: < function: "_==_" args: < select_expr: < operand: < ident_expr: < name: "var" > > field: "int64_value" > > args: < const_expr: < int64_value: 0 > > > > > )"; Expr expr; SourceInfo source_info; TextFormat::ParseFromString(kExpr0, &expr); // Obtain CEL Expression builder. std::unique_ptr builder = CreateCelExpressionBuilder(); // Builtin registration. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, &source_info)); Activation activation; // Bind value to "var" parameter. constexpr char kData[] = R"( string_value: "" int64_value: 0 )"; TestMessage data; TextFormat::ParseFromString(kData, &data); Arena arena; activation.InsertValue("var", CelProtoWrapper::CreateMessage(&data, &arena)); // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); } TEST(EndToEndTest, NullLiteral) { // AST CEL equivalent of "Value{null_value: NullValue.NULL_VALUE}" constexpr char kExpr0[] = R"( struct_expr: < message_name: "Value" entries: < field_key: "null_value" value: < select_expr: < operand: < ident_expr: < name: "NullValue" > > field: "NULL_VALUE" > > > > )"; Expr expr; SourceInfo source_info; TextFormat::ParseFromString(kExpr0, &expr); // Obtain CEL Expression builder. std::unique_ptr builder = CreateCelExpressionBuilder(); builder->set_container("google.protobuf"); // Builtin registration. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, &source_info)); Activation activation; Arena arena; // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsNull()); } // Equivalent to 'RecordArg(test_message)' constexpr char kNullMessageHandlingExpr[] = R"pb( id: 1 call_expr: < function: "RecordArg" args: < ident_expr: < name: "test_message" > id: 2 > > )pb"; TEST(EndToEndTest, StrictNullHandling) { InterpreterOptions options; Expr expr; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kNullMessageHandlingExpr, &expr)); SourceInfo info; auto builder = CreateCelExpressionBuilder(options); std::vector extension_calls; ASSERT_OK(builder->GetRegistry()->Register( std::make_unique("RecordArg", &extension_calls))); ASSERT_OK_AND_ASSIGN(auto expression, builder->CreateExpression(&expr, &info)); Activation activation; google::protobuf::Arena arena; activation.InsertValue("test_message", CelValue::CreateNull()); ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); const CelError* result_value; ASSERT_TRUE(result.GetValue(&result_value)) << result.DebugString(); EXPECT_THAT(*result_value, StatusIs(absl::StatusCode::kUnknown, testing::HasSubstr("No matching overloads"))); } TEST(EndToEndTest, OutOfRangeDurationConstant) { InterpreterOptions options; options.enable_timestamp_duration_overflow_errors = true; Expr expr; // Duration representable in absl::Duration, but out of range for CelValue ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"( call_expr { function: "type" args { const_expr { duration_value { seconds: 28552639587287040 } } } })", &expr)); SourceInfo info; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto expression, builder->CreateExpression(&expr, &info)); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, expression->Evaluate(activation, &arena)); const CelError* result_value; ASSERT_TRUE(result.GetValue(&result_value)) << result.DebugString(); EXPECT_THAT(*result_value, StatusIs(absl::StatusCode::kInvalidArgument, testing::HasSubstr("Duration is out of range"))); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/tests/expression_builder_benchmark_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/minimal_descriptor_pool.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_type_registry.h" #include "eval/tests/request_context.pb.h" #include "internal/benchmark.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { using cel::expr::CheckedExpr; using cel::expr::ParsedExpr; using google::api::expr::parser::Parse; enum BenchmarkParam : int { kDefault = 0, kFoldConstants = 1, kRecursivePlanning = 2, kRecursivePlanningWithConstantFolding = 3, }; std::string LabelForParam(BenchmarkParam param) { switch (param) { case BenchmarkParam::kDefault: return "default"; case BenchmarkParam::kFoldConstants: return "fold_constants"; case BenchmarkParam::kRecursivePlanning: return "recursive_planning"; case BenchmarkParam::kRecursivePlanningWithConstantFolding: return "recursive_planning_with_constant_folding"; } return "unknown"; } void BM_RegisterBuiltins(benchmark::State& state) { for (auto _ : state) { auto builder = CreateCelExpressionBuilder(); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); ASSERT_OK(reg_status); } } BENCHMARK(BM_RegisterBuiltins); InterpreterOptions OptionsForParam(BenchmarkParam param, google::protobuf::Arena& arena) { InterpreterOptions options; switch (param) { case BenchmarkParam::kFoldConstants: case BenchmarkParam::kRecursivePlanningWithConstantFolding: options.constant_arena = &arena; options.constant_folding = true; break; case BenchmarkParam::kDefault: case BenchmarkParam::kRecursivePlanning: options.constant_folding = false; break; } switch (param) { case BenchmarkParam::kRecursivePlanning: case BenchmarkParam::kRecursivePlanningWithConstantFolding: options.max_recursion_depth = 48; break; case BenchmarkParam::kDefault: case BenchmarkParam::kFoldConstants: options.max_recursion_depth = 0; break; } return options; } void BM_SymbolicPolicy(benchmark::State& state) { auto param = static_cast(state.range(0)); state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || (request.path.startsWith("/admin") && request.token == "admin" && request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) ))cel")); google::protobuf::Arena arena; InterpreterOptions options = OptionsForParam(param, arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); ASSERT_OK(reg_status); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); arena.Reset(); } } BENCHMARK(BM_SymbolicPolicy) ->Arg(BenchmarkParam::kDefault) ->Arg(BenchmarkParam::kFoldConstants) ->Arg(BenchmarkParam::kRecursivePlanning) ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); absl::StatusOr> MakeBuilderForEnums( absl::string_view container, absl::string_view enum_type, int num_enum_values) { auto builder = CreateCelExpressionBuilder(cel::GetMinimalDescriptorPool(), nullptr, {}); builder->set_container(std::string(container)); CelTypeRegistry* type_registry = builder->GetTypeRegistry(); std::vector enumerators; enumerators.reserve(num_enum_values); for (int i = 0; i < num_enum_values; ++i) { enumerators.push_back( CelTypeRegistry::Enumerator{absl::StrCat("ENUM_VALUE_", i), i}); } type_registry->RegisterEnum(enum_type, std::move(enumerators)); CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder->GetRegistry())); return builder; } void BM_EnumResolutionSimple(benchmark::State& state) { static const CelExpressionBuilder* builder = []() { auto builder = MakeBuilderForEnums("", "com.example.TestEnum", 4); ABSL_CHECK_OK(builder.status()); return builder->release(); }(); ASSERT_OK_AND_ASSIGN(auto expr, Parse("com.example.TestEnum.ENUM_VALUE_0")); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); benchmark::DoNotOptimize(expression); } } BENCHMARK(BM_EnumResolutionSimple)->ThreadRange(1, 32); void BM_EnumResolutionContainer(benchmark::State& state) { static const CelExpressionBuilder* builder = []() { auto builder = MakeBuilderForEnums("com.example", "com.example.TestEnum", 4); ABSL_CHECK_OK(builder.status()); return builder->release(); }(); ASSERT_OK_AND_ASSIGN(auto expr, Parse("TestEnum.ENUM_VALUE_0")); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); benchmark::DoNotOptimize(expression); } } BENCHMARK(BM_EnumResolutionContainer)->ThreadRange(1, 32); void BM_EnumResolution32Candidate(benchmark::State& state) { static const CelExpressionBuilder* builder = []() { auto builder = MakeBuilderForEnums("com.example.foo", "com.example.foo.TestEnum", 8); ABSL_CHECK_OK(builder.status()); return builder->release(); }(); ASSERT_OK_AND_ASSIGN(auto expr, Parse("com.example.foo.TestEnum.ENUM_VALUE_0")); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); benchmark::DoNotOptimize(expression); } } BENCHMARK(BM_EnumResolution32Candidate)->ThreadRange(1, 32); void BM_EnumResolution256Candidate(benchmark::State& state) { static const CelExpressionBuilder* builder = []() { auto builder = MakeBuilderForEnums("com.example.foo", "com.example.foo.TestEnum", 64); ABSL_CHECK_OK(builder.status()); return builder->release(); }(); ASSERT_OK_AND_ASSIGN(auto expr, Parse("com.example.foo.TestEnum.ENUM_VALUE_0")); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); benchmark::DoNotOptimize(expression); } } BENCHMARK(BM_EnumResolution256Candidate)->ThreadRange(1, 32); void BM_NestedComprehension(benchmark::State& state) { auto param = static_cast(state.range(0)); state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( [4, 5, 6].all(x, [1, 2, 3].all(y, x > y) && [7, 8, 9].all(z, x < z)) )")); google::protobuf::Arena arena; InterpreterOptions options = OptionsForParam(param, arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); ASSERT_OK(reg_status); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); arena.Reset(); } } BENCHMARK(BM_NestedComprehension) ->Arg(BenchmarkParam::kDefault) ->Arg(BenchmarkParam::kFoldConstants) ->Arg(BenchmarkParam::kRecursivePlanning) ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_Comparisons(benchmark::State& state) { auto param = static_cast(state.range(0)); state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( v11 < v12 && v12 < v13 && v21 > v22 && v22 > v23 && v31 == v32 && v32 == v33 && v11 != v12 && v12 != v13 )")); google::protobuf::Arena arena; InterpreterOptions options = OptionsForParam(param, arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); ASSERT_OK(reg_status); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); arena.Reset(); } } BENCHMARK(BM_Comparisons) ->Arg(BenchmarkParam::kDefault) ->Arg(BenchmarkParam::kFoldConstants) ->Arg(BenchmarkParam::kRecursivePlanning) ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_ComparisonsConcurrent(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( v11 < v12 && v12 < v13 && v21 > v22 && v22 > v23 && v31 == v32 && v32 == v33 && v11 != v12 && v12 != v13 )")); static const CelExpressionBuilder* builder = [] { InterpreterOptions options; auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); ABSL_CHECK_OK(reg_status); return builder.release(); }(); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); benchmark::DoNotOptimize(expression); } } BENCHMARK(BM_ComparisonsConcurrent)->ThreadRange(1, 32); void RegexPrecompilationBench(bool enabled, benchmark::State& state) { auto param = static_cast(state.range(0)); state.SetLabel(absl::StrCat(LabelForParam(param), "_", enabled ? "enabled" : "disabled")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( input_str.matches(r'192\.168\.' + '[0-9]{1,3}' + r'\.' + '[0-9]{1,3}') || input_str.matches(r'10(\.[0-9]{1,3}){3}') )cel")); // Fake a checked expression with enough reference information for the expr // builder to identify the regex as optimize-able. CheckedExpr checked_expr; checked_expr.mutable_expr()->Swap(expr.mutable_expr()); checked_expr.mutable_source_info()->Swap(expr.mutable_source_info()); (*checked_expr.mutable_reference_map())[2].add_overload_id("matches_string"); (*checked_expr.mutable_reference_map())[11].add_overload_id("matches_string"); google::protobuf::Arena arena; InterpreterOptions options = OptionsForParam(param, arena); options.enable_regex_precompilation = enabled; auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); ASSERT_OK(reg_status); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(auto expression, builder->CreateExpression(&checked_expr)); arena.Reset(); } } void BM_RegexPrecompilationDisabled(benchmark::State& state) { RegexPrecompilationBench(false, state); } BENCHMARK(BM_RegexPrecompilationDisabled) ->Arg(BenchmarkParam::kDefault) ->Arg(BenchmarkParam::kFoldConstants) ->Arg(BenchmarkParam::kRecursivePlanning) ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_RegexPrecompilationEnabled(benchmark::State& state) { RegexPrecompilationBench(true, state); } BENCHMARK(BM_RegexPrecompilationEnabled) ->Arg(BenchmarkParam::kDefault) ->Arg(BenchmarkParam::kFoldConstants) ->Arg(BenchmarkParam::kRecursivePlanning) ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_StringConcat(benchmark::State& state) { auto param = static_cast(state.range(0)); state.SetLabel(LabelForParam(param)); auto size = state.range(1); std::string source = "'1234567890' + '1234567890'"; auto height = static_cast(std::log2(size)); for (int i = 1; i < height; i++) { // Force the parse to be a binary tree, otherwise we can hit // recursion limits. source = absl::StrCat("(", source, " + ", source, ")"); } // add a non const branch to the expression. absl::StrAppend(&source, " + identifier"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(source)); google::protobuf::Arena arena; InterpreterOptions options = OptionsForParam(param, arena); auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); ASSERT_OK(reg_status); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); arena.Reset(); } } BENCHMARK(BM_StringConcat) ->Args({BenchmarkParam::kDefault, 2}) ->Args({BenchmarkParam::kDefault, 4}) ->Args({BenchmarkParam::kDefault, 8}) ->Args({BenchmarkParam::kDefault, 16}) ->Args({BenchmarkParam::kDefault, 32}) ->Args({BenchmarkParam::kFoldConstants, 2}) ->Args({BenchmarkParam::kFoldConstants, 4}) ->Args({BenchmarkParam::kFoldConstants, 8}) ->Args({BenchmarkParam::kFoldConstants, 16}) ->Args({BenchmarkParam::kFoldConstants, 32}) ->Args({BenchmarkParam::kRecursivePlanning, 2}) ->Args({BenchmarkParam::kRecursivePlanning, 4}) ->Args({BenchmarkParam::kRecursivePlanning, 8}) ->Args({BenchmarkParam::kRecursivePlanning, 16}) ->Args({BenchmarkParam::kRecursivePlanning, 32}) ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 2}) ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 4}) ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 8}) ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 16}) ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 32}); void BM_StringConcat32Concurrent(benchmark::State& state) { std::string source = "'1234567890' + '1234567890'"; auto height = static_cast(std::log2(32)); for (int i = 1; i < height; i++) { // Force the parse to be a binary tree, otherwise we can hit // recursion limits. source = absl::StrCat("(", source, " + ", source, ")"); } // add a non const branch to the expression. absl::StrAppend(&source, " + identifier"); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(source)); static const CelExpressionBuilder* builder = [] { InterpreterOptions options; auto builder = CreateCelExpressionBuilder(options); auto reg_status = RegisterBuiltinFunctions(builder->GetRegistry()); ABSL_CHECK_OK(reg_status); return builder.release(); }(); for (auto _ : state) { ASSERT_OK_AND_ASSIGN( auto expression, builder->CreateExpression(&expr.expr(), &expr.source_info())); benchmark::DoNotOptimize(expression); } } BENCHMARK(BM_StringConcat32Concurrent)->ThreadRange(1, 32); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/tests/memory_safety_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Tests for memory safety using the CEL Evaluator. #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_options.h" #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "parser/parser.h" #include "testutil/util.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace google::api::expr::runtime { namespace { using ::absl_testing::IsOkAndHolds; using ::cel::expr::ParsedExpr; using ::google::rpc::context::AttributeContext; using testutil::EqualsProto; struct TestCase { std::string name; std::string expression; absl::flat_hash_map activation; test::CelValueMatcher expected_matcher; bool reference_resolver_enabled = false; }; enum Options { kDefault, kExhaustive, kFoldConstants }; using ParamType = std::tuple; std::string TestCaseName(const testing::TestParamInfo& param_info) { const ParamType& param = param_info.param; absl::string_view opt; switch (std::get<1>(param)) { case Options::kDefault: opt = "default"; break; case Options::kExhaustive: opt = "exhaustive"; break; case Options::kFoldConstants: opt = "opt"; break; } return absl::StrCat(std::get<0>(param).name, "_", opt); } class EvaluatorMemorySafetyTest : public testing::TestWithParam { public: EvaluatorMemorySafetyTest() { google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); } protected: const TestCase& GetTestCase() { return std::get<0>(GetParam()); } InterpreterOptions GetOptions() { InterpreterOptions options; options.constant_arena = &arena_; switch (std::get<1>(GetParam())) { case Options::kDefault: options.enable_regex_precompilation = false; options.constant_folding = false; options.enable_comprehension_list_append = false; options.enable_comprehension_vulnerability_check = true; options.short_circuiting = true; break; case Options::kExhaustive: options.enable_regex_precompilation = false; options.constant_folding = false; options.enable_comprehension_list_append = false; options.enable_comprehension_vulnerability_check = true; options.short_circuiting = false; break; case Options::kFoldConstants: options.enable_regex_precompilation = true; options.constant_folding = true; options.enable_comprehension_list_append = true; options.enable_comprehension_vulnerability_check = false; options.short_circuiting = true; break; } options.enable_qualified_identifier_rewrites = GetTestCase().reference_resolver_enabled; return options; } google::protobuf::Arena arena_; }; bool IsPrivateIpv4Impl(google::protobuf::Arena* arena, CelValue::StringHolder addr) { // Implementation for demonstration, this is simple but incomplete and // brittle. return absl::StartsWith(addr.value(), "192.168.") || absl::StartsWith(addr.value(), "10."); } TEST_P(EvaluatorMemorySafetyTest, Basic) { const auto& test_case = GetTestCase(); InterpreterOptions options = GetOptions(); std::unique_ptr builder = CreateCelExpressionBuilder(options); builder->set_container("google.rpc.context"); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); absl::string_view function_name = "IsPrivate"; if (test_case.reference_resolver_enabled) { function_name = "net.IsPrivate"; } ASSERT_OK((FunctionAdapter::CreateAndRegister( function_name, false, &IsPrivateIpv4Impl, builder->GetRegistry()))); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(test_case.expression)); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation activation; for (const auto& [key, value] : test_case.activation) { activation.InsertValue(key, value); } absl::StatusOr got = plan->Evaluate(activation, &arena_); EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); } // Check no use after free errors if evaluated after AST is freed. TEST_P(EvaluatorMemorySafetyTest, NoAstDependency) { const auto& test_case = GetTestCase(); InterpreterOptions options = GetOptions(); std::unique_ptr builder = CreateCelExpressionBuilder(options); builder->set_container("google.rpc.context"); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); absl::string_view function_name = "IsPrivate"; if (test_case.reference_resolver_enabled) { function_name = "net.IsPrivate"; } ASSERT_OK((FunctionAdapter::CreateAndRegister( function_name, false, &IsPrivateIpv4Impl, builder->GetRegistry()))); auto parsed_expr = parser::Parse(test_case.expression); ASSERT_OK(parsed_expr.status()); auto expr = std::make_unique(std::move(parsed_expr).value()); ASSERT_OK_AND_ASSIGN( std::unique_ptr plan, builder->CreateExpression(&expr->expr(), &expr->source_info())); expr.reset(); // ParsedExpr expr freed Activation activation; for (const auto& [key, value] : test_case.activation) { activation.InsertValue(key, value); } absl::StatusOr got = plan->Evaluate(activation, &arena_); EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); } // TODO(uncreated-issue/25): make expression plan memory safe after builder is freed. // TEST_P(EvaluatorMemorySafetyTest, NoBuilderDependency) INSTANTIATE_TEST_SUITE_P( Expression, EvaluatorMemorySafetyTest, testing::Combine( testing::ValuesIn(std::vector{ { "bool", "(true && false) || x || y == 'test_str'", {{"x", CelValue::CreateBool(false)}, {"y", CelValue::CreateStringView("test_str")}}, test::IsCelBool(true), }, { "const_str", "condition ? 'left_hand_string' : 'right_hand_string'", {{"condition", CelValue::CreateBool(false)}}, test::IsCelString("right_hand_string"), }, { "long_const_string", "condition ? 'left_hand_string' : " "'long_right_hand_string_0123456789'", {{"condition", CelValue::CreateBool(false)}}, test::IsCelString("long_right_hand_string_0123456789"), }, { "computed_string", "(condition ? 'a.b' : 'b.c') + '.d.e.f'", {{"condition", CelValue::CreateBool(false)}}, test::IsCelString("b.c.d.e.f"), }, { "regex", R"('192.168.128.64'.matches(r'^192\.168\.[0-2]?[0-9]?[0-9]\.[0-2]?[0-9]?[0-9]') )", {}, test::IsCelBool(true), }, { "list_create", "[1, 2, 3, 4, 5, 6][3] == 4", {}, test::IsCelBool(true), }, { "list_create_strings", "['1', '2', '3', '4', '5', '6'][2] == '3'", {}, test::IsCelBool(true), }, { "map_create", "{'1': 'one', '2': 'two'}['2']", {}, test::IsCelString("two"), }, { "struct_create", R"( AttributeContext{ request: AttributeContext.Request{ method: 'GET', path: '/index' }, origin: AttributeContext.Peer{ ip: '10.0.0.1' } } )", {}, test::IsCelMessage(EqualsProto(R"pb( request { method: "GET" path: "/index" } origin { ip: "10.0.0.1" } )pb")), }, {"extension_function", "IsPrivate('8.8.8.8')", {}, test::IsCelBool(false), /*enable_reference_resolver=*/false}, {"namespaced_function", "net.IsPrivate('192.168.0.1')", {}, test::IsCelBool(true), /*enable_reference_resolver=*/true}, { "comprehension", "['abc', 'def', 'ghi', 'jkl'].exists(el, el == 'mno')", {}, test::IsCelBool(false), }, { "comprehension_complex", "['a' + 'b' + 'c', 'd' + 'ef', 'g' + 'hi', 'j' + 'kl']" ".exists(el, el.startsWith('g'))", {}, test::IsCelBool(true), }}), testing::Values(Options::kDefault, Options::kExhaustive, Options::kFoldConstants)), &TestCaseName); } // namespace } // namespace google::api::expr::runtime ================================================ FILE: eval/tests/mock_cel_expression.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ #define THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ #include #include "absl/status/statusor.h" #include "eval/public/base_activation.h" #include "eval/public/cel_expression.h" #include "internal/testing.h" namespace google::api::expr::runtime { class MockCelExpression : public CelExpression { public: MOCK_METHOD(std::unique_ptr, InitializeState, (google::protobuf::Arena * arena), (const, override)); MOCK_METHOD(absl::StatusOr, Evaluate, (const BaseActivation& activation, google::protobuf::Arena* arena), (const, override)); MOCK_METHOD(absl::StatusOr, Evaluate, (const BaseActivation& activation, CelEvaluationState* state), (const, override)); MOCK_METHOD(absl::StatusOr, Trace, (const BaseActivation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback), (const, override)); MOCK_METHOD(absl::StatusOr, Trace, (const BaseActivation& activation, CelEvaluationState* state, CelEvaluationListener callback), (const, override)); }; } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ ================================================ FILE: eval/tests/modern_benchmark_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // General benchmarks for CEL evaluator. #include #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "google/rpc/context/attribute_context.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/btree_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/flags/flag.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "common/allocator.h" #include "common/casting.h" #include "common/native_type.h" #include "common/value.h" #include "eval/tests/request_context.pb.h" #include "extensions/comprehensions_v2_functions.h" #include "extensions/comprehensions_v2_macros.h" #include "extensions/protobuf/runtime_adapter.h" #include "extensions/protobuf/value.h" #include "internal/benchmark.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "parser/macro_registry.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/constant_folding.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" ABSL_FLAG(bool, enable_recursive_planning, false, "enable recursive planning"); namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::google::api::expr::parser::EnrichedParse; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::RequestContext; using ::google::rpc::context::AttributeContext; RuntimeOptions GetOptions() { RuntimeOptions options; if (absl::GetFlag(FLAGS_enable_recursive_planning)) { options.max_recursion_depth = -1; } return options; } enum class ConstFoldingEnabled { kNo, kYes }; std::unique_ptr StandardRuntimeOrDie( const cel::RuntimeOptions& options, google::protobuf::Arena* arena = nullptr, ConstFoldingEnabled const_folding = ConstFoldingEnabled::kNo) { auto builder = CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options); ABSL_CHECK_OK(builder.status()); switch (const_folding) { case ConstFoldingEnabled::kNo: break; case ConstFoldingEnabled::kYes: ABSL_CHECK(arena != nullptr); ABSL_CHECK_OK(extensions::EnableConstantFolding(*builder)); break; } auto runtime = std::move(builder).value().Build(); ABSL_CHECK_OK(runtime.status()); return std::move(runtime).value(); } template Value WrapMessageOrDie(const T& message, google::protobuf::Arena* absl_nonnull arena) { auto value = extensions::ProtoMessageToValue( message, internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory(), arena); ABSL_CHECK_OK(value.status()); return std::move(value).value(); } // Benchmark test // Evaluates cel expression: // '1 + 1 + 1 .... +1' static void BM_Eval(benchmark::State& state) { RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); int len = state.range(0); Expr root_expr; Expr* cur_expr = &root_expr; for (int i = 0; i < len; i++) { Expr::Call* call = cur_expr->mutable_call_expr(); call->set_function("_+_"); call->add_args()->mutable_const_expr()->set_int64_value(1); cur_expr = call->add_args(); } cur_expr->mutable_const_expr()->set_int64_value(1); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, root_expr)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_TRUE(Cast(result) == len + 1); } } BENCHMARK(BM_Eval)->Range(1, 10000); absl::Status EmptyCallback(int64_t expr_id, const Value&, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) { return absl::OkStatus(); } // Benchmark test // Traces cel expression with an empty callback: // '1 + 1 + 1 .... +1' static void BM_Eval_Trace(benchmark::State& state) { RuntimeOptions options = GetOptions(); options.enable_recursive_tracing = true; auto runtime = StandardRuntimeOrDie(options); int len = state.range(0); Expr root_expr; Expr* cur_expr = &root_expr; for (int i = 0; i < len; i++) { Expr::Call* call = cur_expr->mutable_call_expr(); call->set_function("_+_"); call->add_args()->mutable_const_expr()->set_int64_value(1); cur_expr = call->add_args(); } cur_expr->mutable_const_expr()->set_int64_value(1); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, root_expr)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Trace(&arena, activation, EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_TRUE(Cast(result) == len + 1); } } // A number higher than 10k leads to a stack overflow due to the recursive // nature of the proto to native type conversion. BENCHMARK(BM_Eval_Trace)->Range(1, 10000); // Benchmark test // Evaluates cel expression: // '"a" + "a" + "a" .... + "a"' static void BM_EvalString(benchmark::State& state) { RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); int len = state.range(0); Expr root_expr; Expr* cur_expr = &root_expr; for (int i = 0; i < len; i++) { Expr::Call* call = cur_expr->mutable_call_expr(); call->set_function("_+_"); call->add_args()->mutable_const_expr()->set_string_value("a"); cur_expr = call->add_args(); } cur_expr->mutable_const_expr()->set_string_value("a"); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, root_expr)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_TRUE(Cast(result).Size() == len + 1); } } // A number higher than 10k leads to a stack overflow due to the recursive // nature of the proto to native type conversion. BENCHMARK(BM_EvalString)->Range(1, 10000); // Benchmark test // Traces cel expression with an empty callback: // '"a" + "a" + "a" .... + "a"' static void BM_EvalString_Trace(benchmark::State& state) { RuntimeOptions options = GetOptions(); options.enable_recursive_tracing = true; auto runtime = StandardRuntimeOrDie(options); int len = state.range(0); Expr root_expr; Expr* cur_expr = &root_expr; for (int i = 0; i < len; i++) { Expr::Call* call = cur_expr->mutable_call_expr(); call->set_function("_+_"); call->add_args()->mutable_const_expr()->set_string_value("a"); cur_expr = call->add_args(); } cur_expr->mutable_const_expr()->set_string_value("a"); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, root_expr)); for (auto _ : state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Trace(&arena, activation, EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_TRUE(Cast(result).Size() == len + 1); } } // A number higher than 10k leads to a stack overflow due to the recursive // nature of the proto to native type conversion. BENCHMARK(BM_EvalString_Trace)->Range(1, 10000); const char kIP[] = "10.0.1.2"; const char kPath[] = "/admin/edit"; const char kToken[] = "admin"; ABSL_ATTRIBUTE_NOINLINE bool NativeCheck(absl::btree_map& attributes, const absl::flat_hash_set& denylists, const absl::flat_hash_set& allowlists) { auto& ip = attributes["ip"]; auto& path = attributes["path"]; auto& token = attributes["token"]; if (denylists.find(ip) != denylists.end()) { return false; } if (absl::StartsWith(path, "v1")) { if (token == "v1" || token == "v2" || token == "admin") { return true; } } else if (absl::StartsWith(path, "v2")) { if (token == "v2" || token == "admin") { return true; } } else if (absl::StartsWith(path, "/admin")) { if (token == "admin") { if (allowlists.find(ip) != allowlists.end()) { return true; } } } return false; } void BM_PolicyNative(benchmark::State& state) { const auto denylists = absl::flat_hash_set{"10.0.1.4", "10.0.1.5", "10.0.1.6"}; const auto allowlists = absl::flat_hash_set{"10.0.1.1", "10.0.1.2", "10.0.1.3"}; auto attributes = absl::btree_map{ {"ip", kIP}, {"token", kToken}, {"path", kPath}}; for (auto _ : state) { auto result = NativeCheck(attributes, denylists, allowlists); ASSERT_TRUE(result); } } BENCHMARK(BM_PolicyNative); void BM_PolicySymbolic(benchmark::State& state) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( !(ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && ((path.startsWith("v1") && token in ["v1", "v2", "admin"]) || (path.startsWith("v2") && token in ["v2", "admin"]) || (path.startsWith("/admin") && token == "admin" && ip in [ "10.0.1.1", "10.0.1.2", "10.0.1.3" ]) ))cel")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options, &arena, ConstFoldingEnabled::kYes); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); Activation activation; activation.InsertOrAssignValue("ip", StringValue(&arena, kIP)); activation.InsertOrAssignValue("path", StringValue(&arena, kPath)); activation.InsertOrAssignValue("token", StringValue(&arena, kToken)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); auto result_bool = As(result); ASSERT_TRUE(result_bool && result_bool->NativeValue()); } } BENCHMARK(BM_PolicySymbolic); class RequestMapImpl : public CustomMapValueInterface { public: size_t Size() const override { return 3; } absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const override { return absl::UnimplementedError("Unsupported"); } absl::StatusOr NewIterator() const override { return absl::UnimplementedError("Unsupported"); } std::string DebugString() const override { return "RequestMapImpl"; } absl::Status ConvertToJsonObject( const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Message* absl_nonnull) const override { return absl::UnimplementedError("Unsupported"); } CustomMapValue Clone(google::protobuf::Arena* absl_nonnull arena) const override { return CustomMapValue(google::protobuf::Arena::Create(arena), arena); } protected: // Called by `Find` after performing various argument checks. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override { auto string_value = As(key); if (!string_value) { return false; } if (string_value->Equals("ip")) { *result = StringValue(kIP); } else if (string_value->Equals("path")) { *result = StringValue(kPath); } else if (string_value->Equals("token")) { *result = StringValue(kToken); } else { return false; } return true; } // Called by `Has` after performing various argument checks. absl::StatusOr Has( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override { return absl::UnimplementedError("Unsupported."); } private: NativeTypeId GetNativeTypeId() const override { return NativeTypeId::For(); } }; // Uses a lazily constructed map container for "ip", "path", and "token". void BM_PolicySymbolicMap(benchmark::State& state) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || (request.path.startsWith("/admin") && request.token == "admin" && request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) ))cel")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); Activation activation; CustomMapValue map_value(google::protobuf::Arena::Create(&arena), &arena); activation.InsertOrAssignValue("request", std::move(map_value)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_PolicySymbolicMap); // Uses a protobuf container for "ip", "path", and "token". void BM_PolicySymbolicProto(benchmark::State& state) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && ((request.path.startsWith("v1") && request.token in ["v1", "v2", "admin"]) || (request.path.startsWith("v2") && request.token in ["v2", "admin"]) || (request.path.startsWith("/admin") && request.token == "admin" && request.ip in ["10.0.1.1", "10.0.1.2", "10.0.1.3"]) ))cel")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); SourceInfo source_info; ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); Activation activation; RequestContext request; request.set_ip(kIP); request.set_path(kPath); request.set_token(kToken); activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_PolicySymbolicProto); // This expression has no equivalent CEL constexpr char kListSum[] = R"( id: 1 comprehension_expr: < accu_var: "__result__" iter_var: "x" iter_range: < id: 2 ident_expr: < name: "list_var" > > accu_init: < id: 3 const_expr: < int64_value: 0 > > loop_step: < id: 4 call_expr: < function: "_+_" args: < id: 5 ident_expr: < name: "__result__" > > args: < id: 6 ident_expr: < name: "x" > > > > loop_condition: < id: 7 const_expr: < bool_value: true > > result: < id: 8 ident_expr: < name: "__result__" > > >)"; void BM_Comprehension(benchmark::State& state) { RuntimeOptions options = GetOptions(); options.comprehension_max_iterations = 10000000; auto runtime = StandardRuntimeOrDie(options); Expr expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); google::protobuf::Arena arena; Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); } activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_EQ(Cast(result), len); } } BENCHMARK(BM_Comprehension)->Range(1, 1 << 20); void BM_Comprehension_Trace(benchmark::State& state) { RuntimeOptions options = GetOptions(); options.enable_recursive_tracing = true; options.comprehension_max_iterations = 10000000; auto runtime = StandardRuntimeOrDie(options); google::protobuf::Arena arena; Expr expr; Activation activation; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListSum, &expr)); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); } activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Trace(&arena, activation, EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_EQ(Cast(result), len); } } BENCHMARK(BM_Comprehension_Trace)->Range(1, 1 << 20); void BM_HasMap(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("has(request.path) && !has(request.ip)")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); auto map_builder = cel::NewMapValueBuilder(&arena); ASSERT_THAT( map_builder->Put(cel::StringValue("path"), cel::StringValue("path")), IsOk()); activation.InsertOrAssignValue("request", std::move(*map_builder).Build()); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_HasMap); void BM_HasProto(benchmark::State& state) { RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("has(request.path) && !has(request.ip)")); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; RequestContext request; request.set_path(kPath); request.set_token(kToken); activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_HasProto); void BM_HasProtoMap(benchmark::State& state) { RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("has(request.headers.create_time) && " "!has(request.headers.update_time)")); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; RequestContext request; request.mutable_headers()->insert({"create_time", "2021-01-01"}); activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_HasProtoMap); void BM_ReadProtoMap(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( request.headers.create_time == "2021-01-01" )cel")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; RequestContext request; request.mutable_headers()->insert({"create_time", "2021-01-01"}); activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_ReadProtoMap); void BM_NestedProtoFieldRead(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( !request.a.b.c.d.e )cel")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; RequestContext request; request.mutable_a()->mutable_b()->mutable_c()->mutable_d()->set_e(false); activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_NestedProtoFieldRead); void BM_NestedProtoFieldReadDefaults(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( !request.a.b.c.d.e )cel")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; RequestContext request; activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_NestedProtoFieldReadDefaults); void BM_ProtoStructAccess(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( has(request.auth.claims.iss) && request.auth.claims.iss == 'accounts.google.com' )cel")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; AttributeContext::Request request; auto* auth = request.mutable_auth(); (*auth->mutable_claims()->mutable_fields())["iss"].set_string_value( "accounts.google.com"); activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_ProtoStructAccess); void BM_ProtoListAccess(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(R"cel( "//.../accessLevels/MY_LEVEL_4" in request.auth.access_levels )cel")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; AttributeContext::Request request; auto* auth = request.mutable_auth(); auth->add_access_levels("//.../accessLevels/MY_LEVEL_0"); auth->add_access_levels("//.../accessLevels/MY_LEVEL_1"); auth->add_access_levels("//.../accessLevels/MY_LEVEL_2"); auth->add_access_levels("//.../accessLevels/MY_LEVEL_3"); auth->add_access_levels("//.../accessLevels/MY_LEVEL_4"); activation.InsertOrAssignValue("request", WrapMessageOrDie(request, &arena)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result) && Cast(result).NativeValue()); } } BENCHMARK(BM_ProtoListAccess); // This expression has no equivalent CEL expression. // Sum a square with a nested comprehension constexpr char kNestedListSum[] = R"( id: 1 comprehension_expr: < accu_var: "__result__" iter_var: "x" iter_range: < id: 2 ident_expr: < name: "list_var" > > accu_init: < id: 3 const_expr: < int64_value: 0 > > loop_step: < id: 4 call_expr: < function: "_+_" args: < id: 5 ident_expr: < name: "__result__" > > args: < id: 6 comprehension_expr: < accu_var: "__result__" iter_var: "x" iter_range: < id: 9 ident_expr: < name: "list_var" > > accu_init: < id: 10 const_expr: < int64_value: 0 > > loop_step: < id: 11 call_expr: < function: "_+_" args: < id: 12 ident_expr: < name: "__result__" > > args: < id: 13 ident_expr: < name: "x" > > > > loop_condition: < id: 14 const_expr: < bool_value: true > > result: < id: 15 ident_expr: < name: "__result__" > > > > > > loop_condition: < id: 7 const_expr: < bool_value: true > > result: < id: 8 ident_expr: < name: "__result__" > > >)"; void BM_NestedComprehension(benchmark::State& state) { Expr expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); RuntimeOptions options = GetOptions(); options.comprehension_max_iterations = 10000000; auto runtime = StandardRuntimeOrDie(options); google::protobuf::Arena arena; Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); } activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_EQ(Cast(result), len * len); } } BENCHMARK(BM_NestedComprehension)->Range(1, 1 << 10); void BM_NestedComprehension_Trace(benchmark::State& state) { Expr expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedListSum, &expr)); RuntimeOptions options = GetOptions(); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; options.enable_recursive_tracing = true; auto runtime = StandardRuntimeOrDie(options); google::protobuf::Arena arena; Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); } activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Trace(&arena, activation, &EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_EQ(Cast(result), len * len); } } BENCHMARK(BM_NestedComprehension_Trace)->Range(1, 1 << 10); void BM_ListComprehension(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("list_var.map(x, x * 2)")); RuntimeOptions options = GetOptions(); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); } activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); } } BENCHMARK(BM_ListComprehension)->Range(1, 1 << 16); void BM_ListComprehension_Trace(benchmark::State& state) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("list_var.map(x, x * 2)")); RuntimeOptions options = GetOptions(); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; options.enable_recursive_tracing = true; auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); } activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Trace(&arena, activation, EmptyCallback)); ASSERT_TRUE(InstanceOf(result)); ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); } } BENCHMARK(BM_ListComprehension_Trace)->Range(1, 1 << 16); void BM_ExistsComprehensionBestCase(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("my_int_list.exists(x, x == 1)")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); activation.InsertOrAssignValue("my_int_list", std::move(*list_builder).Build()); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.GetBool().NativeValue()); } } BENCHMARK(BM_ExistsComprehensionBestCase); void BM_ExistsComprehensionWorstCase(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("my_int_list.exists(x, x == -1)")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(i)), IsOk()); } activation.InsertOrAssignValue("my_int_list", std::move(*list_builder).Build()); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.GetBool().NativeValue()); } } BENCHMARK(BM_ExistsComprehensionWorstCase)->Range(1, 1 << 10); void BM_AllComprehensionBestCase(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("my_int_list.exists(x, x != 1)")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); activation.InsertOrAssignValue("my_int_list", std::move(*list_builder).Build()); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(result.IsBool()); ASSERT_FALSE(result.GetBool().NativeValue()); } } BENCHMARK(BM_AllComprehensionBestCase); void BM_AllComprehensionWorstCase(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("my_int_list.all(x, x != -1)")); RuntimeOptions options = GetOptions(); auto runtime = StandardRuntimeOrDie(options); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(i)), IsOk()); } activation.InsertOrAssignValue("my_int_list", std::move(*list_builder).Build()); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(result.IsBool()); ASSERT_TRUE(result.GetBool().NativeValue()); } } BENCHMARK(BM_AllComprehensionWorstCase)->Range(1, 1 << 10); void BM_ListComprehension_Opt(benchmark::State& state) { google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("list_var.map(x, x * 2)")); RuntimeOptions options = GetOptions(); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; auto runtime = StandardRuntimeOrDie(options, &arena, ConstFoldingEnabled::kYes); Activation activation; auto list_builder = cel::NewListValueBuilder(&arena); int len = state.range(0); list_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(list_builder->Add(IntValue(1)), IsOk()); } activation.InsertOrAssignValue("list_var", std::move(*list_builder).Build()); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); } } BENCHMARK(BM_ListComprehension_Opt)->Range(1, 1 << 16); void BM_ComprehensionCpp(benchmark::State& state) { Activation activation; std::vector list; int len = state.range(0); list.reserve(len); for (int i = 0; i < len; i++) { list.push_back(IntValue(1)); } auto op = [&list]() { int sum = 0; for (const auto& value : list) { sum += Cast(value).NativeValue(); } return sum; }; for (auto _ : state) { int result = op(); ASSERT_EQ(result, len); } } BENCHMARK(BM_ComprehensionCpp)->Range(1, 1 << 20); void BM_MapTransformComprehension(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("map_var.transformMapEntry(k, v, {v:k})")); MacroRegistry registry; ASSERT_THAT( extensions::RegisterComprehensionsV2Macros(registry, ParserOptions()), IsOk()); ASSERT_OK_AND_ASSIGN(auto parsed_expr, EnrichedParse(*source, registry, ParserOptions())); RuntimeOptions options = GetOptions(); options.comprehension_max_iterations = 10000000; // This is a critical optimization: it allows the comprehension to accumulate // results in a mutable map instead of cloning and augmenting an unmodifiable // map on every iteration. options.enable_comprehension_mutable_map = true; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_THAT(extensions::RegisterComprehensionsV2Functions( builder.function_registry(), options), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); google::protobuf::Arena arena; Activation activation; auto map_builder = cel::NewMapValueBuilder(&arena); int len = state.range(0); map_builder->Reserve(len); for (int i = 0; i < len; i++) { ASSERT_THAT(map_builder->Put(IntValue(i), IntValue(i)), IsOk()); } activation.InsertOrAssignValue("map_var", std::move(*map_builder).Build()); ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr.parsed_expr())); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(cel::Value result, cel_expr->Evaluate(&arena, activation)); ASSERT_TRUE(InstanceOf(result)); ASSERT_THAT(Cast(result).Size(), IsOkAndHolds(len)); } } BENCHMARK(BM_MapTransformComprehension)->Range(1, 1 << 16); } // namespace } // namespace cel ================================================ FILE: eval/tests/request_context.proto ================================================ syntax = "proto3"; package google.api.expr.runtime; option cc_enable_arenas = true; // Message representing a sample request context message RequestContext { // Example for deeply nested messages. message D { bool e = 1; } message C { D d = 1; } message B { C c = 1; } message A { B b = 1; } string ip = 1; string path = 2; string token = 3; map headers = 4; A a = 5; } ================================================ FILE: eval/tests/unknowns_end_to_end_test.cc ================================================ // Integration tests for unknown processing in the C++ CEL runtime. The // semantics of some of the tested expressions can be complicated because isn't // possible to represent unknown values or errors directly in CEL -- declaring // the unknowns is particular to the runtime. #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/struct.pb.h" #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/attribute.h" #include "base/function_result.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/unknown_set.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/internal/activation_attribute_matcher_access.h" #include "runtime/internal/attribute_matcher.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { using ::absl_testing::IsOk; using ::cel::runtime_internal::ActivationAttributeMatcherAccess; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::protobuf::Arena; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; absl::StatusOr MakeCelMap(absl::string_view expr, google::protobuf::Arena* arena) { static CelExpressionBuilder* builder = []() { return CreateCelExpressionBuilder(InterpreterOptions()).release(); }(); static absl::NoDestructor activation; CEL_ASSIGN_OR_RETURN(ParsedExpr parsed_expr, Parse(expr)); CEL_ASSIGN_OR_RETURN(auto plan, builder->CreateExpression(&parsed_expr.expr(), nullptr)); absl::StatusOr result = plan->Evaluate(*activation, arena); if (!result.ok()) { return result.status(); } if (!result->IsMap()) { return absl::FailedPreconditionError( absl::StrCat("expression did not evaluate to a map: ", expr)); } return result; } enum class FunctionResponse { kUnknown, kTrue, kFalse }; CelFunctionDescriptor CreateDescriptor( absl::string_view name, CelValue::Type type = CelValue::Type::kString) { return CelFunctionDescriptor(std::string(name), false, {type}); } class FunctionImpl : public CelFunction { public: FunctionImpl(absl::string_view name, FunctionResponse response, CelValue::Type type = CelValue::Type::kString) : CelFunction(CreateDescriptor(name, type)), response_(response) {} absl::Status Evaluate(absl::Span arguments, CelValue* result, Arena* arena) const override { switch (response_) { case FunctionResponse::kUnknown: *result = CreateUnknownFunctionResultError(arena, "help message"); break; case FunctionResponse::kTrue: *result = CelValue::CreateBool(true); break; case FunctionResponse::kFalse: *result = CelValue::CreateBool(false); break; } return absl::OkStatus(); } private: FunctionResponse response_; }; // Text fixture for unknowns. Holds on to state needed for execution to work // correctly. class UnknownsTest : public testing::Test { public: void PrepareBuilder(UnknownProcessingOptions opts) { InterpreterOptions options; options.unknown_processing = opts; builder_ = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder_->GetRegistry()), IsOk()); ASSERT_THAT( builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F1")), IsOk()); ASSERT_THAT( builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F2")), IsOk()); } protected: Arena arena_; Activation activation_; std::unique_ptr builder_; }; MATCHER_P(FunctionCallIs, fn_name, "") { const cel::FunctionResult& result = arg; return result.descriptor().name() == fn_name; } MATCHER_P(AttributeIs, attr, "") { const cel::Attribute& result = arg; return result.AsString().value_or("") == attr; } TEST_F(UnknownsTest, NoUnknowns) { PrepareBuilder(UnknownProcessingOptions::kDisabled); activation_.InsertValue("var1", CelValue::CreateInt64(3)); activation_.InsertValue("var2", CelValue::CreateInt64(5)); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F1", FunctionResponse::kFalse)), IsOk()); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F2", FunctionResponse::kTrue)), IsOk()); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); auto plan = builder_->CreateExpression(&expr.expr(), nullptr); ASSERT_THAT(plan, IsOk()); ASSERT_OK_AND_ASSIGN(CelValue response, plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsBool()) << response.DebugString(); EXPECT_TRUE(response.BoolOrDie()); } TEST_F(UnknownsTest, UnknownAttributes) { PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); activation_.InsertValue("var2", CelValue::CreateInt64(3)); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F1", FunctionResponse::kTrue)), IsOk()); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F2", FunctionResponse::kFalse)), IsOk()); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); auto plan = builder_->CreateExpression(&expr.expr(), nullptr); ASSERT_THAT(plan, IsOk()); ASSERT_OK_AND_ASSIGN(CelValue response, plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsUnknownSet()); EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes(), ElementsAre(AttributeIs("var1"))); } TEST_F(UnknownsTest, UnknownAttributesPruning) { PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); activation_.InsertValue("var2", CelValue::CreateInt64(5)); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F1", FunctionResponse::kTrue)), IsOk()); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F2", FunctionResponse::kTrue)), IsOk()); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); auto plan = builder_->CreateExpression(&expr.expr(), nullptr); ASSERT_THAT(plan, IsOk()); ASSERT_OK_AND_ASSIGN(CelValue response, plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsBool()); EXPECT_TRUE(response.BoolOrDie()); } class CustomMatcher : public cel::runtime_internal::AttributeMatcher { public: MatchResult CheckForUnknown(const cel::Attribute& attr) const override { // Rendering to a string just for ease of testing. std::string name = attr.AsString().value_or(""); if (name == "var1") { return MatchResult::PARTIAL; } else if (name == "var1.foo") { return MatchResult::FULL; } return MatchResult::NONE; } }; TEST_F(UnknownsTest, UnknownAttributesCustomMatcher) { PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); ASSERT_OK_AND_ASSIGN(auto var1, MakeCelMap("{'bar': 1}", &arena_)); activation_.InsertValue("var1", var1); CustomMatcher matcher; ActivationAttributeMatcherAccess::SetAttributeMatcher(activation_, &matcher); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F1", FunctionResponse::kTrue, CelValue::Type::kMap)), IsOk()); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F2", FunctionResponse::kTrue)), IsOk()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("F1(var1) || var1.foo || var1.bar")); auto plan = builder_->CreateExpression(&expr.expr(), nullptr); ASSERT_THAT(plan, IsOk()); ASSERT_OK_AND_ASSIGN(CelValue response, plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsUnknownSet()) << response.DebugString(); EXPECT_THAT( response.UnknownSetOrDie()->unknown_attributes(), UnorderedElementsAre(AttributeIs("var1"), AttributeIs("var1.foo"))); } TEST_F(UnknownsTest, UnknownFunctionsWithoutOptionError) { PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); activation_.InsertValue("var1", CelValue::CreateInt64(5)); activation_.InsertValue("var2", CelValue::CreateInt64(3)); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F1", FunctionResponse::kUnknown)), IsOk()); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F2", FunctionResponse::kFalse)), IsOk()); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); auto plan = builder_->CreateExpression(&expr.expr(), nullptr); ASSERT_THAT(plan, IsOk()); ASSERT_OK_AND_ASSIGN(CelValue response, plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsError()); EXPECT_EQ(response.ErrorOrDie()->code(), absl::StatusCode::kUnavailable); } TEST_F(UnknownsTest, UnknownFunctions) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); activation_.InsertValue("var1", CelValue::CreateInt64(5)); activation_.InsertValue("var2", CelValue::CreateInt64(5)); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F1", FunctionResponse::kUnknown)), IsOk()); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F2", FunctionResponse::kFalse)), IsOk()); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); auto plan = builder_->CreateExpression(&expr.expr(), nullptr); ASSERT_THAT(plan, IsOk()); ASSERT_OK_AND_ASSIGN(CelValue response, plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), ElementsAre(FunctionCallIs("F1"))); } TEST_F(UnknownsTest, UnknownsMerge) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); activation_.InsertValue("var1", CelValue::CreateInt64(5)); activation_.set_unknown_attribute_patterns({CelAttributePattern("var2", {})}); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F1", FunctionResponse::kUnknown)), IsOk()); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "F2", FunctionResponse::kTrue)), IsOk()); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2')")); auto plan = builder_->CreateExpression(&expr.expr(), nullptr); ASSERT_THAT(plan, IsOk()); ASSERT_OK_AND_ASSIGN(CelValue response, plan.value()->Evaluate(activation_, &arena_)); ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), ElementsAre(FunctionCallIs("F1"))); EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes(), ElementsAre(AttributeIs("var2"))); } constexpr char kListCompExistsExpr[] = R"pb( id: 25 comprehension_expr { iter_var: "x" iter_range { id: 1 list_expr { elements { id: 2 const_expr { int64_value: 1 } } elements { id: 3 const_expr { int64_value: 2 } } elements { id: 4 const_expr { int64_value: 3 } } elements { id: 5 const_expr { int64_value: 4 } } elements { id: 6 const_expr { int64_value: 5 } } elements { id: 7 const_expr { int64_value: 6 } } elements { id: 8 const_expr { int64_value: 7 } } elements { id: 9 const_expr { int64_value: 8 } } elements { id: 10 const_expr { int64_value: 9 } } elements { id: 11 const_expr { int64_value: 10 } } } } accu_var: "__result__" accu_init { id: 18 const_expr { bool_value: false } } loop_condition { id: 21 call_expr { function: "@not_strictly_false" args { id: 20 call_expr { function: "!_" args { id: 19 ident_expr { name: "__result__" } } } } } } loop_step { id: 23 call_expr { function: "_||_" args { id: 22 ident_expr { name: "__result__" } } args { id: 16 call_expr { function: "_>_" args { id: 14 call_expr { function: "Fn" args { id: 15 ident_expr { name: "x" } } } } args { id: 17 const_expr { int64_value: 2 } } } } } } result { id: 24 ident_expr { name: "__result__" } } })pb"; // Text fixture for comprehension tests. Holds on to state needed for execution // to work correctly. class UnknownsCompTest : public testing::Test { public: void PrepareBuilder(UnknownProcessingOptions opts) { InterpreterOptions options; options.unknown_processing = opts; builder_ = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder_->GetRegistry()), IsOk()); ASSERT_THAT(builder_->GetRegistry()->RegisterLazyFunction( CreateDescriptor("Fn", CelValue::Type::kInt64)), IsOk()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kListCompExistsExpr, &expr_)) << "error parsing expr"; } protected: Arena arena_; Activation activation_; std::unique_ptr builder_; Expr expr_; }; TEST_F(UnknownsCompTest, UnknownsMerge) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "Fn", FunctionResponse::kUnknown, CelValue::Type::kInt64)), IsOk()); // [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].exists(x, Fn(x) > 5) auto build_status = builder_->CreateExpression(&expr_, nullptr); ASSERT_THAT(build_status, IsOk()); auto eval_status = build_status.value()->Evaluate(activation_, &arena_); ASSERT_THAT(eval_status, IsOk()); CelValue response = eval_status.value(); ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), testing::SizeIs(1)); } constexpr char kListCompCondExpr[] = R"pb( id: 25 comprehension_expr { iter_var: "x" iter_range { id: 1 list_expr { elements { id: 2 const_expr { int64_value: 1 } } elements { id: 3 const_expr { int64_value: 2 } } elements { id: 11 const_expr { int64_value: 3 } } } } accu_var: "__result__" accu_init { id: 18 const_expr { int64_value: 0 } } loop_condition { id: 21 call_expr { function: "_<=_" args { id: 20 ident_expr { name: "__result__" } } args { id: 19 const_expr { int64_value: 1 } } } } loop_step { id: 23 call_expr { function: "_?_:_" args { id: 22 call_expr: { function: "Fn" args { id: 4 ident_expr { name: "x" } } } } args { id: 14 call_expr { function: "_+_" args { id: 15 ident_expr { name: "__result__" } } args { id: 17 const_expr { int64_value: 1 } } } } args { id: 16 ident_expr { name: "__result__" } } } } result { id: 24 call_expr { function: "_==_" args { id: 27 ident_expr { name: "__result__" } } args { id: 26 const_expr { int64_value: 1 } } } } })pb"; // Text fixture for comprehension tests affecting the condition step. // Holds on to state needed for execution to work correctly. class UnknownsCompCondTest : public testing::Test { public: void PrepareBuilder(UnknownProcessingOptions opts) { InterpreterOptions options; options.unknown_processing = opts; builder_ = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder_->GetRegistry()), IsOk()); ASSERT_THAT(builder_->GetRegistry()->RegisterLazyFunction( CreateDescriptor("Fn", CelValue::Type::kInt64)), IsOk()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kListCompCondExpr, &expr_)) << "error parsing expr"; } protected: Arena arena_; Activation activation_; std::unique_ptr builder_; Expr expr_; }; TEST_F(UnknownsCompCondTest, UnknownConditionReturned) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); ASSERT_THAT(activation_.InsertFunction(std::make_unique( "Fn", FunctionResponse::kUnknown, CelValue::Type::kInt64)), IsOk()); // [1, 2, 3].exists_one(x, Fn(x)) auto build_status = builder_->CreateExpression(&expr_, nullptr); ASSERT_THAT(build_status, IsOk()); auto eval_status = build_status.value()->Evaluate(activation_, &arena_); ASSERT_THAT(eval_status, IsOk()); CelValue response = eval_status.value(); ASSERT_TRUE(response.IsUnknownSet()) << *response.ErrorOrDie(); // The comprehension ends on the first non-bool condition, so we only get one // call captured in the UnknownSet. EXPECT_THAT(response.UnknownSetOrDie()->unknown_function_results(), testing::SizeIs(1)); } TEST_F(UnknownsCompCondTest, ErrorConditionReturned) { PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); // No implementation for Fn(int64) provided in activation -- this turns into a // CelError. // [1, 2, 3].exists_one(x, Fn(x)) auto build_status = builder_->CreateExpression(&expr_, nullptr); ASSERT_THAT(build_status, IsOk()); auto eval_status = build_status.value()->Evaluate(activation_, &arena_); ASSERT_THAT(eval_status, IsOk()); CelValue response = eval_status.value(); ASSERT_TRUE(response.IsError()) << CelValue::TypeName(response.type()); EXPECT_TRUE(CheckNoMatchingOverloadError(response)); } constexpr char kListCompExistsWithAttrExpr[] = R"pb( id: 25 comprehension_expr { iter_var: "x" iter_range { id: 1 ident_expr { name: "var" } } accu_var: "__result__" accu_init { id: 18 const_expr { bool_value: false } } loop_condition { id: 21 call_expr { function: "@not_strictly_false" args { id: 20 call_expr { function: "!_" args { id: 19 ident_expr { name: "__result__" } } } } } } loop_step { id: 23 call_expr { function: "_||_" args { id: 22 ident_expr { name: "__result__" } } args { id: 16 call_expr { function: "Fn" args { id: 15 ident_expr { name: "x" } } } } } } result { id: 24 ident_expr { name: "__result__" } } })pb"; TEST(UnknownsIterAttrTest, IterAttributeTrail) { InterpreterOptions options; Expr expr; Activation activation; Arena arena; protobuf::Value element; protobuf::Value& value = element.mutable_struct_value()->mutable_fields()->operator[]("elem1"); value.set_number_value(1); protobuf::ListValue list; *list.add_values() = element; *list.add_values() = element; *list.add_values() = element; options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_THAT(builder->GetRegistry()->RegisterLazyFunction( CreateDescriptor("Fn", CelValue::Type::kMap)), IsOk()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kListCompExistsWithAttrExpr, &expr)) << "error parsing expr"; // var.exists(x, Fn(x)) auto plan = builder->CreateExpression(&expr, nullptr).value(); activation.InsertValue("var", CelProtoWrapper::CreateMessage(&list, &arena)); // var[1]['elem1'] is unknown activation.set_unknown_attribute_patterns({CelAttributePattern( "var", { CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)), CreateCelAttributeQualifierPattern( CelValue::CreateStringView("elem1")), })}); ASSERT_THAT(activation.InsertFunction(std::make_unique( "Fn", FunctionResponse::kFalse, CelValue::Type::kMap)), IsOk()); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1]' is partially unknown when we make the function call so we treat it // as unknown. ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() .begin() ->qualifier_path() .size(), 1); } TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypes) { InterpreterOptions options; Expr expr; Activation activation; Arena arena; UnknownSet unknown_set; CelError error = absl::CancelledError(); std::vector> backing; backing.push_back( {CelValue::CreateUnknownSet(&unknown_set), CelValue::CreateBool(false)}); backing.push_back( {CelValue::CreateError(&error), CelValue::CreateBool(false)}); backing.push_back({CelValue::CreateBool(true), CelValue::CreateBool(false)}); auto map_impl = CreateContainerBackedMap(absl::MakeSpan(backing)).value(); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_THAT(builder->GetRegistry()->RegisterLazyFunction( CreateDescriptor("Fn", CelValue::Type::kBool)), IsOk()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kListCompExistsWithAttrExpr, &expr)) << "error parsing expr"; // var.exists(x, Fn(x)) auto plan = builder->CreateExpression(&expr, nullptr).value(); activation.InsertValue("var", CelValue::CreateMap(map_impl.get())); ASSERT_THAT(activation.InsertFunction(std::make_unique( "Fn", FunctionResponse::kFalse, CelValue::Type::kBool)), IsOk()); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); ASSERT_EQ(*response.UnknownSetOrDie(), unknown_set); } TEST(UnknownsIterAttrTest, IterAttributeTrailMapKeyTypesShortcutted) { InterpreterOptions options; Expr expr; Activation activation; Arena arena; UnknownSet unknown_set; CelError error = absl::CancelledError(); std::vector> backing; backing.push_back( {CelValue::CreateUnknownSet(&unknown_set), CelValue::CreateBool(false)}); backing.push_back( {CelValue::CreateError(&error), CelValue::CreateBool(false)}); backing.push_back({CelValue::CreateBool(true), CelValue::CreateBool(false)}); auto map_impl = CreateContainerBackedMap(absl::MakeSpan(backing)).value(); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_THAT(builder->GetRegistry()->RegisterLazyFunction( CreateDescriptor("Fn", CelValue::Type::kBool)), IsOk()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(kListCompExistsWithAttrExpr, &expr)) << "error parsing expr"; // var.exists(x, Fn(x)) auto plan = builder->CreateExpression(&expr, nullptr).value(); activation.InsertValue("var", CelValue::CreateMap(map_impl.get())); ASSERT_THAT(activation.InsertFunction(std::make_unique( "Fn", FunctionResponse::kTrue, CelValue::Type::kBool)), IsOk()); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsBool()) << CelValue::TypeName(response.type()); ASSERT_TRUE(response.BoolOrDie()); } constexpr char kMapElementsComp[] = R"pb( id: 25 comprehension_expr { iter_var: "x" iter_range { id: 1 ident_expr { name: "var" } } accu_var: "__result__" accu_init { id: 2 list_expr {} } loop_condition { id: 3 const_expr { bool_value: true } } loop_step { id: 4 call_expr { function: "_+_" args { id: 5 ident_expr { name: "__result__" } } args { id: 6 list_expr { elements { id: 9 call_expr { function: "Fn" args { id: 7 select_expr { field: "key" operand { id: 8 ident_expr { name: "x" } } } } } } } } } } result { id: 9 ident_expr { name: "__result__" } } })pb"; // TODO(issues/67): Expected behavior for maps with unknown keys/values in a // comprehension is a little unclear and the test coverage is a bit sparse. // A few more tests should be added for coverage and to help document. TEST(UnknownsIterAttrTest, IterAttributeTrailMap) { InterpreterOptions options; Expr expr; Activation activation; Arena arena; protobuf::Value element; protobuf::Value& value = element.mutable_struct_value()->mutable_fields()->operator[]("key"); value.set_number_value(1); protobuf::ListValue list; *list.add_values() = element; *list.add_values() = element; *list.add_values() = element; options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_THAT(builder->GetRegistry()->RegisterLazyFunction( CreateDescriptor("Fn", CelValue::Type::kDouble)), IsOk()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kMapElementsComp, &expr)) << "error parsing expr"; activation.InsertValue("var", CelProtoWrapper::CreateMessage(&list, &arena)); // var[1]['key'] is unknown activation.set_unknown_attribute_patterns({CelAttributePattern( "var", { CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)), CreateCelAttributeQualifierPattern(CelValue::CreateStringView("key")), })}); ASSERT_THAT(activation.InsertFunction(std::make_unique( "Fn", FunctionResponse::kFalse, CelValue::Type::kDouble)), IsOk()); auto plan = builder->CreateExpression(&expr, nullptr).value(); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1].key' is unknown when we make the Fn function call. // comprehension is: ((([] + false) + unk) + false) -> unk ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() .begin() ->qualifier_path() .size(), 2); } constexpr char kFilterElementsComp[] = R"pb( id: 25 comprehension_expr { iter_var: "x" iter_range { id: 1 ident_expr { name: "var" } } accu_var: "__result__" accu_init { id: 2 list_expr {} } loop_condition { id: 3 const_expr { bool_value: true } } loop_step { id: 4 call_expr { function: "_?_:_" args { id: 5 select_expr { field: "filter_key" operand { id: 6 ident_expr { name: "x" } } } } args { id: 7 call_expr { function: "_+_" args { id: 8 ident_expr { name: "__result__" } } args { id: 9 list_expr { elements { id: 10 select_expr { field: "value_key" operand { id: 12 ident_expr { name: "x" } } } } } } } } args { id: 13 ident_expr { name: "__result__" } } } } result { id: 14 ident_expr { name: "__result__" } } })pb"; TEST(UnknownsIterAttrTest, IterAttributeTrailExact) { InterpreterOptions options; Activation activation; Arena arena; ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("list_var.exists(x, x)")); protobuf::Value element; element.set_bool_value(false); protobuf::ListValue list; *list.add_values() = element; *list.add_values() = element; *list.add_values() = element; (*list.mutable_values())[0].set_bool_value(true); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); activation.InsertValue("list_var", CelProtoWrapper::CreateMessage(&list, &arena)); // list_var[0] std::vector unknown_attribute_patterns; unknown_attribute_patterns.push_back(CelAttributePattern( "list_var", {CreateCelAttributeQualifierPattern(CelValue::CreateInt64(0))})); activation.set_unknown_attribute_patterns( std::move(unknown_attribute_patterns)); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() .begin() ->qualifier_path() .size(), 1); } TEST(UnknownsIterAttrTest, IterAttributeTrailFilterValues) { InterpreterOptions options; Expr expr; Activation activation; Arena arena; protobuf::Value element; protobuf::Value* value = &element.mutable_struct_value()->mutable_fields()->operator[]( "filter_key"); value->set_bool_value(true); value = &element.mutable_struct_value()->mutable_fields()->operator[]( "value_key"); value->set_number_value(1.0); protobuf::ListValue list; *list.add_values() = element; *list.add_values() = element; *list.add_values() = element; options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kFilterElementsComp, &expr)) << "error parsing expr"; activation.InsertValue("var", CelProtoWrapper::CreateMessage(&list, &arena)); // var[1]['value_key'] is unknown activation.set_unknown_attribute_patterns({CelAttributePattern( "var", { CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)), CreateCelAttributeQualifierPattern( CelValue::CreateStringView("value_key")), })}); auto plan = builder->CreateExpression(&expr, nullptr).value(); CelValue response = plan->Evaluate(activation, &arena).value(); ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); // 'var[1].value_key' is unknown when we make the cons function call. // comprehension is: ((([] + [1]) + unk) + [1]) -> unk ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() .begin() ->qualifier_path() .size(), 2); } TEST(UnknownsIterAttrTest, IterAttributeTrailFilterConditions) { InterpreterOptions options; Expr expr; Activation activation; Arena arena; protobuf::Value element; protobuf::Value* value = &element.mutable_struct_value()->mutable_fields()->operator[]( "filter_key"); value->set_bool_value(true); value = &element.mutable_struct_value()->mutable_fields()->operator[]( "value_key"); value->set_number_value(1.0); protobuf::ListValue list; *list.add_values() = element; *list.add_values() = element; *list.add_values() = element; options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; auto builder = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterBuiltinFunctions(builder->GetRegistry()), IsOk()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kFilterElementsComp, &expr)) << "error parsing expr"; activation.InsertValue("var", CelProtoWrapper::CreateMessage(&list, &arena)); // var[1]['value_key'] is unknown activation.set_unknown_attribute_patterns( {CelAttributePattern( "var", { CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1)), CreateCelAttributeQualifierPattern( CelValue::CreateStringView("filter_key")), }), CelAttributePattern( "var", { CreateCelAttributeQualifierPattern(CelValue::CreateInt64(0)), CreateCelAttributeQualifierPattern( CelValue::CreateStringView("filter_key")), })}); auto plan = builder->CreateExpression(&expr, nullptr).value(); CelValue response = plan->Evaluate(activation, &arena).value(); // 'var[1].filter_key' is unknown when we make the ternary call. // Since the unknown is expressed in a conditional jump, the behavior is to // ignore the possible outcomes // loop0: (unk{0})? [] + [1] : [] -> unk{0} // loop1: (unk{1})? unk{0} + [1] : unk{0} -> unk{1} // loop2: (true)? unk{1} + [1] : unk{1} -> unk{1} // result: unk{1} ASSERT_TRUE(response.IsUnknownSet()) << CelValue::TypeName(response.type()); ASSERT_EQ(response.UnknownSetOrDie()->unknown_attributes().size(), 1); ASSERT_EQ(response.UnknownSetOrDie() ->unknown_attributes() .begin() ->qualifier_path() .size(), 2); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: eval/testutil/BUILD ================================================ load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") # This package contains testing utility code package(default_visibility = ["//visibility:public"]) licenses(["notice"]) proto_library( name = "test_message_proto", srcs = [ "test_message.proto", ], deps = [ "@com_google_protobuf//:any_proto", "@com_google_protobuf//:duration_proto", "@com_google_protobuf//:struct_proto", "@com_google_protobuf//:timestamp_proto", "@com_google_protobuf//:wrappers_proto", ], ) cc_proto_library( name = "test_message_cc_proto", deps = [":test_message_proto"], ) proto_library( name = "test_extensions_proto", srcs = [ "test_extensions.proto", ], deps = ["@com_google_protobuf//:wrappers_proto"], ) cc_proto_library( name = "test_extensions_cc_proto", deps = [":test_extensions_proto"], ) ================================================ FILE: eval/testutil/test_extensions.proto ================================================ syntax = "proto2"; package google.api.expr.runtime; import "google/protobuf/wrappers.proto"; option cc_enable_arenas = true; option java_multiple_files = true; enum TestExtEnum { TEST_EXT_UNSPECIFIED = 0; TEST_EXT_1 = 10; TEST_EXT_2 = 20; TEST_EXT_3 = 30; } // This proto is used to show how extensions are tracked as fields // with fully qualified names. message TestExtensions { optional string name = 1; extensions 100 to max; } // Package scoped extensions. extend TestExtensions { optional TestExtensions nested_ext = 100; optional int32 int32_ext = 101; optional google.protobuf.Int32Value int32_wrapper_ext = 102; } // Message scoped extensions. message TestMessageExtensions { extend TestExtensions { repeated string repeated_string_exts = 103; optional TestExtEnum enum_ext = 104; } } ================================================ FILE: eval/testutil/test_message.proto ================================================ syntax = "proto3"; package google.api.expr.runtime; import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/struct.proto"; import "google/protobuf/timestamp.proto"; import "google/protobuf/wrappers.proto"; option cc_enable_arenas = true; enum TestEnum { TEST_ENUM_UNSPECIFIED = 0; TEST_ENUM_1 = 10; TEST_ENUM_2 = 20; TEST_ENUM_3 = 30; } message TestMessage { int32 int32_value = 1; int64 int64_value = 2; uint32 uint32_value = 3; uint64 uint64_value = 4; float float_value = 5; double double_value = 6; string string_value = 7; string cord_value = 8 [ctype = CORD]; bytes bytes_value = 9; bool bool_value = 10; enum TestEnum { TEST_ENUM_UNSPECIFIED = 0; TEST_ENUM_1 = 1; TEST_ENUM_2 = 2; } TestEnum enum_value = 11; TestMessage message_value = 12; reserved 99; repeated int32 int32_list = 101; repeated int64 int64_list = 102; repeated uint32 uint32_list = 103; repeated uint64 uint64_list = 104; repeated float float_list = 105; repeated double double_list = 106; repeated string string_list = 107; repeated string cord_list = 108 [ctype = CORD]; repeated bytes bytes_list = 109; repeated bool bool_list = 110; repeated TestEnum enum_list = 111; repeated TestMessage message_list = 112; repeated google.protobuf.Timestamp timestamp_list = 113; map int64_int32_map = 201; map uint64_int32_map = 202; map string_int32_map = 203; map bool_int32_map = 204; map int32_int32_map = 205; map uint32_uint32_map = 206; map int32_float_map = 207; map int64_enum_map = 208; map string_timestamp_map = 209; map string_message_map = 210; map int64_timestamp_map = 211; // Well-known types. google.protobuf.Any any_value = 300; google.protobuf.Duration duration_value = 301; google.protobuf.Timestamp timestamp_value = 302; google.protobuf.Struct struct_value = 303; google.protobuf.Value value_value = 304; google.protobuf.Int64Value int64_wrapper_value = 305; google.protobuf.Int32Value int32_wrapper_value = 306; google.protobuf.DoubleValue double_wrapper_value = 307; google.protobuf.FloatValue float_wrapper_value = 308; google.protobuf.UInt64Value uint64_wrapper_value = 309; google.protobuf.UInt32Value uint32_wrapper_value = 310; google.protobuf.StringValue string_wrapper_value = 311; google.protobuf.BoolValue bool_wrapper_value = 312; google.protobuf.BytesValue bytes_wrapper_value = 313; } ================================================ FILE: extensions/BUILD ================================================ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) cc_library( name = "encoders", srcs = ["encoders.cc"], hdrs = ["encoders.h"], deps = [ "//checker:type_checker_builder", "//common:decl", "//common:type", "//common:value", "//compiler", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "encoders_test", srcs = ["encoders_test.cc"], deps = [ ":encoders", "//checker:standard_library", "//checker:validation_result", "//compiler", "//compiler:compiler_factory", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status:status_matchers", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "proto_ext", srcs = ["proto_ext.cc"], hdrs = ["proto_ext.h"], deps = [ "//common:expr", "//compiler", "//internal:status_macros", "//parser:macro", "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:parser_interface", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) cc_library( name = "math_ext", srcs = ["math_ext.cc"], hdrs = ["math_ext.h"], deps = [ ":math_ext_decls", "//common:casting", "//common:value", "//eval/public:cel_function_registry", "//eval/public:cel_number", "//eval/public:cel_options", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "math_ext_macros", srcs = ["math_ext_macros.cc"], hdrs = ["math_ext_macros.h"], deps = [ "//common:ast", "//common:constant", "//parser:macro", "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) cc_library( name = "math_ext_decls", srcs = ["math_ext_decls.cc"], hdrs = ["math_ext_decls.h"], deps = [ ":math_ext_macros", "//checker:type_checker_builder", "//checker/internal:builtins_arena", "//common:decl", "//common:type", "//common:type_kind", "//compiler", "//internal:status_macros", "//parser:parser_interface", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) cc_test( name = "math_ext_test", srcs = ["math_ext_test.cc"], deps = [ ":math_ext", ":math_ext_decls", ":math_ext_macros", "//checker:standard_library", "//checker:type_check_issue", "//checker:validation_result", "//common:decl", "//common:function_descriptor", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//eval/public/testing:matchers", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "//runtime:activation", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) # New users should use ":regex_ext" instead. cc_library( name = "regex_functions", srcs = ["regex_functions.cc"], hdrs = ["regex_functions.h"], deps = [ "//checker:type_checker_builder", "//checker/internal:builtins_arena", "//common:decl", "//common:type", "//common:value", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:re2_options", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", "@com_googlesource_code_re2//:re2", ], ) cc_library( name = "bindings_ext", srcs = ["bindings_ext.cc"], hdrs = ["bindings_ext.h"], deps = [ "//common:ast", "//compiler", "//internal:status_macros", "//parser:macro", "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:parser_interface", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_test( name = "regex_functions_test", srcs = [ "regex_functions_test.cc", ], deps = [ ":regex_functions", "//checker:standard_library", "//checker:validation_result", "//common:value", "//common:value_testing", "//compiler", "//compiler:compiler_factory", "//extensions/protobuf:runtime_adapter", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "//runtime", "//runtime:activation", "//runtime:reference_resolver", "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "bindings_ext_test", srcs = ["bindings_ext_test.cc"], deps = [ ":bindings_ext", "//base:attributes", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function", "//eval/public:cel_function_adapter", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//eval/public/testing:matchers", "//internal:testing", "//parser", "//parser:macro", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "bindings_ext_benchmark_test", srcs = ["bindings_ext_benchmark_test.cc"], tags = ["benchmark"], deps = [ ":bindings_ext", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/testing:matchers", "//internal:benchmark", "//internal:testing", "//parser", "//parser:macro", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_check", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "select_optimization", srcs = ["select_optimization.cc"], hdrs = ["select_optimization.h"], deps = [ "//base:attributes", "//base:builtins", "//common:ast", "//common:ast_rewrite", "//common:casting", "//common:constant", "//common:expr", "//common:function_descriptor", "//common:kind", "//common:native_type", "//common:type", "//common:value", "//eval/compiler:flat_expr_builder", "//eval/compiler:flat_expr_builder_extensions", "//eval/eval:attribute_trail", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/eval:expression_step_base", "//internal:casts", "//internal:number", "//internal:status_macros", "//runtime:runtime_builder", "//runtime/internal:errors", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "select_optimization_test", srcs = ["select_optimization_test.cc"], deps = [ ":select_optimization", "//base:ast", "//base:attributes", "//base:builtins", "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:decl", "//common:decl_proto", "//common:expr", "//common:kind", "//common:memory", "//common:value", "//compiler", "//compiler:compiler_factory", "//compiler:optional", "//compiler:standard_library", "//eval/compiler:flat_expr_builder", "//eval/compiler:flat_expr_builder_extensions", "//eval/compiler:resolver", "//eval/eval:evaluator_core", "//eval/internal:interop", "//eval/public:cel_type_registry", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "//eval/public/structs:legacy_type_adapter", "//eval/public/structs:legacy_type_info_apis", "//extensions/protobuf:ast_converters", "//internal:number", "//internal:status_macros", "//internal:testing", "//parser", "//runtime:activation", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime:type_registry", "//runtime/internal:issue_collector", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "lists_functions", srcs = ["lists_functions.cc"], hdrs = ["lists_functions.h"], deps = [ "//checker:type_checker_builder", "//checker/internal:builtins_arena", "//common:decl", "//common:expr", "//common:operators", "//common:type", "//common:value", "//common:value_kind", "//compiler", "//internal:status_macros", "//parser:macro", "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:parser_interface", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "lists_functions_test", srcs = ["lists_functions_test.cc"], deps = [ ":lists_functions", "//checker:type_check_issue", "//checker:validation_result", "//common:source", "//common:value", "//common:value_testing", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "//parser:macro_registry", "//parser:options", "//parser:standard_macros", "//runtime", "//runtime:activation", "//runtime:reference_resolver", "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "sets_functions", srcs = ["sets_functions.cc"], hdrs = ["sets_functions.h"], deps = [ "//base:function_adapter", "//checker:type_checker_builder", "//common:decl", "//common:type", "//common:value", "//compiler", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "sets_functions_test", srcs = ["sets_functions_test.cc"], deps = [ ":sets_functions", "//checker:standard_library", "//checker:validation_result", "//common:ast_proto", "//common:minimal_descriptor_pool", "//compiler:compiler_factory", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function_adapter", "//eval/public:cel_options", "//eval/public:cel_value", "//internal:testing", "//runtime:runtime_options", "@com_google_absl//absl/status:status_matchers", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "sets_functions_benchmark_test", srcs = ["sets_functions_benchmark_test.cc"], tags = ["benchmark"], deps = [ ":sets_functions", "//common:value", "//eval/internal:interop", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//internal:benchmark", "//internal:status_macros", "//internal:testing", "//parser", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "strings", srcs = ["strings.cc"], hdrs = ["strings.h"], deps = [ ":formatting", "//checker:type_checker_builder", "//checker/internal:builtins_arena", "//common:decl", "//common:type", "//common:value", "//compiler", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "strings_test", srcs = ["strings_test.cc"], deps = [ ":strings", "//checker:standard_library", "//checker:type_check_issue", "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:decl", "//common:type", "//common:value", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "//parser:options", "//runtime", "//runtime:activation", "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "//testutil:baseline_tests", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "comprehensions_v2_functions", srcs = ["comprehensions_v2_functions.cc"], hdrs = ["comprehensions_v2_functions.h"], deps = [ "//common:value", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "comprehensions_v2_macros", srcs = ["comprehensions_v2_macros.cc"], hdrs = ["comprehensions_v2_macros.h"], deps = [ "//common:expr", "//common:operators", "//compiler", "//internal:status_macros", "//parser:macro", "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:parser_interface", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_library( name = "comprehensions_v2", srcs = ["comprehensions_v2.cc"], hdrs = ["comprehensions_v2.h"], deps = [ ":comprehensions_v2_functions", ":comprehensions_v2_macros", "//checker:type_checker_builder", "//checker/internal:builtins_arena", "//common:decl", "//common:type", "//compiler", "//internal:status_macros", "//parser:macro_registry", "//parser:options", "//parser:parser_interface", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", ], ) cc_test( name = "comprehensions_v2_test", srcs = ["comprehensions_v2_test.cc"], deps = [ ":bindings_ext", ":comprehensions_v2", ":comprehensions_v2_functions", ":strings", "//checker:standard_library", "//checker:validation_result", "//common:value", "//common:value_testing", "//compiler:compiler_factory", "//compiler:optional", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", "//runtime:optional_types", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "formatting", srcs = ["formatting.cc"], hdrs = ["formatting.h"], deps = [ "//common:value", "//common:value_kind", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "regex_ext", srcs = ["regex_ext.cc"], hdrs = ["regex_ext.h"], deps = [ "//checker:type_checker_builder", "//checker/internal:builtins_arena", "//common:decl", "//common:type", "//common:value", "//compiler", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:casts", "//internal:re2_options", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_builder", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", "//validator", "//validator:regex_validator", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", "@com_googlesource_code_re2//:re2", ], ) cc_test( name = "regex_ext_test", srcs = ["regex_ext_test.cc"], deps = [ ":regex_ext", "//checker:standard_library", "//checker:validation_result", "//common:kind", "//common:value", "//common:value_testing", "//compiler", "//compiler:compiler_factory", "//eval/public:activation", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_function_registry", "//eval/public:cel_options", "//extensions/protobuf:runtime_adapter", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "//runtime", "//runtime:activation", "//runtime:optional_types", "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "//validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "formatting_test", srcs = ["formatting_test.cc"], deps = [ ":formatting", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:parse_text_proto", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//parser", "//parser:options", "//runtime", "//runtime:activation", "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: extensions/bindings_ext.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/bindings_ext.h" #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/ast.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "parser/macro_expr_factory.h" #include "parser/parser_interface.h" namespace cel::extensions { namespace { static constexpr char kCelNamespace[] = "cel"; static constexpr char kBind[] = "bind"; static constexpr char kUnusedIterVar[] = "#unused"; bool IsTargetNamespace(const Expr& target) { return target.has_ident_expr() && target.ident_expr().name() == kCelNamespace; } inline absl::Status ConfigureParser(ParserBuilder& parser_builder) { for (const Macro& macro : bindings_macros()) { CEL_RETURN_IF_ERROR(parser_builder.AddMacro(macro)); } return absl::OkStatus(); } } // namespace std::vector bindings_macros() { absl::StatusOr cel_bind = Macro::Receiver( kBind, 3, [](MacroExprFactory& factory, Expr& target, absl::Span args) -> absl::optional { if (!IsTargetNamespace(target)) { return absl::nullopt; } if (!args[0].has_ident_expr()) { return factory.ReportErrorAt( args[0], "cel.bind() variable name must be a simple identifier"); } auto var_name = args[0].ident_expr().name(); return factory.NewComprehension(kUnusedIterVar, factory.NewList(), std::move(var_name), std::move(args[1]), factory.NewBoolConst(false), std::move(args[0]), std::move(args[2])); }); return {*cel_bind}; } CompilerLibrary BindingsCompilerLibrary() { return CompilerLibrary("cel.lib.ext.bindings", &ConfigureParser); } } // namespace cel::extensions ================================================ FILE: extensions/bindings_ext.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_BINDINGS_EXT_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_BINDINGS_EXT_H_ #include #include "absl/status/status.h" #include "compiler/compiler.h" #include "parser/macro.h" #include "parser/macro_registry.h" #include "parser/options.h" namespace cel::extensions { // bindings_macros() returns a macro for cel.bind() which can be used to support // local variable bindings within expressions. std::vector bindings_macros(); inline absl::Status RegisterBindingsMacros(MacroRegistry& registry, const ParserOptions&) { return registry.RegisterMacros(bindings_macros()); } // Declarations for the bindings extension library. CompilerLibrary BindingsCompilerLibrary(); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_BINDINGS_EXT_H_ ================================================ FILE: extensions/bindings_ext_benchmark_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/testing/matchers.h" #include "extensions/bindings_ext.h" #include "internal/benchmark.h" #include "internal/testing.h" #include "parser/macro.h" #include "parser/parser.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::google::api::expr::parser::ParseWithMacros; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::test::CelValueMatcher; using ::google::api::expr::runtime::test::IsCelBool; using ::google::api::expr::runtime::test::IsCelString; struct BenchmarkCase { std::string name; std::string expression; CelValueMatcher matcher; }; const std::vector& BenchmarkCases() { static absl::NoDestructor> cases( std::vector{ {"simple", R"(cel.bind(x, "ab", x))", IsCelString("ab")}, {"multiple_references", R"(cel.bind(x, "ab", x + x + x + x))", IsCelString("abababab")}, {"nested", R"( cel.bind( x, "ab", cel.bind( y, "cd", x + y + "ef")))", IsCelString("abcdef")}, {"nested_defintion", R"( cel.bind( x, "ab", cel.bind( y, x + "cd", y + "ef" )))", IsCelString("abcdef")}, {"bind_outside_loop", R"( cel.bind( outer_value, [1, 2, 3], [3, 2, 1].all( value, value in outer_value) ))", IsCelBool(true)}, {"bind_inside_loop", R"( [3, 2, 1].all( x, cel.bind(value, x * x, value < 16) ))", IsCelBool(true)}, {"bind_loop_bind", R"( cel.bind( outer_value, {1: 2, 2: 3, 3: 4}, outer_value.all( key, cel.bind( value, outer_value[key], value == key + 1 ) )))", IsCelBool(true)}, {"ternary_depends_on_bind", R"( cel.bind( a, "ab", (true && a.startsWith("c")) ? a : "cd" ))", IsCelString("cd")}, {"ternary_does_not_depend_on_bind", R"( cel.bind( a, "ab", (false && a.startsWith("c")) ? a : "cd" ))", IsCelString("cd")}, {"twice_nested_defintion", R"( cel.bind( x, "ab", cel.bind( y, x + "cd", cel.bind( z, y + "ef", z))) )", IsCelString("abcdef")}, }); return *cases; } class BindingsBenchmarkTest : public ::testing::TestWithParam { protected: google::protobuf::Arena arena_; }; TEST_P(BindingsBenchmarkTest, CheckBenchmarkCaseWorks) { const BenchmarkCase& benchmark = GetParam(); std::vector all_macros = Macro::AllMacros(); std::vector bindings_macros = cel::extensions::bindings_macros(); all_macros.insert(all_macros.end(), bindings_macros.begin(), bindings_macros.end()); ASSERT_OK_AND_ASSIGN( auto expr, ParseWithMacros(benchmark.expression, all_macros, "")); InterpreterOptions options; auto builder = google::api::expr::runtime::CreateCelExpressionBuilder(options); ASSERT_OK(google::api::expr::runtime::RegisterBuiltinFunctions( builder->GetRegistry())); ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(CelValue result, program->Evaluate(activation, &arena)); EXPECT_THAT(result, benchmark.matcher); } void RunBenchmark(const BenchmarkCase& benchmark, benchmark::State& state) { std::vector all_macros = Macro::AllMacros(); std::vector bindings_macros = cel::extensions::bindings_macros(); all_macros.insert(all_macros.end(), bindings_macros.begin(), bindings_macros.end()); ASSERT_OK_AND_ASSIGN( auto expr, ParseWithMacros(benchmark.expression, all_macros, "")); InterpreterOptions options; auto builder = google::api::expr::runtime::CreateCelExpressionBuilder(options); ASSERT_OK(google::api::expr::runtime::RegisterBuiltinFunctions( builder->GetRegistry())); ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression( &expr.expr(), &expr.source_info())); Activation activation; google::protobuf::Arena arena; for (auto _ : state) { auto result = program->Evaluate(activation, &arena); benchmark::DoNotOptimize(result); ABSL_DCHECK_OK(result); ABSL_DCHECK(benchmark.matcher.Matches(*result)); } } void BM_Simple(benchmark::State& state) { RunBenchmark(BenchmarkCases()[0], state); } void BM_MultipleReferences(benchmark::State& state) { RunBenchmark(BenchmarkCases()[1], state); } void BM_Nested(benchmark::State& state) { RunBenchmark(BenchmarkCases()[2], state); } void BM_NestedDefinition(benchmark::State& state) { RunBenchmark(BenchmarkCases()[3], state); } void BM_BindOusideLoop(benchmark::State& state) { RunBenchmark(BenchmarkCases()[4], state); } void BM_BindInsideLoop(benchmark::State& state) { RunBenchmark(BenchmarkCases()[5], state); } void BM_BindLoopBind(benchmark::State& state) { RunBenchmark(BenchmarkCases()[6], state); } void BM_TernaryDependsOnBind(benchmark::State& state) { RunBenchmark(BenchmarkCases()[7], state); } void BM_TernaryDoesNotDependOnBind(benchmark::State& state) { RunBenchmark(BenchmarkCases()[8], state); } void BM_TwiceNestedDefinition(benchmark::State& state) { RunBenchmark(BenchmarkCases()[9], state); } BENCHMARK(BM_Simple); BENCHMARK(BM_MultipleReferences); BENCHMARK(BM_Nested); BENCHMARK(BM_NestedDefinition); BENCHMARK(BM_BindOusideLoop); BENCHMARK(BM_BindInsideLoop); BENCHMARK(BM_BindLoopBind); BENCHMARK(BM_TernaryDependsOnBind); BENCHMARK(BM_TernaryDoesNotDependOnBind); BENCHMARK(BM_TwiceNestedDefinition); INSTANTIATE_TEST_SUITE_P(BindingsBenchmarkTest, BindingsBenchmarkTest, ::testing::ValuesIn(BenchmarkCases())); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/bindings_ext_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/bindings_ext.h" #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/attribute.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "internal/testing.h" #include "parser/macro.h" #include "parser/parser.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto2::NestedTestAllTypes; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::google::api::expr::parser::ParseWithMacros; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelFunction; using ::google::api::expr::runtime::CelFunctionDescriptor; using ::google::api::expr::runtime::CelProtoWrapper; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::FunctionAdapter; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::api::expr::runtime::UnknownProcessingOptions; using ::google::api::expr::runtime::test::IsCelInt64; using ::google::protobuf::Arena; using ::google::protobuf::TextFormat; using ::testing::Contains; using ::testing::HasSubstr; using ::testing::Pair; struct TestInfo { std::string expr; std::string err = ""; }; class TestFunction : public CelFunction { public: explicit TestFunction(absl::string_view name) : CelFunction(CelFunctionDescriptor( name, true, {CelValue::Type::kBool, CelValue::Type::kBool, CelValue::Type::kBool, CelValue::Type::kBool})) {} absl::Status Evaluate(absl::Span args, CelValue* result, Arena* arena) const override { *result = CelValue::CreateBool(true); return absl::OkStatus(); } }; // Test function used to test macro collision and non-expansion. constexpr absl::string_view kBind = "bind"; std::unique_ptr CreateBindFunction() { return std::make_unique(kBind); } class BindingsExtTest : public testing::TestWithParam> { protected: const TestInfo& GetTestInfo() { return std::get<0>(GetParam()); } bool GetEnableConstantFolding() { return std::get<1>(GetParam()); } bool GetEnableRecursivePlan() { return std::get<2>(GetParam()); } }; TEST_P(BindingsExtTest, Default) { const TestInfo& test_info = GetTestInfo(); Arena arena; std::vector all_macros = Macro::AllMacros(); std::vector bindings_macros = cel::extensions::bindings_macros(); all_macros.insert(all_macros.end(), bindings_macros.begin(), bindings_macros.end()); auto result = ParseWithMacros(test_info.expr, all_macros, ""); if (!test_info.err.empty()) { EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_info.err))); return; } EXPECT_THAT(result, IsOk()); ParsedExpr parsed_expr = *result; Expr expr = parsed_expr.expr(); SourceInfo source_info = parsed_expr.source_info(); // Obtain CEL Expression builder. InterpreterOptions options; options.enable_heterogeneous_equality = true; options.enable_empty_wrapper_null_unboxing = true; options.constant_folding = GetEnableConstantFolding(); options.constant_arena = &arena; options.max_recursion_depth = GetEnableRecursivePlan() ? -1 : 0; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, &source_info)); Activation activation; // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsBool()) << out.DebugString(); EXPECT_EQ(out.BoolOrDie(), true); } TEST_P(BindingsExtTest, Tracing) { const TestInfo& test_info = GetTestInfo(); Arena arena; std::vector all_macros = Macro::AllMacros(); std::vector bindings_macros = cel::extensions::bindings_macros(); all_macros.insert(all_macros.end(), bindings_macros.begin(), bindings_macros.end()); auto result = ParseWithMacros(test_info.expr, all_macros, ""); if (!test_info.err.empty()) { EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_info.err))); return; } EXPECT_THAT(result, IsOk()); ParsedExpr parsed_expr = *result; Expr expr = parsed_expr.expr(); SourceInfo source_info = parsed_expr.source_info(); // Obtain CEL Expression builder. InterpreterOptions options; options.enable_heterogeneous_equality = true; options.enable_empty_wrapper_null_unboxing = true; options.constant_folding = GetEnableConstantFolding(); options.constant_arena = &arena; options.max_recursion_depth = GetEnableRecursivePlan() ? -1 : 0; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr, &source_info)); Activation activation; // Run evaluation. ASSERT_OK_AND_ASSIGN( CelValue out, cel_expr->Trace(activation, &arena, [](int64_t, const CelValue&, google::protobuf::Arena*) { return absl::OkStatus(); })); ASSERT_TRUE(out.IsBool()) << out.DebugString(); EXPECT_EQ(out.BoolOrDie(), true); } INSTANTIATE_TEST_SUITE_P( CelBindingsExtTest, BindingsExtTest, testing::Combine( testing::ValuesIn( {{"cel.bind(t, true, t)"}, {"cel.bind(msg, \"hello\", msg + msg + msg) == " "\"hellohellohello\""}, {"cel.bind(t1, true, cel.bind(t2, true, t1 && t2))"}, {"cel.bind(valid_elems, [1, 2, 3], " "[3, 4, 5].exists(e, e in valid_elems))"}, {"cel.bind(valid_elems, [1, 2, 3], " "![4, 5].exists(e, e in valid_elems))"}, // Implementation detail: bind variables and comprehension // variables get mapped to an int index in the same space. Check // that mixing them works. {R"( cel.bind( my_list, ['a', 'b', 'c'].map(x, x + '_'), [0, 1, 2].map(y, my_list[y] + string(y))) == ['a_0', 'b_1', 'c_2'])"}, // Check scoping rules. {"cel.bind(x, 1, " " cel.bind(x, x + 1, x)) == 2"}, // Testing a bound function with the same macro name, but non-cel // namespace. The function mirrors the macro signature, but just // returns true. {"false.bind(false, false, false)"}, // Error case where the variable name is not a simple identifier. {"cel.bind(bad.name, true, bad.name)", "variable name must be a simple identifier"}}), /*constant_folding*/ testing::Bool(), /*recursive_plan*/ testing::Bool())); constexpr absl::string_view kTraceExpr = R"pb( expr: { id: 11 comprehension_expr: { iter_var: "#unused" iter_range: { id: 8 list_expr: {} } accu_var: "x" accu_init: { id: 4 const_expr: { int64_value: 20 } } loop_condition: { id: 9 const_expr: { bool_value: false } } loop_step: { id: 10 ident_expr: { name: "x" } } result: { id: 6 call_expr: { function: "_*_" args: { id: 5 ident_expr: { name: "x" } } args: { id: 7 ident_expr: { name: "x" } } } } } })pb"; TEST(BindingsExtTest, TraceSupport) { ParsedExpr expr; ASSERT_TRUE(TextFormat::ParseFromString(kTraceExpr, &expr)); InterpreterOptions options; options.enable_heterogeneous_equality = true; options.enable_empty_wrapper_null_unboxing = true; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); ASSERT_OK_AND_ASSIGN( auto plan, builder->CreateExpression(&expr.expr(), &expr.source_info())); Activation activation; google::protobuf::Arena arena; absl::flat_hash_map ids; ASSERT_OK_AND_ASSIGN( auto result, plan->Trace(activation, &arena, [&](int64_t id, const CelValue& value, google::protobuf::Arena* arena) { ids[id] = value; return absl::OkStatus(); })); EXPECT_TRUE(result.IsInt64() && result.Int64OrDie() == 400) << result.DebugString(); EXPECT_THAT(ids, Contains(Pair(4, IsCelInt64(20)))); EXPECT_THAT(ids, Contains(Pair(7, IsCelInt64(20)))); } // Test bind expression with nested field selection. // // cel.bind(submsg, // msg.child.child, // (false) ? // TestAllTypes{single_int64: -42}.single_int64 : // submsg.payload.single_int64) constexpr absl::string_view kFieldSelectTestExpr = R"pb( reference_map: { key: 4 value: { name: "msg" } } reference_map: { key: 8 value: { overload_id: "conditional" } } reference_map: { key: 9 value: { name: "cel.expr.conformance.proto2.TestAllTypes" } } reference_map: { key: 13 value: { name: "submsg" } } reference_map: { key: 18 value: { name: "submsg" } } type_map: { key: 4 value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 5 value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 6 value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 7 value: { primitive: BOOL } } type_map: { key: 8 value: { primitive: INT64 } } type_map: { key: 9 value: { message_type: "cel.expr.conformance.proto2.TestAllTypes" } } type_map: { key: 11 value: { primitive: INT64 } } type_map: { key: 12 value: { primitive: INT64 } } type_map: { key: 13 value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 14 value: { message_type: "cel.expr.conformance.proto2.TestAllTypes" } } type_map: { key: 15 value: { primitive: INT64 } } type_map: { key: 16 value: { list_type: { elem_type: { dyn: {} } } } } type_map: { key: 17 value: { primitive: BOOL } } type_map: { key: 18 value: { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } type_map: { key: 19 value: { primitive: INT64 } } source_info: { location: "" line_offsets: 120 positions: { key: 1 value: 0 } positions: { key: 2 value: 8 } positions: { key: 3 value: 9 } positions: { key: 4 value: 17 } positions: { key: 5 value: 20 } positions: { key: 6 value: 26 } positions: { key: 7 value: 35 } positions: { key: 8 value: 42 } positions: { key: 9 value: 56 } positions: { key: 10 value: 69 } positions: { key: 11 value: 71 } positions: { key: 12 value: 75 } positions: { key: 13 value: 91 } positions: { key: 14 value: 97 } positions: { key: 15 value: 105 } positions: { key: 16 value: 8 } positions: { key: 17 value: 8 } positions: { key: 18 value: 8 } positions: { key: 19 value: 8 } macro_calls: { key: 19 value: { call_expr: { target: { id: 1 ident_expr: { name: "cel" } } function: "bind" args: { id: 3 ident_expr: { name: "submsg" } } args: { id: 6 select_expr: { operand: { id: 5 select_expr: { operand: { id: 4 ident_expr: { name: "msg" } } field: "child" } } field: "child" } } args: { id: 8 call_expr: { function: "_?_:_" args: { id: 7 const_expr: { bool_value: false } } args: { id: 12 select_expr: { operand: { id: 9 struct_expr: { message_name: "cel.expr.conformance.proto2.TestAllTypes" entries: { id: 10 field_key: "single_int64" value: { id: 11 const_expr: { int64_value: -42 } } } } } field: "single_int64" } } args: { id: 15 select_expr: { operand: { id: 14 select_expr: { operand: { id: 13 ident_expr: { name: "submsg" } } field: "payload" } } field: "single_int64" } } } } } } } } expr: { id: 19 comprehension_expr: { iter_var: "#unused" iter_range: { id: 16 list_expr: {} } accu_var: "submsg" accu_init: { id: 6 select_expr: { operand: { id: 5 select_expr: { operand: { id: 4 ident_expr: { name: "msg" } } field: "child" } } field: "child" } } loop_condition: { id: 17 const_expr: { bool_value: false } } loop_step: { id: 18 ident_expr: { name: "submsg" } } result: { id: 8 call_expr: { function: "_?_:_" args: { id: 7 const_expr: { bool_value: false } } args: { id: 12 select_expr: { operand: { id: 9 struct_expr: { message_name: "cel.expr.conformance.proto2.TestAllTypes" entries: { id: 10 field_key: "single_int64" value: { id: 11 const_expr: { int64_value: -42 } } } } } field: "single_int64" } } args: { id: 15 select_expr: { operand: { id: 14 select_expr: { operand: { id: 13 ident_expr: { name: "submsg" } } field: "payload" } } field: "single_int64" } } } } } })pb"; class BindingsExtInteractionsTest : public testing::TestWithParam { protected: bool GetEnableSelectOptimization() { return GetParam(); } }; TEST_P(BindingsExtInteractionsTest, SelectOptimization) { CheckedExpr expr; ASSERT_TRUE(TextFormat::ParseFromString(kFieldSelectTestExpr, &expr)); InterpreterOptions options; options.enable_empty_wrapper_null_unboxing = true; options.enable_select_optimization = GetEnableSelectOptimization(); std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr)); Arena arena; Activation activation; NestedTestAllTypes msg; msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsInt64()); EXPECT_EQ(out.Int64OrDie(), 42); } TEST_P(BindingsExtInteractionsTest, UnknownAttributesSelectOptimization) { CheckedExpr expr; ASSERT_TRUE(TextFormat::ParseFromString(kFieldSelectTestExpr, &expr)); InterpreterOptions options; options.enable_empty_wrapper_null_unboxing = true; options.unknown_processing = UnknownProcessingOptions::kAttributeOnly; options.enable_select_optimization = GetEnableSelectOptimization(); std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr)); Arena arena; Activation activation; activation.set_unknown_attribute_patterns({AttributePattern( "msg", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("child")})}); NestedTestAllTypes msg; msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsUnknownSet()); EXPECT_THAT(out.UnknownSetOrDie()->unknown_attributes(), testing::ElementsAre( Attribute("msg", {AttributeQualifier::OfString("child"), AttributeQualifier::OfString("child")}))); } TEST_P(BindingsExtInteractionsTest, UnknownAttributeSelectOptimizationReturnValue) { CheckedExpr expr; ASSERT_TRUE(TextFormat::ParseFromString(kFieldSelectTestExpr, &expr)); InterpreterOptions options; options.enable_empty_wrapper_null_unboxing = true; options.unknown_processing = UnknownProcessingOptions::kAttributeOnly; options.enable_select_optimization = GetEnableSelectOptimization(); std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr)); Arena arena; Activation activation; activation.set_unknown_attribute_patterns({AttributePattern( "msg", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("payload"), AttributeQualifierPattern::OfString("single_int64")})}); NestedTestAllTypes msg; msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsUnknownSet()) << out.DebugString(); EXPECT_THAT(out.UnknownSetOrDie()->unknown_attributes(), testing::ElementsAre(Attribute( "msg", {AttributeQualifier::OfString("child"), AttributeQualifier::OfString("child"), AttributeQualifier::OfString("payload"), AttributeQualifier::OfString("single_int64")}))); } TEST_P(BindingsExtInteractionsTest, MissingAttributesSelectOptimization) { CheckedExpr expr; ASSERT_TRUE(TextFormat::ParseFromString(kFieldSelectTestExpr, &expr)); InterpreterOptions options; options.enable_empty_wrapper_null_unboxing = true; options.enable_missing_attribute_errors = true; options.enable_select_optimization = GetEnableSelectOptimization(); std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&expr)); Arena arena; Activation activation; activation.set_missing_attribute_patterns({AttributePattern( "msg", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("payload"), AttributeQualifierPattern::OfString("single_int64")})}); NestedTestAllTypes msg; msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsError()) << out.DebugString(); EXPECT_THAT(out.ErrorOrDie()->ToString(), HasSubstr("msg.child.child.payload.single_int64")); } TEST_P(BindingsExtInteractionsTest, UnknownAttribute) { std::vector all_macros = Macro::AllMacros(); std::vector bindings_macros = cel::extensions::bindings_macros(); all_macros.insert(all_macros.end(), bindings_macros.begin(), bindings_macros.end()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithMacros( R"( cel.bind( x, msg.child.payload.single_int64, x < 42 || 1 == 1))", all_macros)); InterpreterOptions options; options.enable_empty_wrapper_null_unboxing = true; options.unknown_processing = UnknownProcessingOptions::kAttributeOnly; options.enable_select_optimization = GetEnableSelectOptimization(); std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( &expr.expr(), &expr.source_info())); Arena arena; Activation activation; activation.set_unknown_attribute_patterns({AttributePattern( "msg", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("payload"), AttributeQualifierPattern::OfString("single_int64")})}); NestedTestAllTypes msg; msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsBool()) << out.DebugString(); EXPECT_TRUE(out.BoolOrDie()); } TEST_P(BindingsExtInteractionsTest, UnknownAttributeReturnValue) { std::vector all_macros = Macro::AllMacros(); std::vector bindings_macros = cel::extensions::bindings_macros(); all_macros.insert(all_macros.end(), bindings_macros.begin(), bindings_macros.end()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithMacros( R"( cel.bind( x, msg.child.payload.single_int64, x))", all_macros)); InterpreterOptions options; options.enable_empty_wrapper_null_unboxing = true; options.unknown_processing = UnknownProcessingOptions::kAttributeOnly; options.enable_select_optimization = GetEnableSelectOptimization(); std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( &expr.expr(), &expr.source_info())); Arena arena; Activation activation; activation.set_unknown_attribute_patterns({AttributePattern( "msg", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("payload"), AttributeQualifierPattern::OfString("single_int64")})}); NestedTestAllTypes msg; msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsUnknownSet()) << out.DebugString(); EXPECT_THAT(out.UnknownSetOrDie()->unknown_attributes(), testing::ElementsAre(Attribute( "msg", {AttributeQualifier::OfString("child"), AttributeQualifier::OfString("payload"), AttributeQualifier::OfString("single_int64")}))); } TEST_P(BindingsExtInteractionsTest, MissingAttribute) { std::vector all_macros = Macro::AllMacros(); std::vector bindings_macros = cel::extensions::bindings_macros(); all_macros.insert(all_macros.end(), bindings_macros.begin(), bindings_macros.end()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithMacros( R"( cel.bind( x, msg.child.payload.single_int64, x < 42 || 1 == 2))", all_macros)); InterpreterOptions options; options.enable_empty_wrapper_null_unboxing = true; options.enable_missing_attribute_errors = true; options.enable_select_optimization = GetEnableSelectOptimization(); std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); // Register builtins and configure the execution environment. ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression( &expr.expr(), &expr.source_info())); Arena arena; Activation activation; activation.set_missing_attribute_patterns({AttributePattern( "msg", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("payload"), AttributeQualifierPattern::OfString("single_int64")})}); NestedTestAllTypes msg; msg.mutable_child()->mutable_child()->mutable_payload()->set_single_int64(42); activation.InsertValue("msg", CelProtoWrapper::CreateMessage(&msg, &arena)); // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsError()) << out.DebugString(); EXPECT_THAT(out.ErrorOrDie()->ToString(), HasSubstr("msg.child.payload.single_int64")); } INSTANTIATE_TEST_SUITE_P(BindingsExtInteractionsTest, BindingsExtInteractionsTest, /*enable_select_optimization=*/testing::Bool()); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/comprehensions_v2.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/comprehensions_v2.h" #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "checker/internal/builtins_arena.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "extensions/comprehensions_v2_macros.h" #include "internal/status_macros.h" #include "parser/parser_interface.h" using ::cel::checker_internal::BuiltinsArena; namespace cel::extensions { namespace { // Arbitrary type parameter name A. TypeParamType TypeParamA() { return TypeParamType("A"); } // Arbitrary type parameter name B. TypeParamType TypeParamB() { return TypeParamType("B"); } Type MapOfAB() { static absl::NoDestructor kInstance( MapType(BuiltinsArena(), TypeParamA(), TypeParamB())); return *kInstance; } absl::Status AddComprehensionsV2Functions(TypeCheckerBuilder& builder) { FunctionDecl map_insert; map_insert.set_name("cel.@mapInsert"); CEL_RETURN_IF_ERROR(map_insert.AddOverload( MakeOverloadDecl("@mapInsert_map_key_value", MapOfAB(), MapOfAB(), TypeParamA(), TypeParamB()))); CEL_RETURN_IF_ERROR(map_insert.AddOverload( MakeOverloadDecl("@mapInsert_map_map", MapOfAB(), MapOfAB(), MapOfAB()))); return builder.AddFunction(map_insert); } absl::Status ConfigureParser(ParserBuilder& parser_builder) { return RegisterComprehensionsV2Macros(parser_builder); } } // namespace CompilerLibrary ComprehensionsV2CompilerLibrary() { return CompilerLibrary("cel.lib.ext.comprev2", &ConfigureParser, &AddComprehensionsV2Functions); } CheckerLibrary ComprehensionsV2CheckerLibrary() { return CheckerLibrary{"cel.lib.ext.comprev2", &AddComprehensionsV2Functions}; } } // namespace cel::extensions ================================================ FILE: extensions/comprehensions_v2.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_H_ #include "absl/status/status.h" #include "checker/type_checker_builder.h" #include "compiler/compiler.h" #include "extensions/comprehensions_v2_functions.h" // IWYU pragma: export #include "parser/macro_registry.h" #include "parser/options.h" namespace cel::extensions { // Registers the macros defined by the comprehension v2 extension. absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, const ParserOptions& options); // Declarations for the comprehensions v2 extension library. CompilerLibrary ComprehensionsV2CompilerLibrary(); // Declarations for the comprehensions v2 extension library. CheckerLibrary ComprehensionsV2CheckerLibrary(); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_H_ ================================================ FILE: extensions/comprehensions_v2_functions.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/comprehensions_v2_functions.h" #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/value.h" #include "common/values/map_value_builder.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { absl::StatusOr MapInsertKeyValue( const MapValue& map, const Value& key, const Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (auto mutable_map_value = common_internal::AsMutableMapValue(map); mutable_map_value) { // Fast path, runtime has given us a mutable map. We can mutate it directly // and return it. CEL_RETURN_IF_ERROR(mutable_map_value->Put(key, value)) .With(ErrorValueReturn()); return map; } // Slow path, we have to make a copy. auto builder = NewMapValueBuilder(arena); if (auto size = map.Size(); size.ok()) { builder->Reserve(*size + 1); } else { size.IgnoreError(); } CEL_RETURN_IF_ERROR( map.ForEach( [&builder](const Value& key, const Value& value) -> absl::StatusOr { CEL_RETURN_IF_ERROR(builder->Put(key, value)); return true; }, descriptor_pool, message_factory, arena)) .With(ErrorValueReturn()); CEL_RETURN_IF_ERROR(builder->Put(key, value)).With(ErrorValueReturn()); return std::move(*builder).Build(); } absl::StatusOr MapInsertMap( const MapValue& map, const MapValue& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (auto mutable_map_value = common_internal::AsMutableMapValue(map); mutable_map_value) { // Fast path, runtime has given us a mutable map. We can mutate it directly // and return it. CEL_RETURN_IF_ERROR( value.ForEach( [&mutable_map_value](const Value& key, const Value& value) -> absl::StatusOr { CEL_RETURN_IF_ERROR(mutable_map_value->Put(key, value)); return true; }, descriptor_pool, message_factory, arena)) .With(ErrorValueReturn()); return map; } // Slow path, we have to make a copy. auto builder = NewMapValueBuilder(arena); if (auto size = map.Size(); size.ok()) { builder->Reserve(*size + 1); } else { size.IgnoreError(); } CEL_RETURN_IF_ERROR( map.ForEach( [&builder](const Value& key, const Value& value) -> absl::StatusOr { CEL_RETURN_IF_ERROR(builder->Put(key, value)); return true; }, descriptor_pool, message_factory, arena)) .With(ErrorValueReturn()); CEL_RETURN_IF_ERROR( value.ForEach( [&builder](const Value& key, const Value& value) -> absl::StatusOr { CEL_RETURN_IF_ERROR(builder->Put(key, value)); return true; }, descriptor_pool, message_factory, arena)) .With(ErrorValueReturn()); return std::move(*builder).Build(); } } // namespace absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, const RuntimeOptions& options) { CEL_RETURN_IF_ERROR(registry.Register( TernaryFunctionAdapter, MapValue, Value, Value>::CreateDescriptor("cel.@mapInsert", /*receiver_style=*/false), TernaryFunctionAdapter, MapValue, Value, Value>::WrapFunction(&MapInsertKeyValue))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, MapValue, MapValue>:: CreateDescriptor("cel.@mapInsert", /*receiver_style=*/false), BinaryFunctionAdapter, MapValue, MapValue>::WrapFunction(&MapInsertMap))); return absl::OkStatus(); } absl::Status RegisterComprehensionsV2Functions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options) { return RegisterComprehensionsV2Functions( registry->InternalGetRegistry(), google::api::expr::runtime::ConvertToRuntimeOptions(options)); } } // namespace cel::extensions ================================================ FILE: extensions/comprehensions_v2_functions.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel::extensions { // Register comprehension v2 functions. absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, const RuntimeOptions& options); absl::Status RegisterComprehensionsV2Functions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_FUNCTIONS_H_ ================================================ FILE: extensions/comprehensions_v2_macros.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/comprehensions_v2_macros.h" #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/expr.h" #include "common/operators.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser_interface.h" namespace cel::extensions { namespace { using ::google::api::expr::common::CelOperator; absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("all() requires 3 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "all() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "all() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt( args[0], "all() second variable must be different from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("all() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("all() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = factory.NewCall(CelOperator::NOT_STRICTLY_FALSE, factory.NewAccuIdent()); auto step = factory.NewCall(CelOperator::LOGICAL_AND, factory.NewAccuIdent(), std::move(args[2])); auto result = factory.NewAccuIdent(); return factory.NewComprehension( args[0].ident_expr().name(), args[1].ident_expr().name(), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } Macro MakeAllMacro2() { auto status_or_macro = Macro::Receiver(CelOperator::ALL, 3, ExpandAllMacro2); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("exists() requires 3 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "exists() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "exists() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt( args[0], "exists() second variable must be different from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("exists() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( CelOperator::NOT_STRICTLY_FALSE, factory.NewCall(CelOperator::LOGICAL_NOT, factory.NewAccuIdent())); auto step = factory.NewCall(CelOperator::LOGICAL_OR, factory.NewAccuIdent(), std::move(args[2])); auto result = factory.NewAccuIdent(); return factory.NewComprehension( args[0].ident_expr().name(), args[1].ident_expr().name(), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } Macro MakeExistsMacro2() { auto status_or_macro = Macro::Receiver(CelOperator::EXISTS, 3, ExpandExistsMacro2); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("existsOne() requires 3 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "existsOne() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "existsOne() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt( args[0], "existsOne() second variable must be different " "from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("existsOne() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("existsOne() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); auto step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), factory.NewCall(CelOperator::ADD, factory.NewAccuIdent(), factory.NewIntConst(1)), factory.NewAccuIdent()); auto result = factory.NewCall(CelOperator::EQUALS, factory.NewAccuIdent(), factory.NewIntConst(1)); return factory.NewComprehension( args[0].ident_expr().name(), args[1].ident_expr().name(), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } Macro MakeExistsOneMacro2() { auto status_or_macro = Macro::Receiver("existsOne", 3, ExpandExistsOneMacro2); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("transformList() requires 3 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "transformList() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "transformList() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt(args[0], "transformList() second variable must be " "different from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); Expr step = factory.NewCall( CelOperator::ADD, factory.NewAccuIdent(), factory.NewList(factory.NewListElement(std::move(args[2])))); return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), std::move(target), factory.AccuVarName(), factory.NewList(), factory.NewBoolConst(true), std::move(step), factory.NewAccuIdent()); } Macro MakeTransformList3Macro() { auto status_or_macro = Macro::Receiver("transformList", 3, ExpandTransformList3Macro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 4) { return factory.ReportError("transformList() requires 4 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "transformList() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "transformList() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt(args[0], "transformList() second variable must be " "different from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); Expr step = factory.NewCall( CelOperator::ADD, factory.NewAccuIdent(), factory.NewList(factory.NewListElement(std::move(args[3])))); step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), std::move(step), factory.NewAccuIdent()); return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), std::move(target), factory.AccuVarName(), factory.NewList(), factory.NewBoolConst(true), std::move(step), factory.NewAccuIdent()); } Macro MakeTransformList4Macro() { auto status_or_macro = Macro::Receiver("transformList", 4, ExpandTransformList4Macro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("transformMap() requires 3 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "transformMap() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "transformMap() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt(args[0], "transformMap() second variable must be " "different from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); Expr step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), std::move(args[0]), std::move(args[2])); return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), std::move(target), factory.AccuVarName(), factory.NewMap(), factory.NewBoolConst(true), std::move(step), factory.NewAccuIdent()); } Macro MakeTransformMap3Macro() { auto status_or_macro = Macro::Receiver("transformMap", 3, ExpandTransformMap3Macro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 4) { return factory.ReportError("transformMap() requires 4 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "transformMap() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "transformMap() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt(args[0], "transformMap() second variable must be " "different from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); Expr step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), std::move(args[0]), std::move(args[3])); step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), std::move(step), factory.NewAccuIdent()); return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), std::move(target), factory.AccuVarName(), factory.NewMap(), factory.NewBoolConst(true), std::move(step), factory.NewAccuIdent()); } Macro MakeTransformMap4Macro() { auto status_or_macro = Macro::Receiver("transformMap", 4, ExpandTransformMap4Macro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("transformMapEntry() requires 3 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "transformMapEntry() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "transformMapEntry() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt(args[0], "transformMapEntry() second variable must be " "different from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); Expr step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), std::move(args[2])); return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), std::move(target), factory.AccuVarName(), factory.NewMap(), factory.NewBoolConst(true), std::move(step), factory.NewAccuIdent()); } Macro MakeTransformMap3EntryMacro() { auto status_or_macro = Macro::Receiver("transformMapEntry", 3, ExpandTransformMapEntry3Macro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 4) { return factory.ReportError("transformMapEntry() requires 4 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "transformMapEntry() first variable name must be a simple identifier"); } if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { return factory.ReportErrorAt( args[1], "transformMapEntry() second variable name must be a simple identifier"); } if (args[0].ident_expr().name() == args[1].ident_expr().name()) { return factory.ReportErrorAt(args[0], "transformMapEntry() second variable must be " "different from the first variable"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", kDeprecatedAccumulatorVariableName)); } if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); Expr step = factory.NewCall("cel.@mapInsert", factory.NewAccuIdent(), std::move(args[3])); step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[2]), std::move(step), factory.NewAccuIdent()); return factory.NewComprehension(std::move(iter_var), std::move(iter_var2), std::move(target), factory.AccuVarName(), factory.NewMap(), factory.NewBoolConst(true), std::move(step), factory.NewAccuIdent()); } Macro MakeTransformMapEntry4Macro() { auto status_or_macro = Macro::Receiver("transformMapEntry", 4, ExpandTransformMapEntry4Macro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } const Macro& AllMacro2() { static const absl::NoDestructor macro(MakeAllMacro2()); return *macro; } const Macro& ExistsMacro2() { static const absl::NoDestructor macro(MakeExistsMacro2()); return *macro; } const Macro& ExistsOneMacro2() { static const absl::NoDestructor macro(MakeExistsOneMacro2()); return *macro; } const Macro& TransformList3Macro() { static const absl::NoDestructor macro(MakeTransformList3Macro()); return *macro; } const Macro& TransformList4Macro() { static const absl::NoDestructor macro(MakeTransformList4Macro()); return *macro; } const Macro& TransformMap3Macro() { static const absl::NoDestructor macro(MakeTransformMap3Macro()); return *macro; } const Macro& TransformMap4Macro() { static const absl::NoDestructor macro(MakeTransformMap4Macro()); return *macro; } const Macro& TransformMapEntry3Macro() { static const absl::NoDestructor macro(MakeTransformMap3EntryMacro()); return *macro; } const Macro& TransformMapEntry4Macro() { static const absl::NoDestructor macro(MakeTransformMapEntry4Macro()); return *macro; } } // namespace std::vector AllMacros() { return {AllMacro2(), ExistsMacro2(), ExistsOneMacro2(), TransformList3Macro(), TransformList4Macro(), TransformMap3Macro(), TransformMap4Macro(), TransformMapEntry3Macro(), TransformMapEntry4Macro()}; } // Registers the macros defined by the comprehension v2 extension. absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, const ParserOptions&) { for (const Macro& macro : AllMacros()) { CEL_RETURN_IF_ERROR(registry.RegisterMacro(macro)); } return absl::OkStatus(); } absl::Status RegisterComprehensionsV2Macros(ParserBuilder& parser_builder) { for (const Macro& macro : AllMacros()) { CEL_RETURN_IF_ERROR(parser_builder.AddMacro(macro)); } return absl::OkStatus(); } } // namespace cel::extensions ================================================ FILE: extensions/comprehensions_v2_macros.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ #include "absl/status/status.h" #include "compiler/compiler.h" #include "parser/macro_registry.h" #include "parser/options.h" namespace cel::extensions { // Registers the macros defined by the comprehension v2 extension. absl::Status RegisterComprehensionsV2Macros(MacroRegistry& registry, const ParserOptions& options); // Registers the macros defined by the comprehension v2 extension. absl::Status RegisterComprehensionsV2Macros(ParserBuilder& parser_builder); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_COMPREHENSIONS_V2_MACROS_H_ ================================================ FILE: extensions/comprehensions_v2_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/comprehensions_v2.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "checker/standard_library.h" #include "checker/validation_result.h" #include "common/value_testing.h" #include "common/values/list_value_builder.h" #include "common/values/map_value_builder.h" #include "compiler/compiler_factory.h" #include "compiler/optional.h" #include "extensions/bindings_ext.h" #include "extensions/comprehensions_v2_functions.h" #include "extensions/strings.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/optional_types.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::ErrorValueIs; using ::testing::HasSubstr; using ::testing::TestWithParam; absl::StatusOr> CreateProgram( const std::string& expression, bool enable_mutable_accumulator, int max_recursion_depth) { // Configure the compiler CEL_ASSIGN_OR_RETURN( auto compiler_builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(OptionalCompilerLibrary())); CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(BindingsCompilerLibrary())); CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(StringsCompilerLibrary())); CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary( extensions::ComprehensionsV2CompilerLibrary())); CEL_ASSIGN_OR_RETURN(auto compiler, std::move(*compiler_builder).Build()); // Configure the runtime cel::RuntimeOptions options; options.enable_qualified_type_identifiers = true; options.enable_comprehension_list_append = enable_mutable_accumulator; options.enable_comprehension_mutable_map = enable_mutable_accumulator; options.max_recursion_depth = max_recursion_depth; CEL_ASSIGN_OR_RETURN(auto runtime_builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); CEL_RETURN_IF_ERROR(EnableOptionalTypes(runtime_builder)); CEL_RETURN_IF_ERROR( RegisterStringsFunctions(runtime_builder.function_registry(), options)); CEL_RETURN_IF_ERROR(RegisterComprehensionsV2Functions( runtime_builder.function_registry(), options)); CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, std::move(runtime_builder).Build()); CEL_ASSIGN_OR_RETURN(ValidationResult result, compiler->Compile(expression)); if (!result.IsValid()) { return absl::Status(absl::StatusCode::kInvalidArgument, result.FormatError()); } return runtime->CreateProgram(*result.ReleaseAst()); } struct TestOptions { bool enable_mutable_accumulator; int max_recursion_depth; }; struct ComprehensionsV2TestCase { std::string expression; absl::StatusCode expected_status_code = absl::StatusCode::kOk; std::string expected_error; }; class ComprehensionsV2Test : public TestWithParam> { }; TEST_P(ComprehensionsV2Test, Basic) { const ComprehensionsV2TestCase& test_case = std::get<0>(GetParam()); const TestOptions& options = std::get<1>(GetParam()); absl::StatusOr> program = CreateProgram(test_case.expression, options.enable_mutable_accumulator, options.max_recursion_depth); if (!program.ok()) { EXPECT_THAT(program, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_case.expected_error))); // The error is expected. Nothing more to do in this test case return; } ASSERT_THAT(program, IsOk()); google::protobuf::Arena arena; Activation activation; if (test_case.expected_status_code == absl::StatusCode::kOk) { EXPECT_THAT(program.value()->Evaluate(&arena, activation), IsOkAndHolds(BoolValueIs(true))) << test_case.expression; } else { EXPECT_THAT(program.value()->Evaluate(&arena, activation), IsOkAndHolds(ErrorValueIs(StatusIs( test_case.expected_status_code, test_case.expected_error)))) << test_case.expression; } } INSTANTIATE_TEST_SUITE_P( ComprehensionsV2Test, ComprehensionsV2Test, ::testing::Combine( ::testing::ValuesIn({ // list.all() {.expression = "[1, 2, 3, 4].all(i, v, i < 5 && v > 0)"}, {.expression = "[1, 2, 3, 4].all(i, v, i < v)"}, {.expression = "[1, 2, 3, 4].all(i, v, i > v) == false"}, { .expression = R"cel(cel.bind(listA, [1, 2, 3, 4], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v, listB[?i].hasValue() && listB[i] == v))))cel", }, { .expression = R"cel(cel.bind(listA, [1, 2, 3, 4, 5, 6], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v, listB[?i].hasValue() && listB[i] == v))) == false)cel", }, { .expression = "[].all(__result__, v, v == 0)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].all(__result__, v, v == 0)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].all(i, __result__, i == 0)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].all(e, e, e == e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "[].all(foo.bar, e, true)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "[].all(e, foo.bar, true)", .expected_error = "second variable name must be a simple identifier", }, // list.exists() { .expression = R"cel(cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.exists(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false))))cel", }, { .expression = "[].exists(__result__, v, v == 0)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].exists(i, __result__, i == 0)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].exists(e, e, e == e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "[].exists(foo.bar, e, true)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "[].exists(e, foo.bar, true)", .expected_error = "second variable name must be a simple identifier", }, // list.existsOne() { .expression = R"cel(cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.existsOne(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next.endsWith('world')).orValue(false))))cel", }, { .expression = R"cel(cel.bind(l, ['hello', 'goodbye', 'hello!', 'goodbye'], l.existsOne(i, v, v.startsWith('hello') && l[?(i+1)].optMap(next, next == 'goodbye').orValue(false))) == false)cel", }, { .expression = "[].existsOne(__result__, v, v == 0)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].existsOne(i, __result__, i == 0)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].existsOne(e, e, e == e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "[].existsOne(foo.bar, e, true)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "[].existsOne(e, foo.bar, true)", .expected_error = "second variable name must be a simple identifier", }, // list.transformList() { .expression = R"cel(['Hello', 'world'].transformList(i, v, '[' + string(i) + ']' + v.lowerAscii()) == ['[0]hello', '[1]world'])cel", }, { .expression = R"cel(['hello', 'world'].transformList(i, v, v.startsWith('greeting'), '[' + string(i) + ']' + v) == [])cel", }, { .expression = R"cel([1, 2, 3].transformList(indexVar, valueVar, (indexVar * valueVar) + valueVar) == [1, 4, 9])cel", }, { .expression = R"cel([1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == [1, 9])cel", }, { .expression = "[].transformList(__result__, v, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].transformList(i, __result__, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].transformList(e, e, e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "[].transformList(foo.bar, e, e)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "[].transformList(e, foo.bar, e)", .expected_error = "second variable name must be a simple identifier", }, { .expression = "[].transformList(__result__, v, v == 0, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].transformList(i, __result__, i == 0, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "[].transformList(e, e, e == e, e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "[].transformList(foo.bar, e, true, e)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "[].transformList(e, foo.bar, true, e)", .expected_error = "second variable name must be a simple identifier", }, // list.transformMap() { .expression = R"cel(['Hello', 'world'].transformMap(i, v, [v.lowerAscii()]) == {0: ['hello'], 1: ['world']})cel", }, { .expression = R"cel([1, 2, 3].transformMap(indexVar, valueVar, (indexVar * valueVar) + valueVar) == {0: 1, 1: 4, 2: 9})cel", }, { .expression = R"cel([1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0, (indexVar * valueVar) + valueVar) == {0: 1, 2: 9})cel", }, // map.all() { .expression = R"cel({'hello': 'world', 'hello!': 'world'}.all(k, v, k.startsWith('hello') && v == 'world'))cel", }, { .expression = R"cel({'hello': 'world', 'hello!': 'worlds'}.all(k, v, k.startsWith('hello') && v.endsWith('world')) == false)cel", }, // map.exists() { .expression = R"cel({'hello': 'world', 'hello!': 'worlds'}.exists(k, v, k.startsWith('hello') && v.endsWith('world')))cel", }, // map.existsOne() { .expression = R"cel({'hello': 'world', 'hello!': 'worlds'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')))cel", }, { .expression = R"cel({'hello': 'world', 'hello!': 'wow, world'}.existsOne(k, v, k.startsWith('hello') && v.endsWith('world')) == false)cel", }, // map.transformList() { .expression = R"cel({'Hello': 'world'}.transformList(k, v, k.lowerAscii() + "=" + v) == ['hello=world'])cel", }, { .expression = R"cel({'hello': 'world'}.transformList(k, v, k.startsWith('greeting'), k + "=" + v) == [])cel", }, { .expression = R"cel(cel.bind(m, {'farewell': 'goodbye', 'greeting': 'hello'}.transformList(k, _, k), m == ['farewell', 'greeting'] || m == ['greeting', 'farewell']))cel", }, { .expression = R"cel(cel.bind(m, {'greeting': 'hello', 'farewell': 'goodbye'}.transformList(_, v, v), m == ['goodbye', 'hello'] || m == ['hello', 'goodbye']))cel", }, // map.transformMap() { .expression = R"cel({'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, k + ', ' + v + '!') == {'hello': 'hello, world!', 'goodbye': 'goodbye, cruel world!'})cel", }, { .expression = R"cel({'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, v.startsWith('world'), k + ", " + v + "!") == {'hello': 'hello, world!'})cel", }, { .expression = "{}.transformMap(__result__, v, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "{}.transformMap(k, __result__, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "{}.transformMap(e, e, e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "{}.transformMap(foo.bar, e, e)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "{}.transformMap(e, foo.bar, e)", .expected_error = "second variable name must be a simple identifier", }, { .expression = "{}.transformMap(__result__, v, v == 0, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "{}.transformMap(k, __result__, k == 0, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "{}.transformMap(e, e, e == e, e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "{}.transformMap(foo.bar, e, true, e)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "{}.transformMap(e, foo.bar, true, e)", .expected_error = "second variable name must be a simple identifier", }, // map.transformMapEntry { .expression = R"cel({'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, {v: k}) == {'world': 'hello', 'tacocat': 'greetings'})cel", }, { .expression = R"cel({'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, {}) == {})cel", }, { .expression = R"cel({'a': 'same', 'c': 'same'}.transformMapEntry(k, v, {v: k}))cel", .expected_status_code = absl::StatusCode::kAlreadyExists, .expected_error = "duplicate key in map", }, { .expression = "{}.transformMapEntry(__result__, v, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "{}.transformMapEntry(k, __result__, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "{}.transformMapEntry(e, e, e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "{}.transformMapEntry(foo.bar, e, e)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "{}.transformMapEntry(e, foo.bar, e)", .expected_error = "second variable name must be a simple identifier", }, // transformMapEntry(k, v, filter, expr) { .expression = R"cel({'hello': 'world', 'same': 'same'}.transformMapEntry(k, v, k != v, {v: k}) == {'world': 'hello'})cel", }, { .expression = "{}.transformMapEntry(__result__, v, v == 0, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "{}.transformMapEntry(k, __result__, k == 0, v)", .expected_error = "variable name cannot be __result__", }, { .expression = "{}.transformMapEntry(e, e, e == e, e)", .expected_error = "second variable must be different from the first variable", }, { .expression = "{}.transformMapEntry(foo.bar, e, true, e)", .expected_error = "first variable name must be a simple identifier", }, { .expression = "{}.transformMapEntry(e, foo.bar, true, e)", .expected_error = "second variable name must be a simple identifier", }, // list.transformMapEntry { .expression = R"cel(['one', 'two'].transformMapEntry(k, v, {k + 1: 'is ' + v}) == {1: 'is one', 2: 'is two'})cel", }, }), ::testing::ValuesIn({ { .enable_mutable_accumulator = true, .max_recursion_depth = 0, }, { .enable_mutable_accumulator = false, .max_recursion_depth = 0, }, { .enable_mutable_accumulator = true, .max_recursion_depth = -1, }, { .enable_mutable_accumulator = false, .max_recursion_depth = -1, }, }))); class ComprehensionsV2TestMutableAccumulator : public TestWithParam> { }; TEST_P(ComprehensionsV2TestMutableAccumulator, MutableAccumulator) { const ComprehensionsV2TestCase& test_case = std::get<0>(GetParam()); const TestOptions& options = std::get<1>(GetParam()); ASSERT_OK_AND_ASSIGN( std::unique_ptr program, CreateProgram(test_case.expression, options.enable_mutable_accumulator, options.max_recursion_depth)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(&arena, activation)); bool is_mutable_accumulator = common_internal::IsMutableListValue(result) || common_internal::IsMutableMapValue(result); EXPECT_EQ(is_mutable_accumulator, options.enable_mutable_accumulator); } INSTANTIATE_TEST_SUITE_P( ComprehensionsV2Test, ComprehensionsV2TestMutableAccumulator, ::testing::Combine( ::testing::ValuesIn({ {.expression = R"cel(['Hello', 'world'].transformList(i, v, i))cel"}, { .expression = R"cel({'hello': 'world'}.transformMap(k, v, k + v))cel", }, { .expression = R"cel(['hello', 'world'].transformMap(k, v, v))cel", }, { .expression = R"cel({'hello': 'world'}.transformMapEntry(k, v, {v: k}))cel", }, { .expression = R"cel(['hello', 'world'].transformMapEntry(k, v, {v: k}))cel", }, }), ::testing::ValuesIn({ { .enable_mutable_accumulator = true, .max_recursion_depth = 0, }, { .enable_mutable_accumulator = false, .max_recursion_depth = 0, }, { .enable_mutable_accumulator = true, .max_recursion_depth = -1, }, { .enable_mutable_accumulator = false, .max_recursion_depth = -1, }, }))); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/encoders.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/encoders.h" #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/escaping.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { absl::StatusOr Base64Decode( const StringValue& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string in; std::string out; if (!absl::Base64Unescape(value.NativeString(in), &out)) { return ErrorValue{absl::InvalidArgumentError("invalid base64 data")}; } return BytesValue(arena, std::move(out)); } absl::StatusOr Base64Encode( const BytesValue& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string in; std::string out; out = absl::Base64Escape(value.NativeString(in)); return StringValue(arena, std::move(out)); } absl::Status RegisterEncodersDecls(TypeCheckerBuilder& builder) { CEL_ASSIGN_OR_RETURN( auto base64_decode_decl, MakeFunctionDecl( "base64.decode", MakeOverloadDecl("base64_decode_string", BytesType(), StringType()))); CEL_ASSIGN_OR_RETURN( auto base64_encode_decl, MakeFunctionDecl( "base64.encode", MakeOverloadDecl("base64_encode_bytes", StringType(), BytesType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(base64_decode_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(base64_encode_decl)); return absl::OkStatus(); } } // namespace absl::Status RegisterEncodersFunctions(FunctionRegistry& registry, const RuntimeOptions&) { CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, StringValue>::CreateDescriptor("base64.decode", false), UnaryFunctionAdapter, StringValue>::WrapFunction( &Base64Decode))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, BytesValue>::CreateDescriptor( "base64.encode", false), UnaryFunctionAdapter, BytesValue>::WrapFunction( &Base64Encode))); return absl::OkStatus(); } absl::Status RegisterEncodersFunctions( google::api::expr::runtime::CelFunctionRegistry* absl_nonnull registry, const google::api::expr::runtime::InterpreterOptions& options) { return RegisterEncodersFunctions( registry->InternalGetRegistry(), google::api::expr::runtime::ConvertToRuntimeOptions(options)); } CheckerLibrary EncodersCheckerLibrary() { return {"cel.lib.ext.encoders", &RegisterEncodersDecls}; } CompilerLibrary EncodersCompilerLibrary() { return CompilerLibrary::FromCheckerLibrary(EncodersCheckerLibrary()); } } // namespace cel::extensions ================================================ FILE: extensions/encoders.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_ENCODERS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_ENCODERS_H_ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "checker/type_checker_builder.h" #include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel::extensions { // Register encoders functions. absl::Status RegisterEncodersFunctions(FunctionRegistry& registry, const RuntimeOptions& options); absl::Status RegisterEncodersFunctions( google::api::expr::runtime::CelFunctionRegistry* absl_nonnull registry, const google::api::expr::runtime::InterpreterOptions& options); // Declarations for the encoders extension library. CheckerLibrary EncodersCheckerLibrary(); // Compiler library for the encoders extension. CompilerLibrary EncodersCompilerLibrary(); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_ENCODERS_H_ ================================================ FILE: extensions/encoders_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/encoders.h" #include #include #include #include "absl/status/status_matchers.h" #include "checker/standard_library.h" #include "checker/validation_result.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; struct TestCase { std::string expr; }; class EncodersTest : public ::testing::TestWithParam {}; TEST_P(EncodersTest, ParseCheckEval) { const TestCase& test_case = GetParam(); // Configure the compiler. ASSERT_OK_AND_ASSIGN( auto compiler_builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT( compiler_builder->AddLibrary(extensions::EncodersCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(*compiler_builder).Build()); // Configure the runtime. cel::RuntimeOptions runtime_options; ASSERT_OK_AND_ASSIGN( auto runtime_builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), runtime_options)); ASSERT_THAT(RegisterEncodersFunctions(runtime_builder.function_registry(), runtime_options), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, std::move(runtime_builder).Build()); // Compile, plan, evaluate. ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(test_case.expr)); ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(*result.ReleaseAst())); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value.IsBool()); ASSERT_TRUE(value.GetBool()); } INSTANTIATE_TEST_SUITE_P( EncodersTest, EncodersTest, testing::Values(TestCase{"base64.encode(b'hello') == 'aGVsbG8='"}, TestCase{"base64.decode('aGVsbG8=') == b'hello'"})); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/formatting.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/formatting.h" #include #include #include #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/btree_map.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/value.h" #include "common/value_kind.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { static constexpr int32_t kNanosPerMillisecond = 1000000; static constexpr int32_t kNanosPerMicrosecond = 1000; static constexpr int kMaxPrecision = 1000; absl::StatusOr FormatString( const Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); absl::StatusOr>> ParsePrecision( absl::string_view format, int max_precision) { if (format.empty() || format[0] != '.') return std::pair{0, std::nullopt}; int64_t i = 1; while (i < format.size() && absl::ascii_isdigit(format[i])) { ++i; } if (i == format.size()) { return absl::InvalidArgumentError( "unable to find end of precision specifier"); } int precision; if (!absl::SimpleAtoi(format.substr(1, i - 1), &precision)) { return absl::InvalidArgumentError( "unable to convert precision specifier to integer"); } if (precision > max_precision) { return absl::InvalidArgumentError( absl::StrCat("precision specifier exceeds maximum of ", max_precision)); } return std::pair{i, precision}; } absl::StatusOr FormatDuration( const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { absl::Duration duration = value.GetDuration(); if (duration == absl::ZeroDuration()) { return "0s"; } if (duration < absl::ZeroDuration()) { scratch.append("-"); duration = absl::AbsDuration(duration); } int64_t seconds = absl::ToInt64Seconds(duration); absl::StrAppend(&scratch, seconds); int64_t nanos = absl::ToInt64Nanoseconds(duration - absl::Seconds(seconds)); if (nanos != 0) { scratch.append("."); if (nanos % kNanosPerMillisecond == 0) { scratch.append(absl::StrFormat("%03d", nanos / kNanosPerMillisecond)); } else if (nanos % kNanosPerMicrosecond == 0) { scratch.append(absl::StrFormat("%06d", nanos / kNanosPerMicrosecond)); } else { scratch.append(absl::StrFormat("%09d", nanos)); } } scratch.append("s"); return scratch; } absl::StatusOr FormatDouble( double value, std::optional precision, bool use_scientific_notation, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { static constexpr int kDefaultPrecision = 6; if (std::isnan(value)) { return "NaN"; } else if (value == std::numeric_limits::infinity()) { return "Infinity"; } else if (value == -std::numeric_limits::infinity()) { return "-Infinity"; } auto format = absl::StrCat("%.", precision.value_or(kDefaultPrecision), use_scientific_notation ? "e" : "f"); if (use_scientific_notation) { scratch = absl::StrFormat(*absl::ParsedFormat<'e'>::New(format), value); } else { scratch = absl::StrFormat(*absl::ParsedFormat<'f'>::New(format), value); } return scratch; } absl::StatusOr FormatList( const Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { CEL_ASSIGN_OR_RETURN(auto it, value.GetList().NewIterator()); scratch.clear(); scratch.push_back('['); std::string value_scratch; while (it->HasNext()) { CEL_ASSIGN_OR_RETURN(auto next, it->Next(descriptor_pool, message_factory, arena)); absl::string_view next_str; value_scratch.clear(); CEL_ASSIGN_OR_RETURN( next_str, FormatString(next, descriptor_pool, message_factory, arena, value_scratch)); absl::StrAppend(&scratch, next_str); absl::StrAppend(&scratch, ", "); } if (scratch.size() > 1) { scratch.resize(scratch.size() - 2); } scratch.push_back(']'); return scratch; } absl::StatusOr FormatMap( const Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { absl::btree_map value_map; std::string value_scratch; CEL_RETURN_IF_ERROR(value.GetMap().ForEach( [&](const Value& key, const Value& value) -> absl::StatusOr { if (key.kind() != ValueKind::kString && key.kind() != ValueKind::kBool && key.kind() != ValueKind::kInt && key.kind() != ValueKind::kUint) { return absl::InvalidArgumentError( absl::StrCat("map keys must be strings, booleans, integers, or " "unsigned integers, was given ", key.GetTypeName())); } value_scratch.clear(); CEL_ASSIGN_OR_RETURN(auto key_str, FormatString(key, descriptor_pool, message_factory, arena, value_scratch)); value_map.emplace(key_str, value); return true; }, descriptor_pool, message_factory, arena)); scratch.clear(); scratch.push_back('{'); for (const auto& [key, value] : value_map) { value_scratch.clear(); CEL_ASSIGN_OR_RETURN(auto value_str, FormatString(value, descriptor_pool, message_factory, arena, value_scratch)); absl::StrAppend(&scratch, key, ": "); absl::StrAppend(&scratch, value_str); absl::StrAppend(&scratch, ", "); } if (scratch.size() > 1) { scratch.resize(scratch.size() - 2); } scratch.push_back('}'); return scratch; } absl::StatusOr FormatString( const Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { switch (value.kind()) { case ValueKind::kList: return FormatList(value, descriptor_pool, message_factory, arena, scratch); case ValueKind::kMap: return FormatMap(value, descriptor_pool, message_factory, arena, scratch); case ValueKind::kString: return value.GetString().NativeString(scratch); case ValueKind::kBytes: return value.GetBytes().NativeString(scratch); case ValueKind::kNull: return "null"; case ValueKind::kInt: absl::StrAppend(&scratch, value.GetInt().NativeValue()); return scratch; case ValueKind::kUint: absl::StrAppend(&scratch, value.GetUint().NativeValue()); return scratch; case ValueKind::kDouble: { auto number = value.GetDouble().NativeValue(); if (std::isnan(number)) { return "NaN"; } if (number == std::numeric_limits::infinity()) { return "Infinity"; } if (number == -std::numeric_limits::infinity()) { return "-Infinity"; } absl::StrAppend(&scratch, number); return scratch; } case ValueKind::kTimestamp: absl::StrAppend(&scratch, value.DebugString()); return scratch; case ValueKind::kDuration: return FormatDuration(value, scratch); case ValueKind::kBool: if (value.GetBool().NativeValue()) { return "true"; } return "false"; case ValueKind::kType: return value.GetType().name(); default: return absl::InvalidArgumentError(absl::StrFormat( "could not convert argument %s to string", value.GetTypeName())); } } absl::StatusOr FormatDecimal( const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { scratch.clear(); switch (value.kind()) { case ValueKind::kInt: absl::StrAppend(&scratch, value.GetInt().NativeValue()); return scratch; case ValueKind::kUint: absl::StrAppend(&scratch, value.GetUint().NativeValue()); return scratch; case ValueKind::kDouble: return FormatDouble(value.GetDouble().NativeValue(), /*precision=*/std::nullopt, /*use_scientific_notation=*/false, scratch); default: return absl::InvalidArgumentError( absl::StrCat("decimal clause can only be used on numbers, was given ", value.GetTypeName())); } } absl::StatusOr FormatBinary( const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { decltype(value.GetUint().NativeValue()) unsigned_value; bool sign_bit = false; switch (value.kind()) { case ValueKind::kInt: { auto tmp = value.GetInt().NativeValue(); if (tmp < 0) { sign_bit = true; // Negating min int is undefined behavior, so we need to use unsigned // arithmetic. using unsigned_type = std::make_unsigned::type; unsigned_value = -static_cast(tmp); } else { unsigned_value = tmp; } break; } case ValueKind::kUint: unsigned_value = value.GetUint().NativeValue(); break; case ValueKind::kBool: if (value.GetBool().NativeValue()) { return "1"; } return "0"; default: return absl::InvalidArgumentError(absl::StrCat( "binary clause can only be used on integers and bools, was given ", value.GetTypeName())); } if (unsigned_value == 0) { return "0"; } int size = absl::bit_width(unsigned_value) + sign_bit; scratch.resize(size); for (int i = size - 1; i >= 0; --i) { if (unsigned_value & 1) { scratch[i] = '1'; } else { scratch[i] = '0'; } unsigned_value >>= 1; } if (sign_bit) { scratch[0] = '-'; } return scratch; } absl::StatusOr FormatHex( const Value& value, bool use_upper_case, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { switch (value.kind()) { case ValueKind::kString: scratch = absl::BytesToHexString(value.GetString().NativeString(scratch)); break; case ValueKind::kBytes: scratch = absl::BytesToHexString(value.GetBytes().NativeString(scratch)); break; case ValueKind::kInt: { // Golang supports signed hex, but absl::StrFormat does not. To be // compatible, we need to add a leading '-' if the value is negative. auto tmp = value.GetInt().NativeValue(); if (tmp < 0) { // Negating min int is undefined behavior, so we need to use unsigned // arithmetic. using unsigned_type = std::make_unsigned::type; scratch = absl::StrFormat("-%x", -static_cast(tmp)); } else { scratch = absl::StrFormat("%x", tmp); } break; } case ValueKind::kUint: scratch = absl::StrFormat("%x", value.GetUint().NativeValue()); break; default: return absl::InvalidArgumentError( absl::StrCat("hex clause can only be used on integers, byte buffers, " "and strings, was given ", value.GetTypeName())); } if (use_upper_case) { absl::AsciiStrToUpper(&scratch); } return scratch; } absl::StatusOr FormatOctal( const Value& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { switch (value.kind()) { case ValueKind::kInt: { // Golang supports signed octals, but absl::StrFormat does not. To be // compatible, we need to add a leading '-' if the value is negative. auto tmp = value.GetInt().NativeValue(); if (tmp < 0) { // Negating min int is undefined behavior, so we need to use unsigned // arithmetic. using unsigned_type = std::make_unsigned::type; scratch = absl::StrFormat("-%o", -static_cast(tmp)); } else { scratch = absl::StrFormat("%o", tmp); } return scratch; } case ValueKind::kUint: scratch = absl::StrFormat("%o", value.GetUint().NativeValue()); return scratch; default: return absl::InvalidArgumentError( absl::StrCat("octal clause can only be used on integers, was given ", value.GetTypeName())); } } absl::StatusOr GetDouble(const Value& value, std::string& scratch) { if (value.kind() == ValueKind::kString) { auto str = value.GetString().NativeString(scratch); if (str == "NaN") { return std::nan(""); } else if (str == "Infinity") { return std::numeric_limits::infinity(); } else if (str == "-Infinity") { return -std::numeric_limits::infinity(); } else { return absl::InvalidArgumentError( absl::StrCat("only \"NaN\", \"Infinity\", and \"-Infinity\" are " "supported for conversion to double: ", str)); } } if (value.kind() != ValueKind::kDouble) { return absl::InvalidArgumentError( absl::StrCat("expected a double but got a ", value.GetTypeName())); } return value.GetDouble().NativeValue(); } absl::StatusOr FormatFixed( const Value& value, std::optional precision, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { CEL_ASSIGN_OR_RETURN(auto number, GetDouble(value, scratch)); return FormatDouble(number, precision, /*use_scientific_notation=*/false, scratch); } absl::StatusOr FormatScientific( const Value& value, std::optional precision, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { CEL_ASSIGN_OR_RETURN(auto number, GetDouble(value, scratch)); return FormatDouble(number, precision, /*use_scientific_notation=*/true, scratch); } absl::StatusOr> ParseAndFormatClause( absl::string_view format, const Value& value, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { CEL_ASSIGN_OR_RETURN(auto precision_pair, ParsePrecision(format, max_precision)); auto [read, precision] = precision_pair; switch (format[read]) { case 's': { CEL_ASSIGN_OR_RETURN(auto result, FormatString(value, descriptor_pool, message_factory, arena, scratch)); return std::pair{read, result}; } case 'd': { CEL_ASSIGN_OR_RETURN(auto result, FormatDecimal(value, scratch)); return std::pair{read, result}; } case 'f': { CEL_ASSIGN_OR_RETURN(auto result, FormatFixed(value, precision, scratch)); return std::pair{read, result}; } case 'e': { CEL_ASSIGN_OR_RETURN(auto result, FormatScientific(value, precision, scratch)); return std::pair{read, result}; } case 'b': { CEL_ASSIGN_OR_RETURN(auto result, FormatBinary(value, scratch)); return std::pair{read, result}; } case 'x': case 'X': { CEL_ASSIGN_OR_RETURN( auto result, FormatHex(value, /*use_upper_case=*/format[read] == 'X', scratch)); return std::pair{read, result}; } case 'o': { CEL_ASSIGN_OR_RETURN(auto result, FormatOctal(value, scratch)); return std::pair{read, result}; } default: return absl::InvalidArgumentError(absl::StrFormat( "unrecognized formatting clause \"%c\"", format[read])); } } absl::StatusOr Format( const StringValue& format_value, const ListValue& args, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string format_scratch, clause_scratch; absl::string_view format = format_value.NativeString(format_scratch); std::string result; result.reserve(format.size()); int64_t arg_index = 0; CEL_ASSIGN_OR_RETURN(int64_t args_size, args.Size()); for (int64_t i = 0; i < format.size(); ++i) { clause_scratch.clear(); if (format[i] != '%') { result.push_back(format[i]); continue; } ++i; if (i >= format.size()) { return ErrorValue( absl::InvalidArgumentError("unexpected end of format string")); } if (format[i] == '%') { result.push_back('%'); continue; } if (arg_index >= args_size) { return ErrorValue(absl::InvalidArgumentError( absl::StrFormat("index %d out of range", arg_index))); } CEL_ASSIGN_OR_RETURN(auto value, args.Get(arg_index++, descriptor_pool, message_factory, arena)); auto clause = ParseAndFormatClause(format.substr(i), value, max_precision, descriptor_pool, message_factory, arena, clause_scratch); if (!clause.ok()) { return ErrorValue(std::move(clause).status()); } absl::StrAppend(&result, clause->second); i += clause->first; } return StringValue::From(std::move(result), arena); } } // namespace absl::Status RegisterStringFormattingFunctions( FunctionRegistry& registry, const RuntimeOptions& options, StringsExtensionFormatOptions format_options) { const int max_precision = std::clamp(format_options.max_precision, 0, kMaxPrecision); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, ListValue>:: CreateDescriptor("format", /*receiver_style=*/true), BinaryFunctionAdapter, StringValue, ListValue>:: WrapFunction( [max_precision]( const StringValue& format, const ListValue& args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return Format(format, args, max_precision, descriptor_pool, message_factory, arena); }))); return absl::OkStatus(); } } // namespace cel::extensions ================================================ FILE: extensions/formatting.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel::extensions { struct StringsExtensionFormatOptions { // The maximum precision to permit for formatting floating-point numbers. int max_precision = 1000; }; // Register extension functions for string formatting. // // This implements (string).format([args...]) in the strings extension. Most // users should add these functions via `extensions/strings.h` instead. absl::Status RegisterStringFormattingFunctions( FunctionRegistry& registry, const RuntimeOptions& options, StringsExtensionFormatOptions format_options = {}); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_FORMATTING_H_ ================================================ FILE: extensions/formatting_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/formatting.h" #include #include #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/value.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "parser/options.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::ValuesIn; using StringFormatLimitsTest = TestWithParam; // Check that formatted floating points are reversible. TEST_P(StringFormatLimitsTest, FormatLimits) { google::protobuf::Arena arena; const RuntimeOptions options; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_THAT( RegisterStringFormattingFunctions(builder.function_registry(), options), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(GetParam(), "", ParserOptions{})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); Activation activation; static_assert(std::numeric_limits::min_exponent == -1021); for (double x : { 0x1p-1021, 0x3p-1021, std::numeric_limits::epsilon() * 0x1p-3, std::numeric_limits::epsilon() * 0x7p-3, 1.1 / 7.0 * 1e-101, 1.2 / 7.0 * 1e-101, }) { activation.InsertOrAssignValue("x", DoubleValue(x)); ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value.Is()); EXPECT_TRUE(value.GetBool().NativeValue()); } } TEST(StringFormatLimitsTest, MaxPrecisionOption) { google::protobuf::Arena arena; const RuntimeOptions options; StringsExtensionFormatOptions format_options; format_options.max_precision = 99; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_THAT(RegisterStringFormattingFunctions(builder.function_registry(), options, format_options), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("'%.100f'.format([1.123])", "", ParserOptions{})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); Activation activation; ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value.Is()); EXPECT_THAT(value.GetError().ToStatus().message(), HasSubstr("precision specifier exceeds maximum of 99")); } INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest, ValuesIn({ "double('%.326f'.format([x])) == x", "double('%.17e'.format([x])) == x", })); struct FormattingTestCase { std::string name; std::string format; std::string format_args; absl::flat_hash_map> dyn_args; std::string expected; std::optional error = std::nullopt; }; google::protobuf::Arena* GetTestArena() { static absl::NoDestructor arena; return &*arena; } template ParsedMessageValue MakeMessage(absl::string_view text) { return ParsedMessageValue( internal::DynamicParseTextProto(GetTestArena(), text, internal::GetTestingDescriptorPool(), internal::GetTestingMessageFactory()), GetTestArena()); } using StringFormatTest = TestWithParam; TEST_P(StringFormatTest, TestStringFormatting) { const FormattingTestCase& test_case = GetParam(); google::protobuf::Arena arena; const RuntimeOptions options; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); auto registration_status = RegisterStringFormattingFunctions(builder.function_registry(), options); if (test_case.error.has_value() && !registration_status.ok()) { EXPECT_THAT(registration_status.message(), HasSubstr(*test_case.error)); return; } else { ASSERT_THAT(registration_status, IsOk()); } ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); auto expr_str = absl::StrFormat("'''%s'''.format([%s])", test_case.format, test_case.format_args); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(expr_str, "", ParserOptions{})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); Activation activation; for (const auto& [name, value] : test_case.dyn_args) { if (std::holds_alternative(value)) { activation.InsertOrAssignValue(name, StringValue{std::get(value)}); } else if (std::holds_alternative(value)) { activation.InsertOrAssignValue(name, BoolValue{std::get(value)}); } else if (std::holds_alternative(value)) { activation.InsertOrAssignValue(name, IntValue{std::get(value)}); } else if (std::holds_alternative(value)) { activation.InsertOrAssignValue(name, IntValue{std::get(value)}); } else if (std::holds_alternative(value)) { activation.InsertOrAssignValue(name, UintValue{std::get(value)}); } else if (std::holds_alternative(value)) { activation.InsertOrAssignValue(name, DoubleValue{std::get(value)}); } else if (std::holds_alternative(value)) { activation.InsertOrAssignValue( name, DurationValue{std::get(value)}); } else if (std::holds_alternative(value)) { activation.InsertOrAssignValue( name, TimestampValue{std::get(value)}); } else if (std::holds_alternative(value)) { activation.InsertOrAssignValue(name, std::get(value)); } } auto result = program->Evaluate(&arena, activation); if (test_case.error.has_value()) { if (result.ok()) { EXPECT_THAT(result->DebugString(), HasSubstr(*test_case.error)); } else { EXPECT_THAT(result.status().message(), HasSubstr(*test_case.error)); } } else { if (!result.ok()) { // Make it easier to debug the test case. ASSERT_THAT(result.status().message(), ""); // Make sure test case stops here. ASSERT_TRUE(result.ok()); } ASSERT_TRUE(result->Is()); EXPECT_THAT(result->GetString().ToString(), test_case.expected); } } INSTANTIATE_TEST_SUITE_P( TestStringFormatting, StringFormatTest, ValuesIn({ { .name = "Basic", .format = "%s %s!", .format_args = "'hello', 'world'", .expected = "hello world!", }, { .name = "EscapedPercentSign", .format = "Percent sign %%!", .format_args = "'hello', 'world'", .expected = "Percent sign %!", }, { .name = "IncompleteCase", .format = "%", .format_args = "'hello'", .error = "unexpected end of format string", }, { .name = "MissingFormatArg", .format = "%s", .format_args = "", .error = "index 0 out of range", }, { .name = "MissingFormatArg2", .format = "%s, %s", .format_args = "'hello'", .error = "index 1 out of range", }, { .name = "InvalidPrecision", .format = "%.6", .format_args = "'hello'", .error = "unable to find end of precision specifier", }, { .name = "InvalidPrecision2", .format = "%.f", .format_args = "'hello'", .error = "unable to convert precision specifier to integer", }, { .name = "InvalidPrecision3", .format = "%.", .format_args = "'hello'", .error = "unable to find end of precision specifier", }, { .name = "InvalidPrecisionOutOfRange", .format = "%.1001f", .format_args = "1.2345", .error = "precision specifier exceeds maximum of 100", }, { .name = "DecimalFormatingClause", .format = "int %d, uint %d", .format_args = "-1, uint(2)", .expected = R"(int -1, uint 2)", }, { .name = "OctalFormatingClause", .format = "int %o, uint %o", .format_args = "-10, uint(20)", .expected = R"(int -12, uint 24)", }, { .name = "OctalDoesNotWorkWithDouble", .format = "double %o", .format_args = "double(\"-Inf\")", .error = "octal clause can only be used on integers, was given double", }, { .name = "HexFormatingClause", .format = "int %x, uint %X, string %x, bytes %X", .format_args = "-10, uint(255), 'hello', b'world'", .expected = "int -a, uint FF, string 68656c6c6f, bytes 776F726C64", }, { .name = "HexFormatingClauseLeadingZero", .format = "string: %x", .format_args = R"(b'\x00\x00hello\x00')", .expected = "string: 000068656c6c6f00", }, { .name = "HexDoesNotWorkWithDouble", .format = "double %x", .format_args = "double(\"-Inf\")", .error = "hex clause can only be used on integers, byte buffers, " "and strings, was given double", }, { .name = "BinaryFormatingClause", .format = "int %b, uint %b, bool %b, bool %b", .format_args = "-32, uint(20), false, true", .expected = "int -100000, uint 10100, bool 0, bool 1", }, { .name = "BinaryFormatingClauseLimits", .format = "min_int %b, max_int %b, max_uint %b", .format_args = absl::StrCat(std::numeric_limits::min(), ",", std::numeric_limits::max(), ",", std::numeric_limits::max(), "u"), .expected = "min_int " "-10000000000000000000000000000000000000000000000000000" "00000000000, max_int " "111111111111111111111111111111111111111111111111111111" "111111111, max_uint " "111111111111111111111111111111111111111111111111111111" "1111111111", }, { .name = "BinaryFormatingClauseZero", .format = "zero %b", .format_args = "0", .expected = "zero 0", }, { .name = "HexFormatingClauseLimits", .format = "min_int %x, max_int %x, max_uint %x", .format_args = absl::StrCat(std::numeric_limits::min(), ",", std::numeric_limits::max(), ",", std::numeric_limits::max(), "u"), .expected = "min_int -8000000000000000, max_int 7fffffffffffffff, " "max_uint ffffffffffffffff", }, { .name = "OctalFormatingClauseLimits", .format = "min_int %o, max_int %o, max_uint %o", .format_args = absl::StrCat(std::numeric_limits::min(), ",", std::numeric_limits::max(), ",", std::numeric_limits::max(), "u"), .expected = "min_int -1000000000000000000000, max_int " "777777777777777777777, max_uint 1777777777777777777777", }, { .name = "FixedClauseFormatting", .format = "%f", .format_args = "10000.1234", .expected = "10000.123400", }, { .name = "FixedClauseFormattingWithPrecision", .format = "%.2f", .format_args = "10000.1234", .expected = "10000.12", }, { .name = "ListSupportForStringWithQuotes", .format = "%s", .format_args = R"(["a\"b","a\\b"])", .expected = "[a\"b, a\\b]", }, { .name = "ListSupportForStringWithDouble", .format = "%s", .format_args = R"([double("NaN"),double("Infinity"), double("-Infinity")])", .expected = "[NaN, Infinity, -Infinity]", }, FormattingTestCase{ .name = "FixedClauseFormattingWithDynArgs", .format = "%.2f %d", .format_args = "arg, message.single_int32", .dyn_args = { {"arg", 10000.1234}, {"message", MakeMessage(R"pb(single_int32: 42)pb")}, }, .expected = "10000.12 42", }, { .name = "NoOp", .format = "no substitution", .expected = "no substitution", }, { .name = "MidStringSubstitution", .format = "str is %s and some more", .format_args = "'filler'", .expected = "str is filler and some more", }, { .name = "PercentEscaping", .format = "%% and also %%", .expected = "% and also %", }, { .name = "SubstitutionInsideEscapedPercentSigns", .format = "%%%s%%", .format_args = "'text'", .expected = "%text%", }, { .name = "SubstitutionWithOneEscapedPercentSignOnTheRight", .format = "%s%%", .format_args = "'percent on the right'", .expected = "percent on the right%", }, { .name = "SubstitutionWithOneEscapedPercentSignOnTheLeft", .format = "%%%s", .format_args = "'percent on the left'", .expected = "%percent on the left", }, { .name = "MultipleSubstitutions", .format = "%d %d %d, %s %s %s, %d %d %d, %s %s %s", .format_args = "1, 2, 3, 'A', 'B', 'C', 4, 5, 6, 'D', 'E', 'F'", .expected = "1 2 3, A B C, 4 5 6, D E F", }, { .name = "PercentSignEscapeSequenceSupport", .format = "\u0025\u0025escaped \u0025s\u0025\u0025", .format_args = "'percent'", .expected = "%escaped percent%", }, { .name = "FixedPointFormattingClause", .format = "%.3f", .format_args = "1.2345", .expected = "1.234", }, { .name = "BinaryFormattingClause", .format = "this is 5 in binary: %b", .format_args = "5", .expected = "this is 5 in binary: 101", }, { .name = "UintSupportForBinaryFormatting", .format = "unsigned 64 in binary: %b", .format_args = "uint(64)", .expected = "unsigned 64 in binary: 1000000", }, { .name = "BoolSupportForBinaryFormatting", .format = "bit set from bool: %b", .format_args = "true", .expected = "bit set from bool: 1", }, { .name = "OctalFormattingClause", .format = "%o", .format_args = "11", .expected = "13", }, { .name = "UintSupportForOctalFormattingClause", .format = "this is an unsigned octal: %o", .format_args = "uint(65535)", .expected = "this is an unsigned octal: 177777", }, { .name = "LowercaseHexadecimalFormattingClause", .format = "%x is 20 in hexadecimal", .format_args = "30", .expected = "1e is 20 in hexadecimal", }, { .name = "UppercaseHexadecimalFormattingClause", .format = "%X is 20 in hexadecimal", .format_args = "30", .expected = "1E is 20 in hexadecimal", }, { .name = "UnsignedSupportForHexadecimalFormattingClause", .format = "%X is 6000 in hexadecimal", .format_args = "uint(6000)", .expected = "1770 is 6000 in hexadecimal", }, { .name = "StringSupportWithHexadecimalFormattingClause", .format = "%x", .format_args = R"("Hello world!")", .expected = "48656c6c6f20776f726c6421", }, { .name = "StringSupportWithUppercaseHexadecimalFormattingClause", .format = "%X", .format_args = R"("Hello world!")", .expected = "48656C6C6F20776F726C6421", }, { .name = "ByteSupportWithHexadecimalFormattingClause", .format = "%x", .format_args = R"(b"byte string")", .expected = "6279746520737472696e67", }, { .name = "ByteSupportWithUppercaseHexadecimalFormattingClause", .format = "%X", .format_args = R"(b"byte string")", .expected = "6279746520737472696E67", }, { .name = "ScientificNotationFormattingClause", .format = "%.6e", .format_args = "1052.032911275", .expected = "1.052033e+03", }, { .name = "ScientificNotationFormattingClause2", .format = "%e", .format_args = "1234.0", .expected = "1.234000e+03", }, { .name = "DefaultPrecisionForFixedPointClause", .format = "%f", .format_args = "2.71828", .expected = "2.718280", }, { .name = "DefaultPrecisionForScientificNotation", .format = "%e", .format_args = "2.71828", .expected = "2.718280e+00", }, { .name = "NaNSupportForFixedPoint", .format = "%f", .format_args = "\"NaN\"", .expected = "NaN", }, { .name = "PositiveInfinitySupportForFixedPoint", .format = "%f", .format_args = "\"Infinity\"", .expected = "Infinity", }, { .name = "NegativeInfinitySupportForFixedPoint", .format = "%f", .format_args = "\"-Infinity\"", .expected = "-Infinity", }, { .name = "UintSupportForDecimalClause", .format = "%d", .format_args = "uint(64)", .expected = "64", }, { .name = "NullSupportForString", .format = "null: %s", .format_args = "null", .expected = "null: null", }, { .name = "IntSupportForString", .format = "%s", .format_args = "999999999999", .expected = "999999999999", }, { .name = "BytesSupportForString", .format = "some bytes: %s", .format_args = "b\"xyz\"", .expected = "some bytes: xyz", }, { .name = "TypeSupportForString", .format = "type is %s", .format_args = "type(\"test string\")", .expected = "type is string", }, { .name = "TimestampSupportForString", .format = "%s", .format_args = "timestamp(\"2023-02-03T23:31:20+00:00\")", .expected = "2023-02-03T23:31:20Z", }, { .name = "DurationSupportForString", .format = "%s", .format_args = "duration(\"1h45m47s\")", .expected = "6347s", }, { .name = "ListSupportForString", .format = "%s", .format_args = R"(["abc", 3.14, null, [9, 8, 7, 6], timestamp("2023-02-03T23:31:20Z")])", .expected = R"([abc, 3.14, null, [9, 8, 7, 6], 2023-02-03T23:31:20Z])", }, { .name = "MapSupportForString", .format = "%s", .format_args = R"({"key1": b"xyz", "key5": null, "key2": duration("7200s"), "key4": true, "key3": 2.71828})", .expected = R"({key1: xyz, key2: 7200s, key3: 2.71828, key4: true, key5: null})", }, { .name = "MapSupportAllKeyTypes", .format = "map with multiple key types: %s", .format_args = R"({1: "value1", uint(2): "value2", true: double("NaN")})", .expected = "map with multiple key types: {1: value1, 2: value2, " "true: NaN}", }, { .name = "MapAfterDecimalFormatting", .format = "%d %s", .format_args = R"(42, {"key": 1})", .expected = "42 {key: 1}", }, { .name = "BooleanSupportForString", .format = "true bool: %s, false bool: %s", .format_args = "true, false", .expected = "true bool: true, false bool: false", }, FormattingTestCase{ .name = "DynTypeSupportForStringFormattingClause", .format = "Dynamic String: %s", .format_args = R"(dynStr)", .dyn_args = {{"dynStr", std::string("a string")}}, .expected = "Dynamic String: a string", }, FormattingTestCase{ .name = "DynTypeSupportForNumbersWithStringFormattingClause", .format = "Dynamic Int Str: %s Dynamic Double Str: %s", .format_args = R"(dynIntStr, dynDoubleStr)", .dyn_args = { {"dynIntStr", 32}, {"dynDoubleStr", 56.8}, }, .expected = "Dynamic Int Str: 32 Dynamic Double Str: 56.8", }, FormattingTestCase{ .name = "DynTypeSupportForIntegerFormattingClause", .format = "Dynamic Int: %d", .format_args = R"(dynInt)", .dyn_args = {{"dynInt", 128}}, .expected = "Dynamic Int: 128", }, FormattingTestCase{ .name = "DynTypeSupportForIntegerFormattingClauseUnsigned", .format = "Dynamic Unsigned Int: %d", .format_args = R"(dynUnsignedInt)", .dyn_args = {{"dynUnsignedInt", uint64_t{256}}}, .expected = "Dynamic Unsigned Int: 256", }, FormattingTestCase{ .name = "DynTypeSupportForHexFormattingClause", .format = "Dynamic Hex Int: %x", .format_args = R"(dynHexInt)", .dyn_args = {{"dynHexInt", 22}}, .expected = "Dynamic Hex Int: 16", }, FormattingTestCase{ .name = "DynTypeSupportForHexFormattingClauseUppercase", .format = "Dynamic Hex Int: %X (uppercase)", .format_args = R"(dynHexInt)", .dyn_args = {{"dynHexInt", 26}}, .expected = "Dynamic Hex Int: 1A (uppercase)", }, FormattingTestCase{ .name = "DynTypeSupportForUnsignedHexFormattingClause", .format = "Dynamic Hex Int: %x (unsigned)", .format_args = R"(dynUnsignedHexInt)", .dyn_args = {{"dynUnsignedHexInt", uint64_t{500}}}, .expected = "Dynamic Hex Int: 1f4 (unsigned)", }, FormattingTestCase{ .name = "DynTypeSupportForFixedPointFormattingClause", .format = "Dynamic Double: %.3f", .format_args = R"(dynDouble)", .dyn_args = {{"dynDouble", 4.5}}, .expected = "Dynamic Double: 4.500", }, FormattingTestCase{ .name = "DynTypeSupportForFixedPointFormattingClauseCommaSeparatorL" "ocale", .format = "Dynamic Double: %f", .format_args = R"(dynDouble)", .dyn_args = {{"dynDouble", 4.5}}, .expected = "Dynamic Double: 4.500000", }, FormattingTestCase{ .name = "DynTypeSupportForScientificNotation", .format = "(Dynamic Type) E: %e", .format_args = R"(dynE)", .dyn_args = {{"dynE", 2.71828}}, .expected = "(Dynamic Type) E: 2.718280e+00", }, FormattingTestCase{ .name = "DynTypeNaNInfinitySupportForFixedPoint", .format = "NaN: %f, Infinity: %f", .format_args = R"(dynNaN, dynInf)", .dyn_args = {{"dynNaN", std::nan("")}, {"dynInf", std::numeric_limits::infinity()}}, .expected = "NaN: NaN, Infinity: Infinity", }, FormattingTestCase{ .name = "DynTypeSupportForTimestamp", .format = "Dynamic Type Timestamp: %s", .format_args = R"(dynTime)", .dyn_args = {{"dynTime", absl::FromUnixSeconds(1257894000)}}, .expected = "Dynamic Type Timestamp: 2009-11-10T23:00:00Z", }, FormattingTestCase{ .name = "DynTypeSupportForDuration", .format = "Dynamic Type Duration: %s", .format_args = R"(dynDuration)", .dyn_args = {{"dynDuration", absl::Hours(2) + absl::Minutes(25) + absl::Seconds(47)}}, .expected = "Dynamic Type Duration: 8747s", }, FormattingTestCase{ .name = "DynTypeSupportForMaps", .format = "Dynamic Type Map with Duration: %s", .format_args = R"({6:dyn(duration("422s"))})", .expected = "Dynamic Type Map with Duration: {6: 422s}", }, FormattingTestCase{ .name = "DurationsWithSubseconds", .format = "Durations with subseconds: %s", .format_args = R"([duration("422s"), duration("2s123ms"), duration("1us"), duration("1ns"), duration("-1000000ns")])", .expected = "Durations with subseconds: [422s, 2.123s, 0.000001s, " "0.000000001s, -0.001s]", }, { .name = "UnrecognizedFormattingClause", .format = "%a", .format_args = "1", .error = "unrecognized formatting clause \"a\"", }, { .name = "OutOfBoundsArgIndex", .format = "%d %d %d", .format_args = "0, 1", .error = "index 2 out of range", }, { .name = "StringSubstitutionIsNotAllowedWithBinaryClause", .format = "string is %b", .format_args = "\"abc\"", .error = "binary clause can only be used on integers and bools, " "was given string", }, { .name = "DurationSubstitutionIsNotAllowedWithDecimalClause", .format = "%d", .format_args = "duration(\"30m2s\")", .error = "decimal clause can only be used on numbers, was given " "google.protobuf.Duration", }, { .name = "StringSubstitutionIsNotAllowedWithOctalClause", .format = "octal: %o", .format_args = "\"a string\"", .error = "octal clause can only be used on integers, was given string", }, { .name = "DoubleSubstitutionIsNotAllowedWithHexClause", .format = "double is %x", .format_args = "0.5", .error = "hex clause can only be used on integers, byte buffers, " "and strings, was given double", }, { .name = "UppercaseIsNotAllowedForScientificClause", .format = "double is %E", .format_args = "0.5", .error = "unrecognized formatting clause \"E\"", }, { .name = "ObjectIsNotAllowed", .format = "object is %s", .format_args = "cel.expr.conformance.proto3.TestAllTypes{}", .error = "could not convert argument " "cel.expr.conformance.proto3.TestAllTypes to string", }, { .name = "ObjectInsideList", .format = "%s", .format_args = "[1, 2, cel.expr.conformance.proto3.TestAllTypes{}]", .error = "could not convert argument " "cel.expr.conformance.proto3.TestAllTypes to string", }, { .name = "ObjectInsideMap", .format = "%s", .format_args = "{1: \"a\", 2: cel.expr.conformance.proto3.TestAllTypes{}}", .error = "could not convert argument " "cel.expr.conformance.proto3.TestAllTypes to string", }, { .name = "NullNotAllowedForDecimalClause", .format = "null: %d", .format_args = "null", .error = "decimal clause can only be used on numbers, was given " "null_type", }, { .name = "NullNotAllowedForScientificNotationClause", .format = "null: %e", .format_args = "null", .error = "expected a double but got a null_type", }, { .name = "NullNotAllowedForFixedPointClause", .format = "null: %f", .format_args = "null", .error = "expected a double but got a null_type", }, { .name = "NullNotAllowedForHexadecimalClause", .format = "null: %x", .format_args = "null", .error = "hex clause can only be used on integers, byte buffers, " "and strings, was given null_type", }, { .name = "NullNotAllowedForUppercaseHexadecimalClause", .format = "null: %X", .format_args = "null", .error = "hex clause can only be used on integers, byte buffers, " "and strings, was given null_type", }, { .name = "NullNotAllowedForBinaryClause", .format = "null: %b", .format_args = "null", .error = "binary clause can only be used on integers and bools, " "was given null_type", }, { .name = "NullNotAllowedForOctalClause", .format = "null: %o", .format_args = "null", .error = "octal clause can only be used on integers, was given " "null_type", }, { .name = "NegativeBinaryFormattingClause", .format = "this is -5 in binary: %b", .format_args = "-5", .expected = "this is -5 in binary: -101", }, { .name = "NegativeOctalFormattingClause", .format = "%o", .format_args = "-11", .expected = "-13", }, { .name = "NegativeHexadecimalFormattingClause", .format = "%x is -30 in hexadecimal", .format_args = "-30", .expected = "-1e is -30 in hexadecimal", }, { .name = "DefaultPrecisionForString", .format = "%s", .format_args = "2.71", .expected = "2.71", }, { .name = "DefaultListPrecisionForString", .format = "%s", .format_args = "[2.71]", .expected = "[2.71]", // Different from Golang (2.710000) consistent with // the precision of a double outside of a list. }, { .name = "AutomaticRoundingForString", .format = "%s", .format_args = "10002.71", .expected = "10002.7", // Different from Golang (10002.71) which // does not round. }, { .name = "DefaultScientificNotationForString", .format = "%s", .format_args = "0.000000002", .expected = "2e-09", }, { .name = "DefaultListScientificNotationForString", .format = "%s", .format_args = "[0.000000002]", .expected = "[2e-09]", // Different from Golang (0.000000) consistent with // the notation of a double outside of a list. }, { .name = "NaNSupportForString", .format = "%s", .format_args = R"(double("NaN"))", .expected = "NaN", }, { .name = "PositiveInfinitySupportForString", .format = "%s", .format_args = R"(double("Inf"))", .expected = "Infinity", }, { .name = "NegativeInfinitySupportForString", .format = "%s", .format_args = R"(double("-Inf"))", .expected = "-Infinity", }, { .name = "InfinityListSupportForString", .format = "%s", .format_args = R"([double("NaN"), double("+Inf"), double("-Inf")])", .expected = "[NaN, Infinity, -Infinity]", }, { .name = "SmallDurationSupportForString", .format = "%s", .format_args = R"(duration("2ns"))", .expected = "0.000000002s", }, }), [](const testing::TestParamInfo& info) { return info.param.name; }); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/lists_functions.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/lists_functions.h" #include #include #include #include #include #include #include "absl/base/macros.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/internal/builtins_arena.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/expr.h" #include "common/operators.h" #include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser_interface.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { using ::cel::checker_internal::BuiltinsArena; absl::Span SortableTypes() { static const Type kTypes[]{cel::IntType(), cel::UintType(), cel::DoubleType(), cel::BoolType(), cel::DurationType(), cel::TimestampType(), cel::StringType(), cel::BytesType()}; return kTypes; } // Slow distinct() implementation that uses Equal() to compare values in O(n^2). absl::Status ListDistinctHeterogeneousImpl( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValueBuilder* absl_nonnull builder, int64_t start_index = 0, std::vector seen = {}) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); for (int64_t i = start_index; i < size; ++i) { CEL_ASSIGN_OR_RETURN(Value value, list.Get(i, descriptor_pool, message_factory, arena)); bool is_distinct = true; for (const Value& seen_value : seen) { CEL_ASSIGN_OR_RETURN(Value equal, value.Equal(seen_value, descriptor_pool, message_factory, arena)); if (equal.IsTrue()) { is_distinct = false; break; } } if (is_distinct) { seen.push_back(value); CEL_RETURN_IF_ERROR(builder->Add(value)); } } return absl::OkStatus(); } // Fast distinct() implementation for homogeneous hashable types. Falls back to // the slow implementation if the list is not actually homogeneous. template absl::Status ListDistinctHomogeneousHashableImpl( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValueBuilder* absl_nonnull builder) { absl::flat_hash_set seen; CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); for (int64_t i = 0; i < size; ++i) { CEL_ASSIGN_OR_RETURN(Value value, list.Get(i, descriptor_pool, message_factory, arena)); if (auto typed_value = value.As(); typed_value.has_value()) { if (seen.contains(*typed_value)) { continue; } seen.insert(*typed_value); CEL_RETURN_IF_ERROR(builder->Add(value)); } else { // List is not homogeneous, fall back to the slow implementation. // Keep the existing list builder, which already constructed the list of // all the distinct values (that were homogeneous so far) up to index i. // Pass the seen values as a vector to the slow implementation. std::vector seen_values{seen.begin(), seen.end()}; return ListDistinctHeterogeneousImpl(list, descriptor_pool, message_factory, arena, builder, i, std::move(seen_values)); } } return absl::OkStatus(); } absl::StatusOr ListDistinct( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); // If the list is empty or has a single element, we can return it as is. if (size < 2) { return list; } // We need a set to keep track of the seen values. // // By default, for unhashable types, this set is implemented as a vector of // all the seen values, which means that we will perform O(n^2) comparisons // between the values. // // For efficiency purposes, if the first element of the list is hashable, we // will use a specialized implementation that is faster for homogeneous lists // of hashable types. // If the list is not homogeneous, we will fall back to the slow // implementation. // // The total runtime cost is O(n) for homogeneous lists of hashable types, and // O(n^2) for all other cases. auto builder = NewListValueBuilder(arena); CEL_ASSIGN_OR_RETURN(Value first, list.Get(0, descriptor_pool, message_factory, arena)); switch (first.kind()) { case ValueKind::kInt: { CEL_RETURN_IF_ERROR(ListDistinctHomogeneousHashableImpl( list, descriptor_pool, message_factory, arena, builder.get())); break; } case ValueKind::kUint: { CEL_RETURN_IF_ERROR(ListDistinctHomogeneousHashableImpl( list, descriptor_pool, message_factory, arena, builder.get())); break; } case ValueKind::kBool: { CEL_RETURN_IF_ERROR(ListDistinctHomogeneousHashableImpl( list, descriptor_pool, message_factory, arena, builder.get())); break; } case ValueKind::kString: { CEL_RETURN_IF_ERROR(ListDistinctHomogeneousHashableImpl( list, descriptor_pool, message_factory, arena, builder.get())); break; } default: { CEL_RETURN_IF_ERROR(ListDistinctHeterogeneousImpl( list, descriptor_pool, message_factory, arena, builder.get())); break; } } return std::move(*builder).Build(); } absl::Status ListFlattenImpl( const ListValue& list, int64_t remaining_depth, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValueBuilder* absl_nonnull builder) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); for (int64_t i = 0; i < size; ++i) { CEL_ASSIGN_OR_RETURN(Value value, list.Get(i, descriptor_pool, message_factory, arena)); if (absl::optional list_value = value.AsList(); list_value.has_value() && remaining_depth > 0) { CEL_RETURN_IF_ERROR(ListFlattenImpl(*list_value, remaining_depth - 1, descriptor_pool, message_factory, arena, builder)); } else { CEL_RETURN_IF_ERROR(builder->Add(std::move(value))); } } return absl::OkStatus(); } absl::StatusOr ListFlatten( const ListValue& list, int64_t depth, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (depth < 0) { return ErrorValue( absl::InvalidArgumentError("flatten(): level must be non-negative")); } auto builder = NewListValueBuilder(arena); CEL_RETURN_IF_ERROR(ListFlattenImpl(list, depth, descriptor_pool, message_factory, arena, builder.get())); return std::move(*builder).Build(); } absl::StatusOr ListRange( int64_t end, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { auto builder = NewListValueBuilder(arena); builder->Reserve(end); for (int64_t i = 0; i < end; ++i) { CEL_RETURN_IF_ERROR(builder->Add(IntValue(i))); } return std::move(*builder).Build(); } absl::StatusOr ListReverse( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { auto builder = NewListValueBuilder(arena); CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); for (ptrdiff_t i = size - 1; i >= 0; --i) { CEL_ASSIGN_OR_RETURN(Value value, list.Get(i, descriptor_pool, message_factory, arena)); CEL_RETURN_IF_ERROR(builder->Add(value)); } return std::move(*builder).Build(); } absl::StatusOr ListSlice( const ListValue& list, int64_t start, int64_t end, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); if (start < 0 || end < 0) { return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( "cannot slice(%d, %d), negative indexes not supported", start, end))); } if (start > end) { return cel::ErrorValue(absl::InvalidArgumentError( absl::StrFormat("cannot slice(%d, %d), start index must be less than " "or equal to end index", start, end))); } if (size < end) { return cel::ErrorValue(absl::InvalidArgumentError(absl::StrFormat( "cannot slice(%d, %d), list is length %d", start, end, size))); } auto builder = NewListValueBuilder(arena); for (int64_t i = start; i < end; ++i) { CEL_ASSIGN_OR_RETURN(Value val, list.Get(i, descriptor_pool, message_factory, arena)); CEL_RETURN_IF_ERROR(builder->Add(val)); } return std::move(*builder).Build(); } template absl::StatusOr ListSortByAssociatedKeysNative( const ListValue& list, const ListValue& keys, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); // If the list is empty or has a single element, we can return it as is. if (size < 2) { return list; } std::vector keys_vec; absl::Status status = keys.ForEach( [&keys_vec](const Value& value) -> absl::StatusOr { if (auto typed_value = value.As(); typed_value.has_value()) { keys_vec.push_back(*typed_value); } else { return absl::InvalidArgumentError( "sort(): list elements must have the same type"); } return true; }, descriptor_pool, message_factory, arena); if (!status.ok()) { return ErrorValue(status); } ABSL_ASSERT(keys_vec.size() == size); // Already checked by the caller. std::vector sorted_indices(keys_vec.size()); std::iota(sorted_indices.begin(), sorted_indices.end(), 0); std::sort( sorted_indices.begin(), sorted_indices.end(), [&](int64_t a, int64_t b) -> bool { return keys_vec[a] < keys_vec[b]; }); // Now sorted_indices contains the indices of the keys in sorted order. // We can use it to build the sorted list. auto builder = NewListValueBuilder(arena); for (const auto& index : sorted_indices) { CEL_ASSIGN_OR_RETURN( Value value, list.Get(index, descriptor_pool, message_factory, arena)); CEL_RETURN_IF_ERROR(builder->Add(value)); } return std::move(*builder).Build(); } // Internal function used for the implementation of sort() and sortBy(). // // Sorts a list of arbitrary elements, according to the order produced by // sorting another list of comparable elements. If the element type of the keys // is not comparable or the element types are not the same, the function will // produce an error. // // .@sortByAssociatedKeys() -> // U in {int, uint, double, bool, duration, timestamp, string, bytes} // // Example: // // ["foo", "bar", "baz"].@sortByAssociatedKeys([3, 1, 2]) // -> returns ["bar", "baz", "foo"] absl::StatusOr ListSortByAssociatedKeys( const ListValue& list, const ListValue& keys, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(size_t list_size, list.Size()); CEL_ASSIGN_OR_RETURN(size_t keys_size, keys.Size()); if (list_size != keys_size) { return ErrorValue(absl::InvalidArgumentError( absl::StrFormat("@sortByAssociatedKeys() expected a list of the same " "size as the associated keys list, but got %d and %d " "elements respectively.", list_size, keys_size))); } // Empty lists are already sorted. // We don't check for size == 1 because the list could contain a single // element of a type that is not supported by this function. if (list_size == 0) { return list; } CEL_ASSIGN_OR_RETURN(Value first, keys.Get(0, descriptor_pool, message_factory, arena)); switch (first.kind()) { case ValueKind::kInt: return ListSortByAssociatedKeysNative( list, keys, descriptor_pool, message_factory, arena); case ValueKind::kUint: return ListSortByAssociatedKeysNative( list, keys, descriptor_pool, message_factory, arena); case ValueKind::kDouble: return ListSortByAssociatedKeysNative( list, keys, descriptor_pool, message_factory, arena); case ValueKind::kBool: return ListSortByAssociatedKeysNative( list, keys, descriptor_pool, message_factory, arena); case ValueKind::kString: return ListSortByAssociatedKeysNative( list, keys, descriptor_pool, message_factory, arena); case ValueKind::kTimestamp: return ListSortByAssociatedKeysNative( list, keys, descriptor_pool, message_factory, arena); case ValueKind::kDuration: return ListSortByAssociatedKeysNative( list, keys, descriptor_pool, message_factory, arena); case ValueKind::kBytes: return ListSortByAssociatedKeysNative( list, keys, descriptor_pool, message_factory, arena); default: return ErrorValue(absl::InvalidArgumentError( absl::StrFormat("sort(): unsupported type %s", first.GetTypeName()))); } } // Create an expression equivalent to: // target.map(varIdent, mapExpr) absl::optional MakeMapComprehension(MacroExprFactory& factory, Expr target, Expr var_ident, Expr map_expr) { auto step = factory.NewCall( google::api::expr::common::CelOperator::ADD, factory.NewAccuIdent(), factory.NewList(factory.NewListElement(std::move(map_expr)))); auto var_name = var_ident.ident_expr().name(); return factory.NewComprehension(std::move(var_name), std::move(target), factory.AccuVarName(), factory.NewList(), factory.NewBoolConst(true), std::move(step), factory.NewAccuIdent()); } // Create an expression equivalent to: // cel.bind(varIdent, varExpr, call_expr) absl::optional MakeBindComprehension(MacroExprFactory& factory, Expr var_ident, Expr var_expr, Expr call_expr) { auto var_name = var_ident.ident_expr().name(); return factory.NewComprehension( "#unused", factory.NewList(), std::move(var_name), std::move(var_expr), factory.NewBoolConst(false), std::move(var_ident), std::move(call_expr)); } // This macro transforms an expression like: // // mylistExpr.sortBy(e, -math.abs(e)) // // into something equivalent to: // // cel.bind( // @__sortBy_input__, // myListExpr, // @__sortBy_input__.@sortByAssociatedKeys( // @__sortBy_input__.map(e, -math.abs(e) // ) // ) Macro ListSortByMacro() { absl::StatusOr sortby_macro = Macro::Receiver( "sortBy", 2, [](MacroExprFactory& factory, Expr& target, absl::Span args) -> absl::optional { if (!target.has_ident_expr() && !target.has_select_expr() && !target.has_list_expr() && !target.has_comprehension_expr() && !target.has_call_expr()) { return factory.ReportErrorAt( target, "sortBy can only be applied to a list, identifier, " "comprehension, call or select expression"); } auto sortby_input_ident = factory.NewIdent("@__sortBy_input__"); auto sortby_input_expr = std::move(target); auto key_ident = std::move(args[0]); auto key_expr = std::move(args[1]); // Build the map expression: // map_compr := @__sortBy_input__.map(key_ident, key_expr) auto map_compr = MakeMapComprehension(factory, factory.Copy(sortby_input_ident), std::move(key_ident), std::move(key_expr)); if (!map_compr.has_value()) { return absl::nullopt; } // Build the call expression: // call_expr := @__sortBy_input__.@sortByAssociatedKeys(map_compr) std::vector call_args; call_args.push_back(std::move(*map_compr)); auto call_expr = factory.NewMemberCall("@sortByAssociatedKeys", std::move(sortby_input_ident), absl::MakeSpan(call_args)); // Build the returned bind expression: // cel.bind(@__sortBy_input__, target, call_expr) auto var_ident = factory.NewIdent("@__sortBy_input__"); auto var_expr = std::move(sortby_input_expr); auto bind_compr = MakeBindComprehension(factory, std::move(var_ident), std::move(var_expr), std::move(call_expr)); return bind_compr; }); return *sortby_macro; } absl::StatusOr ListSort( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return ListSortByAssociatedKeys(list, list, descriptor_pool, message_factory, arena); } absl::Status RegisterListDistinctFunction(FunctionRegistry& registry) { return UnaryFunctionAdapter, const ListValue&>:: RegisterMemberOverload("distinct", &ListDistinct, registry); } absl::Status RegisterListFlattenFunction(FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, const ListValue&, int64_t>::RegisterMemberOverload("flatten", &ListFlatten, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter, const ListValue&>:: RegisterMemberOverload( "flatten", [](const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return ListFlatten(list, 1, descriptor_pool, message_factory, arena); }, registry))); return absl::OkStatus(); } absl::Status RegisterListRangeFunction(FunctionRegistry& registry) { return UnaryFunctionAdapter, int64_t>::RegisterGlobalOverload("lists.range", &ListRange, registry); } absl::Status RegisterListReverseFunction(FunctionRegistry& registry) { return UnaryFunctionAdapter, const ListValue&>:: RegisterMemberOverload("reverse", &ListReverse, registry); } absl::Status RegisterListSliceFunction(FunctionRegistry& registry) { return TernaryFunctionAdapter, const ListValue&, int64_t, int64_t>::RegisterMemberOverload("slice", &ListSlice, registry); } absl::Status RegisterListSortFunction(FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter, const ListValue&>:: RegisterMemberOverload("sort", &ListSort, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter< absl::StatusOr, const ListValue&, const ListValue&>::RegisterMemberOverload("@sortByAssociatedKeys", &ListSortByAssociatedKeys, registry))); return absl::OkStatus(); } const Type& ListIntType() { static absl::NoDestructor kInstance( ListType(BuiltinsArena(), IntType())); return *kInstance; } const Type& ListTypeParamType() { static absl::NoDestructor kInstance( ListType(BuiltinsArena(), TypeParamType("T"))); return *kInstance; } absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder, int version) { CEL_ASSIGN_OR_RETURN( FunctionDecl distinct_decl, MakeFunctionDecl("distinct", MakeMemberOverloadDecl( "list_distinct", ListTypeParamType(), ListTypeParamType()))); CEL_ASSIGN_OR_RETURN( FunctionDecl flatten_decl, MakeFunctionDecl( "flatten", MakeMemberOverloadDecl("list_flatten_int", ListType(), ListType(), IntType()), MakeMemberOverloadDecl("list_flatten", ListType(), ListType()))); CEL_ASSIGN_OR_RETURN( FunctionDecl range_decl, MakeFunctionDecl( "lists.range", MakeOverloadDecl("list_range", ListIntType(), IntType()))); CEL_ASSIGN_OR_RETURN( FunctionDecl reverse_decl, MakeFunctionDecl( "reverse", MakeMemberOverloadDecl("list_reverse", ListTypeParamType(), ListTypeParamType()))); CEL_ASSIGN_OR_RETURN( FunctionDecl slice_decl, MakeFunctionDecl( "slice", MakeMemberOverloadDecl("list_slice", ListTypeParamType(), ListTypeParamType(), IntType(), IntType()))); static const absl::NoDestructor> kSortableListTypes([] { std::vector instance; instance.reserve(SortableTypes().size()); for (const Type& type : SortableTypes()) { instance.push_back(ListType(BuiltinsArena(), type)); } return instance; }()); FunctionDecl sort_decl; sort_decl.set_name("sort"); FunctionDecl sort_by_key_decl; sort_by_key_decl.set_name("@sortByAssociatedKeys"); for (const Type& list_type : *kSortableListTypes) { std::string elem_type_name(list_type.AsList()->GetElement().name()); CEL_RETURN_IF_ERROR(sort_decl.AddOverload(MakeMemberOverloadDecl( absl::StrCat("list_", elem_type_name, "_sort"), list_type, list_type))); CEL_RETURN_IF_ERROR(sort_by_key_decl.AddOverload(MakeMemberOverloadDecl( absl::StrCat("list_", elem_type_name, "_sortByAssociatedKeys"), ListTypeParamType(), ListTypeParamType(), list_type))); } CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(slice_decl))); if (version == 0) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(flatten_decl))); if (version == 1) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_by_key_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(distinct_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(range_decl))); // MergeFunction is used to combine with the reverse function // defined in strings extension. CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl))); return absl::OkStatus(); } std::vector lists_macros(int version) { switch (version) { case 0: return {}; case 1: return {}; case 2: default: return {ListSortByMacro()}; }; } absl::Status ConfigureParser(ParserBuilder& builder, int version) { for (const Macro& macro : lists_macros(version)) { CEL_RETURN_IF_ERROR(builder.AddMacro(macro)); } return absl::OkStatus(); } } // namespace absl::Status RegisterListsFunctions(FunctionRegistry& registry, const RuntimeOptions& options, int version) { CEL_RETURN_IF_ERROR(RegisterListSliceFunction(registry)); if (version == 0) { return absl::OkStatus(); } // Since version 1 CEL_RETURN_IF_ERROR(RegisterListFlattenFunction(registry)); if (version == 1) { return absl::OkStatus(); } // Since version 2 CEL_RETURN_IF_ERROR(RegisterListDistinctFunction(registry)); CEL_RETURN_IF_ERROR(RegisterListRangeFunction(registry)); CEL_RETURN_IF_ERROR(RegisterListReverseFunction(registry)); CEL_RETURN_IF_ERROR(RegisterListSortFunction(registry)); return absl::OkStatus(); } absl::Status RegisterListsMacros(MacroRegistry& registry, const ParserOptions&, int version) { return registry.RegisterMacros(lists_macros(version)); } CheckerLibrary ListsCheckerLibrary(int version) { return {.id = "cel.lib.ext.lists", .configure = [version](TypeCheckerBuilder& builder) { return RegisterListsCheckerDecls(builder, version); }}; } CompilerLibrary ListsCompilerLibrary(int version) { auto lib = CompilerLibrary::FromCheckerLibrary(ListsCheckerLibrary(version)); lib.configure_parser = [version](ParserBuilder& builder) { return ConfigureParser(builder, version); }; return lib; } } // namespace cel::extensions ================================================ FILE: extensions/lists_functions.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_LISTS_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_LISTS_FUNCTIONS_H_ #include "absl/status/status.h" #include "checker/type_checker_builder.h" #include "compiler/compiler.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel::extensions { constexpr int kListsExtensionLatestVersion = 2; // Register implementations for list extension functions. // // === Since version 0 === // .slice(start: int, end: int) -> list(T) // // === Since version 1 === // .flatten() -> list(dyn) // .flatten(limit: int) -> list(dyn) // // === Since version 2 === // lists.range(n: int) -> list(int) // // .distinct() -> list(T) // // .reverse() -> list(T) // // .sort() -> list(T) // absl::Status RegisterListsFunctions(FunctionRegistry& registry, const RuntimeOptions& options, int version = kListsExtensionLatestVersion); // Register list macros. // // === Since version 2 === // // .sortBy(, ) absl::Status RegisterListsMacros(MacroRegistry& registry, const ParserOptions& options, int version = kListsExtensionLatestVersion); // Type check declarations for the lists extension library. // Provides decls for the following functions: // // === Since version 0 === // .slice(start: int, end: int) -> list(T) // // === Since version 1 === // .flatten() -> list(dyn) // .flatten(limit: int) -> list(dyn) // // === Since version 2 === // lists.range(n: int) -> list(int) // // .distinct() -> list(T) // // .reverse() -> list(T) // // .sort() -> list(T_) where T_ is partially orderable CheckerLibrary ListsCheckerLibrary(int version = kListsExtensionLatestVersion); // Provides decls for the following functions: // // === Since version 0 === // .slice(start: int, end: int) -> list(T) // // === Since version 1 === // .flatten() -> list(dyn) // .flatten(limit: int) -> list(dyn) // // === Since version 2 === // lists.range(n: int) -> list(int) // // .distinct() -> list(T) // // .reverse() -> list(T) // // .sort() -> list(T_) where T_ is partially orderable CompilerLibrary ListsCompilerLibrary( int version = kListsExtensionLatestVersion); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ ================================================ FILE: extensions/lists_functions_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/lists_functions.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" #include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/source.h" #include "common/value.h" #include "common/value_testing.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" #include "parser/standard_macros.h" #include "runtime/activation.h" #include "runtime/reference_resolver.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::test::ErrorValueIs; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::testing::Contains; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::ValuesIn; struct TestInfo { std::string expr; std::string err = ""; }; class ListsFunctionsTest : public testing::TestWithParam {}; TEST_P(ListsFunctionsTest, EndToEnd) { const TestInfo& test_info = GetParam(); RecordProperty("cel_expression", test_info.expr); if (!test_info.err.empty()) { RecordProperty("cel_expected_error", test_info.err); } ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_info.expr, "")); MacroRegistry macro_registry; ParserOptions parser_options{.add_macro_calls = true}; ASSERT_THAT(RegisterStandardMacros(macro_registry, parser_options), IsOk()); ASSERT_THAT(RegisterListsMacros(macro_registry, parser_options), IsOk()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, google::api::expr::parser::Parse(*source, macro_registry, parser_options)); Expr expr = parsed_expr.expr(); SourceInfo source_info = parsed_expr.source_info(); google::protobuf::Arena arena; const auto options = RuntimeOptions{}; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); // Needed to resolve namespaced functions when evaluating a ParsedExpr. ASSERT_THAT(cel::EnableReferenceResolver( builder, cel::ReferenceResolverEnabled::kAlways), IsOk()); EXPECT_THAT(RegisterListsFunctions(builder.function_registry(), options), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); Activation activation; ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); if (!test_info.err.empty()) { EXPECT_THAT(result, ErrorValueIs(StatusIs(testing::_, HasSubstr(test_info.err)))); return; } ASSERT_TRUE(result.IsBool()) << test_info.expr << " -> " << result.DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()) << test_info.expr << " -> " << result.DebugString(); } INSTANTIATE_TEST_SUITE_P( ListsFunctionsTest, ListsFunctionsTest, testing::ValuesIn({ // lists.range() {R"cel(lists.range(4) == [0,1,2,3])cel"}, {R"cel(lists.range(0) == [])cel"}, // .reverse() {R"cel([5,1,2,3].reverse() == [3,2,1,5])cel"}, {R"cel([] == [])cel"}, {R"cel([1] == [1])cel"}, {R"cel( ['are', 'you', 'as', 'bored', 'as', 'I', 'am'].reverse() == ['am', 'I', 'as', 'bored', 'as', 'you', 'are'] )cel"}, {R"cel( [false, true, true].reverse().reverse() == [false, true, true] )cel"}, // .slice() {R"cel([1,2,3,4].slice(0, 4) == [1,2,3,4])cel"}, {R"cel([1,2,3,4].slice(0, 0) == [])cel"}, {R"cel([1,2,3,4].slice(1, 1) == [])cel"}, {R"cel([1,2,3,4].slice(4, 4) == [])cel"}, {R"cel([1,2,3,4].slice(1, 3) == [2, 3])cel"}, {R"cel([1,2,3,4].slice(3, 0))cel", "cannot slice(3, 0), start index must be less than or equal to end " "index"}, {R"cel([1,2,3,4].slice(0, 10))cel", "cannot slice(0, 10), list is length 4"}, {R"cel([1,2,3,4].slice(-5, 10))cel", "cannot slice(-5, 10), negative indexes not supported"}, {R"cel([1,2,3,4].slice(-5, -3))cel", "cannot slice(-5, -3), negative indexes not supported"}, // .flatten() {R"cel(dyn([]).flatten() == [])cel"}, {R"cel(dyn([1,2,3,4]).flatten() == [1,2,3,4])cel"}, {R"cel([1,[2,[3,4]]].flatten() == [1,2,[3,4]])cel"}, {R"cel([1,2,[],[],[3,4]].flatten() == [1,2,3,4])cel"}, {R"cel([1,[2,[3,4]]].flatten(2) == [1,2,3,4])cel"}, {R"cel([1,[2,[3,[4]]]].flatten(-1))cel", "level must be non-negative"}, // .sort() {R"cel([].sort() == [])cel"}, {R"cel([1].sort() == [1])cel"}, {R"cel([4, 3, 2, 1].sort() == [1, 2, 3, 4])cel"}, {R"cel(["d", "a", "b", "c"].sort() == ["a", "b", "c", "d"])cel"}, {R"cel([b"d", b"a", b"aa"].sort() == [b"a", b"aa", b"d"])cel"}, {R"cel( [1.0, -1.5, 2.0, 1.0, -1.5, -1.5].sort() == [-1.5, -1.5, -1.5, 1.0, 1.0, 2.0] )cel"}, {R"cel( [42u, 3u, 1337u, 42u, 1337u, 3u, 42u].sort() == [3u, 3u, 42u, 42u, 42u, 1337u, 1337u] )cel"}, {R"cel([false, true, false].sort() == [false, false, true])cel"}, {R"cel( [ timestamp('2024-01-03T00:00:00Z'), timestamp('2024-01-01T00:00:00Z'), timestamp('2024-01-02T00:00:00Z'), ].sort() == [ timestamp('2024-01-01T00:00:00Z'), timestamp('2024-01-02T00:00:00Z'), timestamp('2024-01-03T00:00:00Z'), ] )cel"}, {R"cel( [duration('1m'), duration('2s'), duration('3h')].sort() == [duration('2s'), duration('1m'), duration('3h')] )cel"}, {R"cel(["d", 3, 2, "c"].sort())cel", "list elements must have the same type"}, {R"cel([google.api.expr.runtime.TestMessage{}].sort())cel", "unsupported type google.api.expr.runtime.TestMessage"}, {R"cel([[1], [2]].sort())cel", "unsupported type list"}, // .sortBy() {R"cel([].sortBy(e, e) == [])cel"}, {R"cel(["a"].sortBy(e, e) == ["a"])cel"}, {R"cel( [-3, 1, -5, -2, 4].sortBy(e, -(e * e)) == [-5, 4, -3, -2, 1] )cel"}, {R"cel( [-3, 1, -5, -2, 4].map(e, e * 2).sortBy(e, -(e * e)) == [-10, 8, -6, -4, 2] )cel"}, {R"cel(lists.range(3).sortBy(e, -e) == [2, 1, 0])cel"}, {R"cel( ["a", "c", "b", "first"].sortBy(e, e == "first" ? "" : e) == ["first", "a", "b", "c"] )cel"}, {R"cel( [ google.api.expr.runtime.TestMessage{string_value: 'foo'}, google.api.expr.runtime.TestMessage{string_value: 'bar'}, google.api.expr.runtime.TestMessage{string_value: 'baz'} ].sortBy(e, e.string_value) == [ google.api.expr.runtime.TestMessage{string_value: 'bar'}, google.api.expr.runtime.TestMessage{string_value: 'baz'}, google.api.expr.runtime.TestMessage{string_value: 'foo'} ] )cel"}, {R"cel([[2], [1], [3]].sortBy(e, e[0]) == [[1], [2], [3]])cel"}, {R"cel([[1], ["a"]].sortBy(e, e[0]))cel", "list elements must have the same type"}, {R"cel([[1], [2]].sortBy(e, e))cel", "unsupported type list"}, {R"cel([google.api.expr.runtime.TestMessage{}].sortBy(e, e))cel", "unsupported type google.api.expr.runtime.TestMessage"}, // .distinct() {R"cel([].distinct() == [])cel"}, {R"cel([1].distinct() == [1])cel"}, {R"cel([-2, 5, -2, 1, 1, 5, -2, 1].distinct() == [-2, 5, 1])cel"}, {R"cel( [2u, 5u, 100u, 1u, 1u, 5u, 2u, 1u].distinct() == [2u, 5u, 100u, 1u] )cel"}, {R"cel([false, true, true, false].distinct() == [false, true])cel"}, {R"cel( ['c', 'a', 'a', 'b', 'a', 'b', 'c', 'c'].distinct() == ['c', 'a', 'b'] )cel"}, {R"cel([1, 2.0, "c", 3, "c", 1].distinct() == [1, 2.0, "c", 3])cel"}, {R"cel([1, 1.0, 2].distinct() == [1, 2])cel"}, {R"cel([1, 1u].distinct() == [1])cel"}, {R"cel([[1], [1], [2]].distinct() == [[1], [2]])cel"}, {R"cel( [ google.api.expr.runtime.TestMessage{string_value: 'a'}, google.api.expr.runtime.TestMessage{string_value: 'b'}, google.api.expr.runtime.TestMessage{string_value: 'a'} ].distinct() == [ google.api.expr.runtime.TestMessage{string_value: 'a'}, google.api.expr.runtime.TestMessage{string_value: 'b'} ] )cel"}, {R"cel( [ google.api.expr.runtime.TestMessage{string_value: 'a'}, 1, 42.0, [1, 2, 3], false, ].distinct() == [ google.api.expr.runtime.TestMessage{string_value: 'a'}, 1, 42.0, [1, 2, 3], false, ] )cel"}, })); TEST(ListsFunctionsTest, ListSortByMacroParseError) { ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("100.sortBy(e, e)", "")); MacroRegistry macro_registry; ParserOptions parser_options{.add_macro_calls = true}; ASSERT_THAT(RegisterListsMacros(macro_registry, parser_options), IsOk()); EXPECT_THAT( google::api::expr::parser::Parse(*source, macro_registry, parser_options), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("sortBy can only be applied to"))); } struct ListCheckerTestCase { std::string expr; std::string error_substr; }; class ListsCheckerLibraryTest : public ::testing::TestWithParam { public: void SetUp() override { // Arrange: Configure the compiler. // Add the lists checker library to the compiler builder. ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(ListsCompilerLibrary()), IsOk()); compiler_builder->GetCheckerBuilder().set_container( "cel.expr.conformance.proto3"); ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); } std::unique_ptr compiler_; }; TEST_P(ListsCheckerLibraryTest, ListsFunctionsTypeCheckerSuccess) { // Act & Assert: Compile the expression and validate the result. ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler_->Compile(GetParam().expr)); absl::string_view error_substr = GetParam().error_substr; EXPECT_EQ(result.IsValid(), error_substr.empty()); if (!error_substr.empty()) { EXPECT_THAT(result.FormatError(), HasSubstr(error_substr)); } } // Returns a vector of test cases for the ListsCheckerLibraryTest. // Returns both positive and negative test cases for the lists functions. std::vector createListsCheckerParams() { return { // lists.distinct() {R"([1,2,3,4,4].distinct() == [1,2,3,4])"}, {R"('abc'.distinct() == [1,2,3,4])", "no matching overload for 'distinct'"}, {R"([1,2,3,4,4].distinct() == 'abc')", "no matching overload for '_==_'"}, {R"([1,2,3,4,4].distinct(1) == [1,2,3,4])", "undeclared reference"}, // lists.flatten() {R"([1,2,3,4].flatten() == [1,2,3,4])"}, {R"([1,2,3,4].flatten(1) == [1,2,3,4])"}, {R"('abc'.flatten() == [1,2,3,4])", "no matching overload for 'flatten'"}, {R"([1,2,3,4].flatten() == 'abc')", "no matching overload for '_==_'"}, {R"('abc'.flatten(1) == [1,2,3,4])", "no matching overload for 'flatten'"}, {R"([1,2,3,4].flatten('abc') == [1,2,3,4])", "no matching overload for 'flatten'"}, {R"([1,2,3,4].flatten(1) == 'abc')", "no matching overload"}, // lists.range() {R"(lists.range(4) == [0,1,2,3])"}, {R"(lists.range('abc') == [])", "no matching overload for 'lists.range'"}, {R"(lists.range(4) == 'abc')", "no matching overload for '_==_'"}, {R"(lists.range(4, 4) == [0,1,2,3])", "undeclared reference"}, // lists.reverse() {R"([1,2,3,4].reverse() == [4,3,2,1])"}, {R"('abc'.reverse() == [])", "no matching overload for 'reverse'"}, {R"([1,2,3,4].reverse() == 'abc')", "no matching overload for '_==_'"}, {R"([1,2,3,4].reverse(1) == [4,3,2,1])", "undeclared reference"}, // lists.slice() {R"([1,2,3,4].slice(0, 4) == [1,2,3,4])"}, {R"('abc'.slice(0, 4) == [1,2,3,4])", "no matching overload for 'slice'"}, {R"([1,2,3,4].slice('abc', 4) == [1,2,3,4])", "no matching overload for 'slice'"}, {R"([1,2,3,4].slice(0, 'abc') == [1,2,3,4])", "no matching overload for 'slice'"}, {R"([1,2,3,4].slice(0, 4) == 'abc')", "no matching overload for '_==_'"}, {R"([1,2,3,4].slice(0, 2, 3) == [1,2,3,4])", "undeclared reference"}, // lists.sort() {R"([1,2,3,4].sort() == [1,2,3,4])"}, {R"([TestAllTypes{}, TestAllTypes{}].sort() == [])", "no matching overload for 'sort'"}, {R"('abc'.sort() == [])", "no matching overload for 'sort'"}, {R"([1,2,3,4].sort() == 'abc')", "no matching overload for '_==_'"}, {R"([1,2,3,4].sort(2) == [1,2,3,4])", "undeclared reference"}, // sortBy macro {R"([1,2,3,4].sortBy(x, -x) == [4,3,2,1])"}, {R"([TestAllTypes{}, TestAllTypes{}].sortBy(x, x) == [])", "no matching overload for '@sortByAssociatedKeys'"}, {R"( [TestAllTypes{single_int64: 2}, TestAllTypes{single_int64: 1}] .sortBy(x, x.single_int64) == [TestAllTypes{single_int64: 1}, TestAllTypes{single_int64: 2}])"}, }; } INSTANTIATE_TEST_SUITE_P(ListsCheckerLibraryTest, ListsCheckerLibraryTest, ValuesIn(createListsCheckerParams())); struct ListsExtensionVersionTestCase { std::string expr; std::vector expected_supported_versions; }; class ListsExtensionVersionTest : public ::testing::TestWithParam {}; TEST_P(ListsExtensionVersionTest, ListsExtensionVersions) { const ListsExtensionVersionTestCase& test_case = GetParam(); for (int version = 0; version <= cel::extensions::kListsExtensionLatestVersion; ++version) { CompilerLibrary compiler_library = ListsCompilerLibrary(version); ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), CompilerOptions())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(test_case.expr)); if (absl::c_contains(test_case.expected_supported_versions, version)) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << "Expected no issues for expr: " << test_case.expr << " at version: " << version << " but got: " << result.FormatError(); } else { EXPECT_THAT(result.GetIssues(), Contains(Property(&TypeCheckIssue::message, HasSubstr("undeclared reference")))); } } }; std::vector CreateListsExtensionVersionParams() { return { ListsExtensionVersionTestCase{ .expr = "[0,1,2,3].slice(0, 2)", .expected_supported_versions = {0, 1, 2}, }, ListsExtensionVersionTestCase{ .expr = "[[0]].flatten()", .expected_supported_versions = {1, 2}, }, ListsExtensionVersionTestCase{ .expr = "[[0]].flatten(1)", .expected_supported_versions = {1, 2}, }, ListsExtensionVersionTestCase{ .expr = "[1,2,3,4].sort()", .expected_supported_versions = {2}, }, ListsExtensionVersionTestCase{ .expr = "[1,2,3,4].sortBy(x, x)", .expected_supported_versions = {2}, }, ListsExtensionVersionTestCase{ .expr = "[1,2,3,4].distinct()", .expected_supported_versions = {2}, }, ListsExtensionVersionTestCase{ .expr = "lists.range(4)", .expected_supported_versions = {2}, }, ListsExtensionVersionTestCase{ .expr = "[1,2,3,4].reverse()", .expected_supported_versions = {2}, }, }; } INSTANTIATE_TEST_SUITE_P(ListsExtensionVersionTest, ListsExtensionVersionTest, ValuesIn(CreateListsExtensionVersionParams())); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/math_ext.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/math_ext.h" #include #include #include #include "absl/base/casts.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/casting.h" #include "common/value.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { using ::google::api::expr::runtime::CelFunctionRegistry; using ::google::api::expr::runtime::CelNumber; using ::google::api::expr::runtime::InterpreterOptions; static constexpr char kMathMin[] = "math.@min"; static constexpr char kMathMax[] = "math.@max"; struct ToValueVisitor { Value operator()(uint64_t v) const { return UintValue{v}; } Value operator()(int64_t v) const { return IntValue{v}; } Value operator()(double v) const { return DoubleValue{v}; } }; Value NumberToValue(CelNumber number) { return number.visit(ToValueVisitor{}); } absl::StatusOr ValueToNumber(const Value& value, absl::string_view function) { if (auto int_value = As(value); int_value) { return CelNumber::FromInt64(int_value->NativeValue()); } if (auto uint_value = As(value); uint_value) { return CelNumber::FromUint64(uint_value->NativeValue()); } if (auto double_value = As(value); double_value) { return CelNumber::FromDouble(double_value->NativeValue()); } return absl::InvalidArgumentError( absl::StrCat(function, " arguments must be numeric")); } CelNumber MinNumber(CelNumber v1, CelNumber v2) { if (v2 < v1) { return v2; } return v1; } Value MinValue(CelNumber v1, CelNumber v2) { return NumberToValue(MinNumber(v1, v2)); } template Value Identity(T v1) { return NumberToValue(CelNumber(v1)); } template Value Min(T v1, U v2) { return MinValue(CelNumber(v1), CelNumber(v2)); } absl::StatusOr MinList( const ListValue& values, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(auto iterator, values.NewIterator()); if (!iterator->HasNext()) { return ErrorValue( absl::InvalidArgumentError("math.@min argument must not be empty")); } Value value; CEL_RETURN_IF_ERROR( iterator->Next(descriptor_pool, message_factory, arena, &value)); absl::StatusOr current = ValueToNumber(value, kMathMin); if (!current.ok()) { return ErrorValue{current.status()}; } CelNumber min = *current; while (iterator->HasNext()) { CEL_RETURN_IF_ERROR( iterator->Next(descriptor_pool, message_factory, arena, &value)); absl::StatusOr other = ValueToNumber(value, kMathMin); if (!other.ok()) { return ErrorValue{other.status()}; } min = MinNumber(min, *other); } return NumberToValue(min); } CelNumber MaxNumber(CelNumber v1, CelNumber v2) { if (v2 > v1) { return v2; } return v1; } Value MaxValue(CelNumber v1, CelNumber v2) { return NumberToValue(MaxNumber(v1, v2)); } template Value Max(T v1, U v2) { return MaxValue(CelNumber(v1), CelNumber(v2)); } absl::StatusOr MaxList( const ListValue& values, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(auto iterator, values.NewIterator()); if (!iterator->HasNext()) { return ErrorValue( absl::InvalidArgumentError("math.@max argument must not be empty")); } Value value; CEL_RETURN_IF_ERROR( iterator->Next(descriptor_pool, message_factory, arena, &value)); absl::StatusOr current = ValueToNumber(value, kMathMax); if (!current.ok()) { return ErrorValue{current.status()}; } CelNumber min = *current; while (iterator->HasNext()) { CEL_RETURN_IF_ERROR( iterator->Next(descriptor_pool, message_factory, arena, &value)); absl::StatusOr other = ValueToNumber(value, kMathMax); if (!other.ok()) { return ErrorValue{other.status()}; } min = MaxNumber(min, *other); } return NumberToValue(min); } template absl::Status RegisterCrossNumericMin(FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Min, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Min, registry))); return absl::OkStatus(); } template absl::Status RegisterCrossNumericMax(FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMax, Max, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMax, Max, registry))); return absl::OkStatus(); } double CeilDouble(double value) { return std::ceil(value); } double FloorDouble(double value) { return std::floor(value); } double RoundDouble(double value) { return std::round(value); } double TruncDouble(double value) { return std::trunc(value); } double SqrtDouble(double value) { return std::sqrt(value); } double SqrtInt(int64_t value) { return std::sqrt(value); } double SqrtUint(uint64_t value) { return std::sqrt(value); } bool IsInfDouble(double value) { return std::isinf(value); } bool IsNaNDouble(double value) { return std::isnan(value); } bool IsFiniteDouble(double value) { return std::isfinite(value); } double AbsDouble(double value) { return std::fabs(value); } Value AbsInt(int64_t value) { if (ABSL_PREDICT_FALSE(value == std::numeric_limits::min())) { return ErrorValue(absl::InvalidArgumentError("integer overflow")); } return IntValue(value < 0 ? -value : value); } uint64_t AbsUint(uint64_t value) { return value; } double SignDouble(double value) { if (std::isnan(value)) { return value; } if (value == 0.0) { return 0.0; } return std::signbit(value) ? -1.0 : 1.0; } int64_t SignInt(int64_t value) { return value < 0 ? -1 : value > 0 ? 1 : 0; } uint64_t SignUint(uint64_t value) { return value == 0 ? 0 : 1; } int64_t BitAndInt(int64_t lhs, int64_t rhs) { return lhs & rhs; } uint64_t BitAndUint(uint64_t lhs, uint64_t rhs) { return lhs & rhs; } int64_t BitOrInt(int64_t lhs, int64_t rhs) { return lhs | rhs; } uint64_t BitOrUint(uint64_t lhs, uint64_t rhs) { return lhs | rhs; } int64_t BitXorInt(int64_t lhs, int64_t rhs) { return lhs ^ rhs; } uint64_t BitXorUint(uint64_t lhs, uint64_t rhs) { return lhs ^ rhs; } int64_t BitNotInt(int64_t value) { return ~value; } uint64_t BitNotUint(uint64_t value) { return ~value; } Value BitShiftLeftInt(int64_t lhs, int64_t rhs) { if (ABSL_PREDICT_FALSE(rhs < 0)) { return ErrorValue(absl::InvalidArgumentError( absl::StrCat("math.bitShiftLeft() invalid negative shift: ", rhs))); } if (rhs > 63) { return IntValue(0); } return IntValue(lhs << static_cast(rhs)); } Value BitShiftLeftUint(uint64_t lhs, int64_t rhs) { if (ABSL_PREDICT_FALSE(rhs < 0)) { return ErrorValue(absl::InvalidArgumentError( absl::StrCat("math.bitShiftLeft() invalid negative shift: ", rhs))); } if (rhs > 63) { return UintValue(0); } return UintValue(lhs << static_cast(rhs)); } Value BitShiftRightInt(int64_t lhs, int64_t rhs) { if (ABSL_PREDICT_FALSE(rhs < 0)) { return ErrorValue(absl::InvalidArgumentError( absl::StrCat("math.bitShiftRight() invalid negative shift: ", rhs))); } if (rhs > 63) { return IntValue(0); } // We do not perform a sign extension shift, per the spec we just do the same // thing as uint. return IntValue(absl::bit_cast(absl::bit_cast(lhs) >> static_cast(rhs))); } Value BitShiftRightUint(uint64_t lhs, int64_t rhs) { if (ABSL_PREDICT_FALSE(rhs < 0)) { return ErrorValue(absl::InvalidArgumentError( absl::StrCat("math.bitShiftRight() invalid negative shift: ", rhs))); } if (rhs > 63) { return UintValue(0); } return UintValue(lhs >> static_cast(rhs)); } } // namespace absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, const RuntimeOptions& options, int version) { CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Identity, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Identity, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Identity, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Min, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Min, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Min, registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); CEL_RETURN_IF_ERROR(( UnaryFunctionAdapter, ListValue>::RegisterGlobalOverload(kMathMin, MinList, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMax, Identity, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMax, Identity, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMax, Identity, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMax, Max, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMax, Max, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( kMathMax, Max, registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); CEL_RETURN_IF_ERROR(( UnaryFunctionAdapter, ListValue>::RegisterGlobalOverload(kMathMax, MaxList, registry))); if (version == 0) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.ceil", CeilDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.floor", FloorDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.round", RoundDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.trunc", TruncDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.isInf", IsInfDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.isNaN", IsNaNDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.isFinite", IsFiniteDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.abs", AbsDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.abs", AbsInt, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.abs", AbsUint, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.sign", SignDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.sign", SignInt, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.sign", SignUint, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitAnd", BitAndInt, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitAnd", BitAndUint, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitOr", BitOrInt, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitOr", BitOrUint, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitXor", BitXorInt, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitXor", BitXorUint, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.bitNot", BitNotInt, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.bitNot", BitNotUint, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitShiftLeft", BitShiftLeftInt, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitShiftLeft", BitShiftLeftUint, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitShiftRight", BitShiftRightInt, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitShiftRight", BitShiftRightUint, registry))); if (version == 1) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.sqrt", SqrtDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.sqrt", SqrtInt, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.sqrt", SqrtUint, registry))); return absl::OkStatus(); } absl::Status RegisterMathExtensionFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { return RegisterMathExtensionFunctions( registry->InternalGetRegistry(), google::api::expr::runtime::ConvertToRuntimeOptions(options)); } } // namespace cel::extensions ================================================ FILE: extensions/math_ext.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "extensions/math_ext_decls.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel::extensions { // Register extension functions for supporting mathematical operations above // and beyond the set defined in the CEL standard environment. absl::Status RegisterMathExtensionFunctions( FunctionRegistry& registry, const RuntimeOptions& options, int version = kMathExtensionLatestVersion); absl::Status RegisterMathExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_ ================================================ FILE: extensions/math_ext_decls.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/math_ext_decls.h" #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "checker/internal/builtins_arena.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "common/type_kind.h" #include "compiler/compiler.h" #include "extensions/math_ext_macros.h" #include "internal/status_macros.h" #include "parser/parser_interface.h" namespace cel::extensions { namespace { constexpr char kMathExtensionName[] = "cel.lib.ext.math"; const Type& ListIntType() { static absl::NoDestructor kInstance( ListType(checker_internal::BuiltinsArena(), IntType())); return *kInstance; } const Type& ListDoubleType() { static absl::NoDestructor kInstance( ListType(checker_internal::BuiltinsArena(), DoubleType())); return *kInstance; } const Type& ListUintType() { static absl::NoDestructor kInstance( ListType(checker_internal::BuiltinsArena(), UintType())); return *kInstance; } std::string OverloadTypeName(const Type& type) { switch (type.kind()) { case cel::TypeKind::kInt: return "int"; case TypeKind::kDouble: return "double"; case TypeKind::kUint: return "uint"; case TypeKind::kList: return absl::StrCat("list_", OverloadTypeName(type.AsList()->GetElement())); default: return "unsupported"; } } absl::Status AddMinMaxDecls(TypeCheckerBuilder& builder) { const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; const Type kListNumerics[] = {ListIntType(), ListDoubleType(), ListUintType()}; constexpr char kMinOverloadPrefix[] = "math_@min_"; constexpr char kMaxOverloadPrefix[] = "math_@max_"; FunctionDecl min_decl; min_decl.set_name("math.@min"); FunctionDecl max_decl; max_decl.set_name("math.@max"); for (const Type& type : kNumerics) { // Unary overloads CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type)), type, type))); CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type)), type, type))); // Pairwise overloads for (const Type& other_type : kNumerics) { Type out_type = DynType(); if (type.kind() == other_type.kind()) { out_type = type; } CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type), "_", OverloadTypeName(other_type)), out_type, type, other_type))); CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type), "_", OverloadTypeName(other_type)), out_type, type, other_type))); } } // List overloads for (const Type& type : kListNumerics) { CEL_RETURN_IF_ERROR(min_decl.AddOverload(MakeOverloadDecl( absl::StrCat(kMinOverloadPrefix, OverloadTypeName(type)), type.AsList()->GetElement(), type))); CEL_RETURN_IF_ERROR(max_decl.AddOverload(MakeOverloadDecl( absl::StrCat(kMaxOverloadPrefix, OverloadTypeName(type)), type.AsList()->GetElement(), type))); } CEL_RETURN_IF_ERROR(builder.AddFunction(min_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(max_decl)); return absl::OkStatus(); } absl::Status AddSignednessDecls(TypeCheckerBuilder& builder) { const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; FunctionDecl sign_decl; sign_decl.set_name("math.sign"); FunctionDecl abs_decl; abs_decl.set_name("math.abs"); for (const Type& type : kNumerics) { CEL_RETURN_IF_ERROR(sign_decl.AddOverload(MakeOverloadDecl( absl::StrCat("math_sign_", OverloadTypeName(type)), type, type))); CEL_RETURN_IF_ERROR(abs_decl.AddOverload(MakeOverloadDecl( absl::StrCat("math_abs_", OverloadTypeName(type)), type, type))); } CEL_RETURN_IF_ERROR(builder.AddFunction(sign_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(abs_decl)); return absl::OkStatus(); } absl::Status AddSqrtDecls(TypeCheckerBuilder& builder) { const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; FunctionDecl sqrt_decl; sqrt_decl.set_name("math.sqrt"); for (const Type& type : kNumerics) { CEL_RETURN_IF_ERROR(sqrt_decl.AddOverload( MakeOverloadDecl(absl::StrCat("math_sqrt_", OverloadTypeName(type)), DoubleType(), type))); } CEL_RETURN_IF_ERROR(builder.AddFunction(sqrt_decl)); return absl::OkStatus(); } absl::Status AddFloatingPointDecls(TypeCheckerBuilder& builder) { // Rounding CEL_ASSIGN_OR_RETURN( auto ceil_decl, MakeFunctionDecl( "math.ceil", MakeOverloadDecl("math_ceil_double", DoubleType(), DoubleType()))); CEL_ASSIGN_OR_RETURN( auto floor_decl, MakeFunctionDecl( "math.floor", MakeOverloadDecl("math_floor_double", DoubleType(), DoubleType()))); CEL_ASSIGN_OR_RETURN( auto round_decl, MakeFunctionDecl( "math.round", MakeOverloadDecl("math_round_double", DoubleType(), DoubleType()))); CEL_ASSIGN_OR_RETURN( auto trunc_decl, MakeFunctionDecl( "math.trunc", MakeOverloadDecl("math_trunc_double", DoubleType(), DoubleType()))); // FP helpers CEL_ASSIGN_OR_RETURN( auto is_inf_decl, MakeFunctionDecl( "math.isInf", MakeOverloadDecl("math_isInf_double", BoolType(), DoubleType()))); CEL_ASSIGN_OR_RETURN( auto is_nan_decl, MakeFunctionDecl( "math.isNaN", MakeOverloadDecl("math_isNaN_double", BoolType(), DoubleType()))); CEL_ASSIGN_OR_RETURN( auto is_finite_decl, MakeFunctionDecl( "math.isFinite", MakeOverloadDecl("math_isFinite_double", BoolType(), DoubleType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(ceil_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(floor_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(round_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(trunc_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(is_inf_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(is_nan_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(is_finite_decl)); return absl::OkStatus(); } absl::Status AddBitwiseDecls(TypeCheckerBuilder& builder) { const Type kBitwiseTypes[] = {IntType(), UintType()}; FunctionDecl bit_and_decl; bit_and_decl.set_name("math.bitAnd"); FunctionDecl bit_or_decl; bit_or_decl.set_name("math.bitOr"); FunctionDecl bit_xor_decl; bit_xor_decl.set_name("math.bitXor"); FunctionDecl bit_not_decl; bit_not_decl.set_name("math.bitNot"); FunctionDecl bit_lshift_decl; bit_lshift_decl.set_name("math.bitShiftLeft"); FunctionDecl bit_rshift_decl; bit_rshift_decl.set_name("math.bitShiftRight"); for (const Type& type : kBitwiseTypes) { CEL_RETURN_IF_ERROR(bit_and_decl.AddOverload( MakeOverloadDecl(absl::StrCat("math_bitAnd_", OverloadTypeName(type), "_", OverloadTypeName(type)), type, type, type))); CEL_RETURN_IF_ERROR(bit_or_decl.AddOverload( MakeOverloadDecl(absl::StrCat("math_bitOr_", OverloadTypeName(type), "_", OverloadTypeName(type)), type, type, type))); CEL_RETURN_IF_ERROR(bit_xor_decl.AddOverload( MakeOverloadDecl(absl::StrCat("math_bitXor_", OverloadTypeName(type), "_", OverloadTypeName(type)), type, type, type))); CEL_RETURN_IF_ERROR(bit_not_decl.AddOverload( MakeOverloadDecl(absl::StrCat("math_bitNot_", OverloadTypeName(type), "_", OverloadTypeName(type)), type, type))); CEL_RETURN_IF_ERROR(bit_lshift_decl.AddOverload(MakeOverloadDecl( absl::StrCat("math_bitShiftLeft_", OverloadTypeName(type), "_int"), type, type, IntType()))); CEL_RETURN_IF_ERROR(bit_rshift_decl.AddOverload(MakeOverloadDecl( absl::StrCat("math_bitShiftRight_", OverloadTypeName(type), "_int"), type, type, IntType()))); } CEL_RETURN_IF_ERROR(builder.AddFunction(bit_and_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(bit_or_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(bit_xor_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(bit_not_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(bit_lshift_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(bit_rshift_decl)); return absl::OkStatus(); } absl::Status AddMathExtensionDeclarations(TypeCheckerBuilder& builder, int version) { CEL_RETURN_IF_ERROR(AddMinMaxDecls(builder)); if (version == 0) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(AddSignednessDecls(builder)); CEL_RETURN_IF_ERROR(AddFloatingPointDecls(builder)); CEL_RETURN_IF_ERROR(AddBitwiseDecls(builder)); if (version == 1) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(AddSqrtDecls(builder)); return absl::OkStatus(); } absl::Status AddMathExtensionMacros(ParserBuilder& builder, int version) { for (const auto& m : math_macros()) { // At the moment, all macros are supported in all versions. When we add a // new macro, we must add a version check here. CEL_RETURN_IF_ERROR(builder.AddMacro(m)); } return absl::OkStatus(); } } // namespace // Configuration for cel::Compiler to enable the math extension declarations. CompilerLibrary MathCompilerLibrary(int version) { return CompilerLibrary( kMathExtensionName, [version](ParserBuilder& builder) { return AddMathExtensionMacros(builder, version); }, [version](TypeCheckerBuilder& builder) { return AddMathExtensionDeclarations(builder, version); }); } // Configuration for cel::TypeChecker to enable the math extension declarations. CheckerLibrary MathCheckerLibrary(int version) { return { .id = kMathExtensionName, .configure = [version](TypeCheckerBuilder& builder) { return AddMathExtensionDeclarations(builder, version); }, }; } } // namespace cel::extensions ================================================ FILE: extensions/math_ext_decls.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ #include "checker/type_checker_builder.h" #include "compiler/compiler.h" namespace cel::extensions { constexpr int kMathExtensionLatestVersion = 2; // Configuration for cel::Compiler to enable the math extension declarations. CompilerLibrary MathCompilerLibrary(int version = kMathExtensionLatestVersion); // Configuration for cel::TypeChecker to enable the math extension declarations. CheckerLibrary MathCheckerLibrary(int version = kMathExtensionLatestVersion); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_DECLS_H_ ================================================ FILE: extensions/math_ext_macros.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/math_ext_macros.h" #include #include #include "absl/functional/overload.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "common/ast.h" #include "common/constant.h" #include "parser/macro.h" #include "parser/macro_expr_factory.h" namespace cel::extensions { namespace { static constexpr absl::string_view kMathNamespace = "math"; static constexpr absl::string_view kLeast = "least"; static constexpr absl::string_view kGreatest = "greatest"; static constexpr char kMathMin[] = "math.@min"; static constexpr char kMathMax[] = "math.@max"; bool IsTargetNamespace(const Expr &target) { return target.has_ident_expr() && target.ident_expr().name() == kMathNamespace; } bool IsValidArgType(const Expr &arg) { return absl::visit( absl::Overload([](const UnspecifiedExpr &) -> bool { return false; }, [](const Constant &const_expr) -> bool { return const_expr.has_double_value() || const_expr.has_int_value() || const_expr.has_uint_value(); }, [](const ListExpr &) -> bool { return false; }, [](const StructExpr &) -> bool { return false; }, [](const MapExpr &) -> bool { return false; }, // This is intended for call and select expressions. [](const auto &) -> bool { return true; }), arg.kind()); } absl::optional CheckInvalidArgs(MacroExprFactory &factory, absl::string_view macro, absl::Span arguments) { for (const auto &argument : arguments) { if (!IsValidArgType(argument)) { return factory.ReportErrorAt( argument, absl::StrCat(macro, " simple literal arguments must be numeric")); } } return absl::nullopt; } bool IsListLiteralWithValidArgs(const Expr &arg) { if (const auto *list_expr = arg.has_list_expr() ? &arg.list_expr() : nullptr; list_expr) { if (list_expr->elements().empty()) { return false; } for (const auto &element : list_expr->elements()) { if (!IsValidArgType(element.expr())) { return false; } } return true; } return false; } } // namespace std::vector math_macros() { absl::StatusOr least = Macro::ReceiverVarArg( kLeast, [](MacroExprFactory &factory, Expr &target, absl::Span arguments) -> absl::optional { if (!IsTargetNamespace(target)) { return absl::nullopt; } switch (arguments.size()) { case 0: return factory.ReportErrorAt( target, "math.least() requires at least one argument."); case 1: { if (!IsListLiteralWithValidArgs(arguments[0]) && !IsValidArgType(arguments[0])) { return factory.ReportErrorAt( arguments[0], "math.least() invalid single argument value."); } return factory.NewCall(kMathMin, arguments); } case 2: { if (auto error = CheckInvalidArgs(factory, "math.least()", arguments); error) { return std::move(*error); } return factory.NewCall(kMathMin, arguments); } default: if (auto error = CheckInvalidArgs(factory, "math.least()", arguments); error) { return std::move(*error); } std::vector elements; elements.reserve(arguments.size()); for (auto &argument : arguments) { elements.push_back(factory.NewListElement(std::move(argument))); } return factory.NewCall(kMathMin, factory.NewList(std::move(elements))); } }); absl::StatusOr greatest = Macro::ReceiverVarArg( kGreatest, [](MacroExprFactory &factory, Expr &target, absl::Span arguments) -> absl::optional { if (!IsTargetNamespace(target)) { return absl::nullopt; } switch (arguments.size()) { case 0: { return factory.ReportErrorAt( target, "math.greatest() requires at least one argument."); } case 1: { if (!IsListLiteralWithValidArgs(arguments[0]) && !IsValidArgType(arguments[0])) { return factory.ReportErrorAt( arguments[0], "math.greatest() invalid single argument value."); } return factory.NewCall(kMathMax, arguments); } case 2: { if (auto error = CheckInvalidArgs(factory, "math.greatest()", arguments); error) { return std::move(*error); } return factory.NewCall(kMathMax, arguments); } default: { if (auto error = CheckInvalidArgs(factory, "math.greatest()", arguments); error) { return std::move(*error); } std::vector elements; elements.reserve(arguments.size()); for (auto &argument : arguments) { elements.push_back(factory.NewListElement(std::move(argument))); } return factory.NewCall(kMathMax, factory.NewList(std::move(elements))); } } }); return {*least, *greatest}; } } // namespace cel::extensions ================================================ FILE: extensions/math_ext_macros.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_MACROS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_MACROS_H_ #include #include "absl/status/status.h" #include "parser/macro.h" #include "parser/macro_registry.h" #include "parser/options.h" namespace cel::extensions { // math_macros() returns the namespaced helper macros for math.least() and // math.greatest(). std::vector math_macros(); inline absl::Status RegisterMathMacros(MacroRegistry& registry, const ParserOptions&) { return registry.RegisterMacros(math_macros()); } } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_MACROS_H_ ================================================ FILE: extensions/math_ext_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/math_ext.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/standard_library.h" #include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/decl.h" #include "common/function_descriptor.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/testing/matchers.h" #include "extensions/math_ext_decls.h" #include "extensions/math_ext_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::google::api::expr::parser::ParseWithMacros; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelFunction; using ::google::api::expr::runtime::CelFunctionDescriptor; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::ContainerBackedListImpl; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::api::expr::runtime::test::EqualsCelValue; using ::google::protobuf::Arena; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::ValuesIn; constexpr absl::string_view kMathMin = "math.@min"; constexpr absl::string_view kMathMax = "math.@max"; struct TestCase { absl::string_view operation; CelValue arg1; absl::optional arg2; CelValue result; }; TestCase MinCase(CelValue v1, CelValue v2, CelValue result) { return TestCase{kMathMin, v1, v2, result}; } TestCase MinCase(CelValue list, CelValue result) { return TestCase{kMathMin, list, absl::nullopt, result}; } TestCase MaxCase(CelValue v1, CelValue v2, CelValue result) { return TestCase{kMathMax, v1, v2, result}; } TestCase MaxCase(CelValue list, CelValue result) { return TestCase{kMathMax, list, absl::nullopt, result}; } struct MacroTestCase { absl::string_view expr; absl::string_view err = ""; }; std::string FormatIssues(const cel::ValidationResult& result) { std::string issues; for (const auto& issue : result.GetIssues()) { if (!issues.empty()) { absl::StrAppend(&issues, "\n", issue.ToDisplayString(*result.GetSource())); } else { issues = issue.ToDisplayString(*result.GetSource()); } } return issues; } class TestFunction : public CelFunction { public: explicit TestFunction(absl::string_view name) : CelFunction(MakeDescriptor(name)) {} static FunctionDescriptor MakeDescriptor(absl::string_view name) { return FunctionDescriptor(name, true, {CelValue::Type::kBool, CelValue::Type::kInt64, CelValue::Type::kInt64}); } absl::Status Evaluate(absl::Span args, CelValue* result, Arena* arena) const override { *result = CelValue::CreateBool(true); return absl::OkStatus(); } }; // Test function used to test macro collision and non-expansion. constexpr absl::string_view kGreatest = "greatest"; std::unique_ptr CreateGreatestFunction() { return std::make_unique(kGreatest); } constexpr absl::string_view kLeast = "least"; std::unique_ptr CreateLeastFunction() { return std::make_unique(kLeast); } Expr CallExprOneArg(absl::string_view operation) { Expr expr; auto call = expr.mutable_call_expr(); call->set_function(operation); auto arg = call->add_args(); auto ident = arg->mutable_ident_expr(); ident->set_name("a"); return expr; } Expr CallExprTwoArgs(absl::string_view operation) { Expr expr; auto call = expr.mutable_call_expr(); call->set_function(operation); auto arg = call->add_args(); auto ident = arg->mutable_ident_expr(); ident->set_name("a"); arg = call->add_args(); ident = arg->mutable_ident_expr(); ident->set_name("b"); return expr; } void ExpectResult(const TestCase& test_case) { Expr expr; Activation activation; activation.InsertValue("a", test_case.arg1); if (test_case.arg2.has_value()) { activation.InsertValue("b", *test_case.arg2); expr = CallExprTwoArgs(test_case.operation); } else { expr = CallExprOneArg(test_case.operation); } SourceInfo source_info; InterpreterOptions options; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterMathExtensionFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expression, builder->CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(auto value, cel_expression->Evaluate(activation, &arena)); if (!test_case.result.IsError()) { EXPECT_THAT(value, EqualsCelValue(test_case.result)); } else { auto expected = test_case.result.ErrorOrDie(); EXPECT_THAT(*value.ErrorOrDie(), StatusIs(expected->code(), HasSubstr(expected->message()))); } } using MathExtParamsTest = testing::TestWithParam; TEST_P(MathExtParamsTest, MinMaxTests) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( MathExtParamsTest, MathExtParamsTest, testing::ValuesIn({ MinCase(CelValue::CreateInt64(3L), CelValue::CreateInt64(2L), CelValue::CreateInt64(2L)), MinCase(CelValue::CreateInt64(-1L), CelValue::CreateUint64(2u), CelValue::CreateInt64(-1L)), MinCase(CelValue::CreateInt64(-1L), CelValue::CreateDouble(-1.1), CelValue::CreateDouble(-1.1)), MinCase(CelValue::CreateDouble(-2.0), CelValue::CreateDouble(-1.1), CelValue::CreateDouble(-2.0)), MinCase(CelValue::CreateDouble(3.1), CelValue::CreateInt64(2), CelValue::CreateInt64(2)), MinCase(CelValue::CreateDouble(2.5), CelValue::CreateUint64(2u), CelValue::CreateUint64(2u)), MinCase(CelValue::CreateUint64(2u), CelValue::CreateDouble(-1.1), CelValue::CreateDouble(-1.1)), MinCase(CelValue::CreateUint64(3u), CelValue::CreateInt64(20), CelValue::CreateUint64(3u)), MinCase(CelValue::CreateUint64(4u), CelValue::CreateUint64(2u), CelValue::CreateUint64(2u)), MinCase(CelValue::CreateInt64(2L), CelValue::CreateUint64(2u), CelValue::CreateInt64(2L)), MinCase(CelValue::CreateInt64(-1L), CelValue::CreateDouble(-1.0), CelValue::CreateInt64(-1L)), MinCase(CelValue::CreateDouble(2.0), CelValue::CreateInt64(2), CelValue::CreateDouble(2.0)), MinCase(CelValue::CreateDouble(2.0), CelValue::CreateUint64(2u), CelValue::CreateDouble(2.0)), MinCase(CelValue::CreateUint64(2u), CelValue::CreateDouble(2.0), CelValue::CreateUint64(2u)), MinCase(CelValue::CreateUint64(3u), CelValue::CreateInt64(3), CelValue::CreateUint64(3u)), MaxCase(CelValue::CreateInt64(3L), CelValue::CreateInt64(2L), CelValue::CreateInt64(3L)), MaxCase(CelValue::CreateInt64(-1L), CelValue::CreateUint64(2u), CelValue::CreateUint64(2u)), MaxCase(CelValue::CreateInt64(-1L), CelValue::CreateDouble(-1.1), CelValue::CreateInt64(-1L)), MaxCase(CelValue::CreateDouble(-2.0), CelValue::CreateDouble(-1.1), CelValue::CreateDouble(-1.1)), MaxCase(CelValue::CreateDouble(3.1), CelValue::CreateInt64(2), CelValue::CreateDouble(3.1)), MaxCase(CelValue::CreateDouble(2.5), CelValue::CreateUint64(2u), CelValue::CreateDouble(2.5)), MaxCase(CelValue::CreateUint64(2u), CelValue::CreateDouble(-1.1), CelValue::CreateUint64(2u)), MaxCase(CelValue::CreateUint64(3u), CelValue::CreateInt64(20), CelValue::CreateInt64(20)), MaxCase(CelValue::CreateUint64(4u), CelValue::CreateUint64(2u), CelValue::CreateUint64(4u)), MaxCase(CelValue::CreateInt64(2L), CelValue::CreateUint64(2u), CelValue::CreateInt64(2L)), MaxCase(CelValue::CreateInt64(-1L), CelValue::CreateDouble(-1.0), CelValue::CreateInt64(-1L)), MaxCase(CelValue::CreateDouble(2.0), CelValue::CreateInt64(2), CelValue::CreateDouble(2.0)), MaxCase(CelValue::CreateDouble(2.0), CelValue::CreateUint64(2u), CelValue::CreateDouble(2.0)), MaxCase(CelValue::CreateUint64(2u), CelValue::CreateDouble(2.0), CelValue::CreateUint64(2u)), MaxCase(CelValue::CreateUint64(3u), CelValue::CreateInt64(3), CelValue::CreateUint64(3u)), })); TEST(MathExtTest, MinMaxList) { ContainerBackedListImpl single_item_list({CelValue::CreateInt64(1)}); ExpectResult(MinCase(CelValue::CreateList(&single_item_list), CelValue::CreateInt64(1))); ExpectResult(MaxCase(CelValue::CreateList(&single_item_list), CelValue::CreateInt64(1))); ContainerBackedListImpl list({CelValue::CreateInt64(1), CelValue::CreateUint64(2u), CelValue::CreateDouble(-1.1)}); ExpectResult( MinCase(CelValue::CreateList(&list), CelValue::CreateDouble(-1.1))); ExpectResult( MaxCase(CelValue::CreateList(&list), CelValue::CreateUint64(2u))); absl::Status empty_list_err = absl::InvalidArgumentError("argument must not be empty"); CelValue err_value = CelValue::CreateError(&empty_list_err); ContainerBackedListImpl empty_list({}); ExpectResult(MinCase(CelValue::CreateList(&empty_list), err_value)); ExpectResult(MaxCase(CelValue::CreateList(&empty_list), err_value)); absl::Status bad_arg_err = absl::InvalidArgumentError("arguments must be numeric"); err_value = CelValue::CreateError(&bad_arg_err); ContainerBackedListImpl bad_single_item({CelValue::CreateBool(true)}); ExpectResult(MinCase(CelValue::CreateList(&bad_single_item), err_value)); ExpectResult(MaxCase(CelValue::CreateList(&bad_single_item), err_value)); ContainerBackedListImpl bad_middle_item({CelValue::CreateInt64(1), CelValue::CreateBool(false), CelValue::CreateDouble(-1.1)}); ExpectResult(MinCase(CelValue::CreateList(&bad_middle_item), err_value)); ExpectResult(MaxCase(CelValue::CreateList(&bad_middle_item), err_value)); } using MathExtMacroParamsTest = testing::TestWithParam; TEST_P(MathExtMacroParamsTest, ParserTests) { const MacroTestCase& test_case = GetParam(); auto result = ParseWithMacros(test_case.expr, cel::extensions::math_macros(), ""); if (!test_case.err.empty()) { EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_case.err))); return; } ASSERT_OK(result); ParsedExpr parsed_expr = *result; Expr expr = parsed_expr.expr(); SourceInfo source_info = parsed_expr.source_info(); InterpreterOptions options; options.enable_qualified_identifier_rewrites = true; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateGreatestFunction())); ASSERT_OK(builder->GetRegistry()->Register(CreateLeastFunction())); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK(RegisterMathExtensionFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expression, builder->CreateExpression(&expr, &source_info)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(auto value, cel_expression->Evaluate(activation, &arena)); ASSERT_TRUE(value.IsBool()); EXPECT_EQ(value.BoolOrDie(), true); } TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { const MacroTestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN( auto compiler_builder, cel::NewCompilerBuilder(internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(MathCompilerLibrary()), IsOk()); // Add test functions that check macro (non-)expansion. ASSERT_OK_AND_ASSIGN( auto least_decl, MakeFunctionDecl("least", MakeMemberOverloadDecl("bool_least_int_int", /*result*/ BoolType(), /*receiver*/ BoolType(), IntType(), IntType()))); ASSERT_OK_AND_ASSIGN(auto greatest_decl, MakeFunctionDecl("greatest", MakeMemberOverloadDecl( "bool_greatest_int_int", /*result*/ BoolType(), /*receiver*/ BoolType(), IntType(), IntType()))); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddFunction(least_decl), IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddFunction(greatest_decl), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); auto result = compiler->Compile(test_case.expr, ""); if (!test_case.err.empty()) { EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_case.err))); return; } ASSERT_THAT(result, IsOk()); ASSERT_TRUE(result->IsValid()) << FormatIssues(*result); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( auto runtime_builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_THAT( RegisterMathExtensionFunctions(runtime_builder.function_registry(), opts), IsOk()); ASSERT_THAT( runtime_builder.function_registry().Register( TestFunction::MakeDescriptor(kGreatest), CreateGreatestFunction()), IsOk()); ASSERT_THAT( runtime_builder.function_registry().Register( TestFunction::MakeDescriptor(kLeast), CreateGreatestFunction()), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(*result->ReleaseAst())); google::protobuf::Arena arena; cel::Activation activation; ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value.IsBool()); EXPECT_EQ(value.GetBool(), true); } INSTANTIATE_TEST_SUITE_P( MathExtMacrosParamsTest, MathExtMacroParamsTest, testing::ValuesIn( {// Tests for math.least {"math.least(-0.5) == -0.5"}, {"math.least(-1) == -1"}, {"math.least(1u) == 1u"}, {"math.least(42.0, -0.5) == -0.5"}, {"math.least(-1, 0) == -1"}, {"math.least(-1, -1) == -1"}, {"math.least(1u, 42u) == 1u"}, {"math.least(42.0, -0.5, -0.25) == -0.5"}, {"math.least(-1, 0, 1) == -1"}, {"math.least(-1, -1, -1) == -1"}, {"math.least(1u, 42u, 0u) == 0u"}, // math.least two arg overloads across type. {"math.least(1, 1.0) == 1"}, {"math.least(1, -2.0) == -2.0"}, {"math.least(2, 1u) == 1u"}, {"math.least(1.5, 2) == 1.5"}, {"math.least(1.5, -2) == -2"}, {"math.least(2.5, 1u) == 1u"}, {"math.least(1u, 2) == 1u"}, {"math.least(1u, -2) == -2"}, {"math.least(2u, 2.5) == 2u"}, // math.least with dynamic values across type. {"math.least(1u, dyn(42)) == 1"}, {"math.least(1u, dyn(42), dyn(0.0)) == 0u"}, // math.least with a list literal. {"math.least([1u, 42u, 0u]) == 0u"}, // math.least errors { "math.least()", "math.least() requires at least one argument.", }, { "math.least('hello')", "math.least() invalid single argument value.", }, { "math.least({})", "math.least() invalid single argument value", }, { "math.least([])", "math.least() invalid single argument value", }, { "math.least([1, true])", "math.least() invalid single argument value", }, { "math.least(1, true)", "math.least() simple literal arguments must be numeric", }, { "math.least(1, 2, true)", "math.least() simple literal arguments must be numeric", }, // Tests for math.greatest {"math.greatest(-0.5) == -0.5"}, {"math.greatest(-1) == -1"}, {"math.greatest(1u) == 1u"}, {"math.greatest(42.0, -0.5) == 42.0"}, {"math.greatest(-1, 0) == 0"}, {"math.greatest(-1, -1) == -1"}, {"math.greatest(1u, 42u) == 42u"}, {"math.greatest(42.0, -0.5, -0.25) == 42.0"}, {"math.greatest(-1, 0, 1) == 1"}, {"math.greatest(-1, -1, -1) == -1"}, {"math.greatest(1u, 42u, 0u) == 42u"}, // math.least two arg overloads across type. {"math.greatest(1, 1.0) == 1"}, {"math.greatest(1, -2.0) == 1"}, {"math.greatest(2, 1u) == 2"}, {"math.greatest(1.5, 2) == 2"}, {"math.greatest(1.5, -2) == 1.5"}, {"math.greatest(2.5, 1u) == 2.5"}, {"math.greatest(1u, 2) == 2"}, {"math.greatest(1u, -2) == 1u"}, {"math.greatest(2u, 2.5) == 2.5"}, // math.greatest with dynamic values across type. {"math.greatest(1u, dyn(42)) == 42.0"}, {"math.greatest(1u, dyn(0.0), 0u) == 1"}, // math.greatest with a list literal {"math.greatest([1u, dyn(0.0), 0u]) == 1"}, // math.greatest errors { "math.greatest()", "math.greatest() requires at least one argument.", }, { "math.greatest('hello')", "math.greatest() invalid single argument value.", }, { "math.greatest({})", "math.greatest() invalid single argument value", }, { "math.greatest([])", "math.greatest() invalid single argument value", }, { "math.greatest([1, true])", "math.greatest() invalid single argument value", }, { "math.greatest(1, true)", "math.greatest() simple literal arguments must be numeric", }, { "math.greatest(1, 2, true)", "math.greatest() simple literal arguments must be numeric", }, // Call signatures which trigger macro expansion, but which do not // get expanded. The function just returns true. { "false.greatest(1,2)", }, { "true.least(1,2)", }, // Basic coverage for function definitions. Behavior is tested in the // conformance tests. {"math.sign(-12) == -1"}, {"math.sign(0u) == 0u"}, {"math.sign(42.01) == 1.0"}, {"math.abs(-12) == 12"}, {"math.abs(0u) == 0u"}, {"math.abs(42.01) == 42.01"}, {"math.ceil(42.01) == 43.0"}, {"math.floor(42.01) == 42.0"}, {"math.round(42.5) == 43.0"}, {"math.sqrt(49.0) == 7.0"}, {"math.sqrt(0) == 0.0"}, {"math.sqrt(1) == 1.0"}, {"math.sqrt(25u) == 5.0"}, {"math.sqrt(38.44) == 6.2"}, {"math.isNaN(math.sqrt(-15)) == true"}, {"math.trunc(42.0) == 42.0"}, {"math.isInf(42.0 / 0.0) == true"}, {"math.isNaN(double('nan')) == true"}, {"math.isFinite(42.1) == true"}, {"math.bitAnd(3, 1) == 1"}, {"math.bitAnd(3u, 1u) == 1u"}, {"math.bitOr(2, 1) == 3"}, {"math.bitOr(2u, 1u) == 3u"}, {"math.bitXor(3, 1) == 2"}, {"math.bitXor(3u, 1u) == 2u"}, {"math.bitNot(2) == -3"}, {"math.bitAnd(math.bitNot(0x3u), 0xFFu) == 0xFCu"}, {"math.bitShiftLeft(1, 1) == 2"}, {"math.bitShiftLeft(1u, 1) == 2u"}, {"math.bitShiftRight(4, 1) == 2"}, {"math.bitShiftRight(4u, 1) == 2u"}})); struct MathExtensionVersionTestCase { std::string expr; std::vector expected_supported_versions; }; class MathExtensionVersionTest : public ::testing::TestWithParam {}; TEST_P(MathExtensionVersionTest, MathExtensionVersions) { const MathExtensionVersionTestCase& test_case = GetParam(); for (int version = 0; version <= cel::extensions::kMathExtensionLatestVersion; ++version) { CompilerLibrary compiler_library = MathCompilerLibrary(version); ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), CompilerOptions())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(test_case.expr)); if (absl::c_contains(test_case.expected_supported_versions, version)) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << "Expected no issues for expr: " << test_case.expr << " at version: " << version << " but got: " << result.FormatError(); } else { EXPECT_THAT(result.GetIssues(), Contains(Property(&TypeCheckIssue::message, HasSubstr("undeclared reference")))) << "Expected undeclared reference for expr: " << test_case.expr << " at version: " << version; } } }; std::vector CreateMathExtensionVersionParams() { return { MathExtensionVersionTestCase{ .expr = "math.least([0,1,2,3])", .expected_supported_versions = {0, 1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.greatest([0,1,2,3])", .expected_supported_versions = {0, 1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.ceil(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.floor(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.round(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.trunc(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.isInf(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.isNaN(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.isFinite(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.abs(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.sign(1.5)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.bitAnd(1, 1)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.bitOr(1, 1)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.bitXor(1, 1)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.bitNot(1)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.bitShiftLeft(1, 1)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.bitShiftRight(1, 1)", .expected_supported_versions = {1, 2}, }, MathExtensionVersionTestCase{ .expr = "math.sqrt(1.5)", .expected_supported_versions = {2}, }, }; } INSTANTIATE_TEST_SUITE_P(MathExtensionVersionTest, MathExtensionVersionTest, ValuesIn(CreateMathExtensionVersionParams())); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/proto_ext.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/proto_ext.h" #include #include #include #include "absl/functional/overload.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "common/expr.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "parser/macro_expr_factory.h" #include "parser/parser_interface.h" namespace cel::extensions { namespace { static constexpr char kProtoNamespace[] = "proto"; static constexpr char kGetExt[] = "getExt"; static constexpr char kHasExt[] = "hasExt"; absl::optional ValidateExtensionIdentifier(const Expr& expr) { return absl::visit( absl::Overload( [](const SelectExpr& select_expr) -> absl::optional { if (select_expr.test_only()) { return absl::nullopt; } auto op_name = ValidateExtensionIdentifier(select_expr.operand()); if (!op_name.has_value()) { return absl::nullopt; } return absl::StrCat(*op_name, ".", select_expr.field()); }, [](const IdentExpr& ident_expr) -> absl::optional { return ident_expr.name(); }, [](const auto&) -> absl::optional { return absl::nullopt; }), expr.kind()); } absl::optional GetExtensionFieldName(const Expr& expr) { if (const auto* select_expr = expr.has_select_expr() ? &expr.select_expr() : nullptr; select_expr) { return ValidateExtensionIdentifier(expr); } return absl::nullopt; } bool IsExtensionCall(const Expr& target) { if (const auto* ident_expr = target.has_ident_expr() ? &target.ident_expr() : nullptr; ident_expr) { return ident_expr->name() == kProtoNamespace; } return false; } absl::Status ConfigureParser(ParserBuilder& builder) { for (const auto& macro : proto_macros()) { CEL_RETURN_IF_ERROR(builder.AddMacro(macro)); } return absl::OkStatus(); } } // namespace std::vector proto_macros() { absl::StatusOr getExt = Macro::Receiver( kGetExt, 2, [](MacroExprFactory& factory, Expr& target, absl::Span arguments) -> absl::optional { if (!IsExtensionCall(target)) { return absl::nullopt; } auto extFieldName = GetExtensionFieldName(arguments[1]); if (!extFieldName.has_value()) { return factory.ReportErrorAt(arguments[1], "invalid extension field"); } return factory.NewSelect(std::move(arguments[0]), std::move(*extFieldName)); }); absl::StatusOr hasExt = Macro::Receiver( kHasExt, 2, [](MacroExprFactory& factory, Expr& target, absl::Span arguments) -> absl::optional { if (!IsExtensionCall(target)) { return absl::nullopt; } auto extFieldName = GetExtensionFieldName(arguments[1]); if (!extFieldName.has_value()) { return factory.ReportErrorAt(arguments[1], "invalid extension field"); } return factory.NewPresenceTest(std::move(arguments[0]), std::move(*extFieldName)); }); return {*hasExt, *getExt}; } CompilerLibrary ProtoExtCompilerLibrary() { return CompilerLibrary("cel.lib.ext.proto", ConfigureParser); } } // namespace cel::extensions ================================================ FILE: extensions/proto_ext.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTO_EXT_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTO_EXT_H_ #include #include "absl/status/status.h" #include "compiler/compiler.h" #include "parser/macro.h" #include "parser/macro_registry.h" #include "parser/options.h" namespace cel::extensions { // proto_macros returns the macros which are useful for working with protobuf // objects in CEL. Specifically, the proto.getExt() and proto.hasExt() macros. std::vector proto_macros(); // Library for the proto extensions. CompilerLibrary ProtoExtCompilerLibrary(); inline absl::Status RegisterProtoMacros(MacroRegistry& registry, const ParserOptions&) { return registry.RegisterMacros(proto_macros()); } } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTO_EXT_H_ ================================================ FILE: extensions/protobuf/BUILD ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package( # Under active development, not yet being released. default_visibility = ["//visibility:public"], ) licenses(["notice"]) cc_library( name = "memory_manager", srcs = ["memory_manager.cc"], hdrs = ["memory_manager.h"], deps = [ "//common:memory", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "memory_manager_test", srcs = ["memory_manager_test.cc"], deps = [ ":memory_manager", "//common:memory", "//internal:testing", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "ast_converters", hdrs = ["ast_converters.h"], deps = [ "//common:ast", "//common:ast_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "runtime_adapter", srcs = ["runtime_adapter.cc"], hdrs = ["runtime_adapter.h"], deps = [ ":ast_converters", "//internal:status_macros", "//runtime", "//runtime:runtime_builder", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "enum_adapter", srcs = ["enum_adapter.cc"], hdrs = ["enum_adapter.h"], deps = [ "//runtime:type_registry", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "value", hdrs = [ "value.h", ], deps = [ "//common:memory", "//common:type", "//common:value", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_test( name = "value_test", srcs = [ "value_test.cc", ], deps = [ ":value", "//base:attributes", "//common:casting", "//common:value", "//common:value_kind", "//common:value_testing", "//internal:testing", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_test( name = "value_end_to_end_test", srcs = ["value_end_to_end_test.cc"], deps = [ ":runtime_adapter", "//common:value", "//common:value_testing", "//internal:testing", "//parser", "//runtime", "//runtime:activation", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "bind_proto_to_activation", srcs = ["bind_proto_to_activation.cc"], hdrs = ["bind_proto_to_activation.h"], deps = [ ":value", "//common:casting", "//common:value", "//internal:status_macros", "//runtime:activation", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "bind_proto_to_activation_test", srcs = ["bind_proto_to_activation_test.cc"], deps = [ ":bind_proto_to_activation", "//common:casting", "//common:value", "//common:value_testing", "//internal:testing", "//runtime:activation", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( name = "value_testing", testonly = True, hdrs = ["value_testing.h"], deps = [ ":value", "//common:value", "//internal:testing", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "value_testing_test", srcs = ["value_testing_test.cc"], deps = [ ":value", ":value_testing", "//common:value", "//common:value_testing", "//internal:proto_matchers", "//internal:testing", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", ], ) ================================================ FILE: extensions/protobuf/ast_converters.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_AST_CONVERTERS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_AST_CONVERTERS_H_ #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/base/attributes.h" #include "absl/status/statusor.h" #include "common/ast.h" #include "common/ast_proto.h" namespace cel::extensions { // Creates a runtime AST from a parsed-only protobuf AST. // May return a non-ok Status if the AST is malformed (e.g. unset required // fields). ABSL_DEPRECATED("Use cel::CreateAstFromParsedExpr instead.") inline absl::StatusOr> CreateAstFromParsedExpr( const cel::expr::Expr& expr, const cel::expr::SourceInfo* source_info = nullptr) { return cel::CreateAstFromParsedExpr(expr, source_info); } ABSL_DEPRECATED("Use cel::CreateAstFromParsedExpr instead.") inline absl::StatusOr> CreateAstFromParsedExpr( const cel::expr::ParsedExpr& parsed_expr) { return cel::CreateAstFromParsedExpr(parsed_expr); } // Creates a runtime AST from a checked protobuf AST. // May return a non-ok Status if the AST is malformed (e.g. unset required // fields). ABSL_DEPRECATED("Use cel::CreateAstFromCheckedExpr instead.") inline absl::StatusOr> CreateAstFromCheckedExpr( const cel::expr::CheckedExpr& checked_expr) { return cel::CreateAstFromCheckedExpr(checked_expr); } } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_AST_CONVERTERS_H_ ================================================ FILE: extensions/protobuf/bind_proto_to_activation.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/bind_proto_to_activation.h" #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/activation.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions::protobuf_internal { namespace { using ::google::protobuf::Descriptor; absl::StatusOr ShouldBindField( const google::protobuf::FieldDescriptor* field_desc, const StructValue& struct_value, BindProtoUnsetFieldBehavior unset_field_behavior) { if (unset_field_behavior == BindProtoUnsetFieldBehavior::kBindDefaultValue || field_desc->is_repeated()) { return true; } return struct_value.HasFieldByNumber(field_desc->number()); } absl::StatusOr GetFieldValue( const google::protobuf::FieldDescriptor* field_desc, const StructValue& struct_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { // Special case unset any. if (field_desc->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE && field_desc->message_type()->well_known_type() == Descriptor::WELLKNOWNTYPE_ANY) { CEL_ASSIGN_OR_RETURN(bool present, struct_value.HasFieldByNumber(field_desc->number())); if (!present) { return NullValue(); } } return struct_value.GetFieldByNumber(field_desc->number(), descriptor_pool, message_factory, arena); } } // namespace absl::Status BindProtoToActivation( const Descriptor& descriptor, const StructValue& struct_value, BindProtoUnsetFieldBehavior unset_field_behavior, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Activation* absl_nonnull activation) { for (int i = 0; i < descriptor.field_count(); i++) { const google::protobuf::FieldDescriptor* field_desc = descriptor.field(i); CEL_ASSIGN_OR_RETURN( bool should_bind, ShouldBindField(field_desc, struct_value, unset_field_behavior)); if (!should_bind) { continue; } CEL_ASSIGN_OR_RETURN( Value field, GetFieldValue(field_desc, struct_value, descriptor_pool, message_factory, arena)); activation->InsertOrAssignValue(field_desc->name(), std::move(field)); } return absl::OkStatus(); } } // namespace cel::extensions::protobuf_internal ================================================ FILE: extensions/protobuf/bind_proto_to_activation.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "common/casting.h" #include "common/value.h" #include "extensions/protobuf/value.h" #include "internal/status_macros.h" #include "runtime/activation.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { // Option for handling unset fields on the context proto. enum class BindProtoUnsetFieldBehavior { // Bind the message defined default or zero value. kBindDefaultValue, // Skip binding unset fields, no value is bound for the corresponding // variable. kSkip }; namespace protobuf_internal { // Implements binding provided the context message has already // been adapted to a suitable struct value. absl::Status BindProtoToActivation( const google::protobuf::Descriptor& descriptor, const StructValue& struct_value, BindProtoUnsetFieldBehavior unset_field_behavior, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Activation* absl_nonnull activation); } // namespace protobuf_internal // Utility method, that takes a protobuf Message and interprets it as a // namespace, binding its fields to Activation. This is often referred to as a // context message. // // Field names and values become respective names and values of parameters // bound to the Activation object. // Example: // Assume we have a protobuf message of type: // message Person { // int age = 1; // string name = 2; // } // // The sample code snippet will look as follows: // // Person person; // person.set_name("John Doe"); // person.age(42); // // CEL_RETURN_IF_ERROR(BindProtoToActivation(person, value_factory, // activation)); // // After this snippet, activation will have two parameters bound: // "name", with string value of "John Doe" // "age", with int value of 42. // // The default behavior for unset fields is to skip them. E.g. if the name field // is not set on the Person message, it will not be bound in to the activation. // BindProtoUnsetFieldBehavior::kBindDefault, will bind the cc proto api default // for the field (either an explicit default value or a type specific default). // // For repeated fields, an unset field is bound as an empty list. template absl::Status BindProtoToActivation( const T& context, BindProtoUnsetFieldBehavior unset_field_behavior, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Activation* absl_nonnull activation) { static_assert(std::is_base_of_v); // TODO(uncreated-issue/68): for simplicity, just convert the whole message to a // struct value. For performance, may be better to convert members as needed. CEL_ASSIGN_OR_RETURN( Value parent, ProtoMessageToValue(context, descriptor_pool, message_factory, arena)); if (!InstanceOf(parent)) { return absl::InvalidArgumentError( absl::StrCat("context is a well-known type: ", context.GetTypeName())); } const StructValue& struct_value = Cast(parent); const google::protobuf::Descriptor* descriptor = context.GetDescriptor(); if (descriptor == nullptr) { return absl::InvalidArgumentError( absl::StrCat("context missing descriptor: ", context.GetTypeName())); } return protobuf_internal::BindProtoToActivation( *descriptor, struct_value, unset_field_behavior, descriptor_pool, message_factory, arena, activation); } template absl::Status BindProtoToActivation( const T& context, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Activation* absl_nonnull activation) { return BindProtoToActivation(context, BindProtoUnsetFieldBehavior::kSkip, descriptor_pool, message_factory, arena, activation); } } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_BIND_PROTO_TO_ACTIVATION_H_ ================================================ FILE: extensions/protobuf/bind_proto_to_activation_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/bind_proto_to_activation.h" #include "google/protobuf/wrappers.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/types/optional.h" #include "common/casting.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "runtime/activation.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto2::TestAllTypes; using ::cel::test::IntValueIs; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Optional; using BindProtoToActivationTest = common_internal::ValueTest<>; TEST_F(BindProtoToActivationTest, BindProtoToActivation) { TestAllTypes test_all_types; test_all_types.set_single_int64(123); Activation activation; ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT(activation.FindVariable("single_int64", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IntValueIs(123)))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationWktUnsupported) { google::protobuf::Int64Value int64_value; int64_value.set_value(123); Activation activation; EXPECT_THAT(BindProtoToActivation(int64_value, descriptor_pool(), message_factory(), arena(), &activation), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("google.protobuf.Int64Value"))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationSkip) { TestAllTypes test_all_types; test_all_types.set_single_int64(123); Activation activation; ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT(activation.FindVariable("single_int32", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(activation.FindVariable("single_sint32", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationDefault) { TestAllTypes test_all_types; test_all_types.set_single_int64(123); Activation activation; ASSERT_THAT( BindProtoToActivation( test_all_types, BindProtoUnsetFieldBehavior::kBindDefaultValue, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); // from test_all_types.proto // optional int32 single_int32 = 1 [default = -32]; EXPECT_THAT(activation.FindVariable("single_int32", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IntValueIs(-32)))); EXPECT_THAT(activation.FindVariable("single_sint32", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IntValueIs(0)))); } // Special case any fields. Mirrors go evaluator behavior. TEST_F(BindProtoToActivationTest, BindProtoToActivationDefaultAny) { TestAllTypes test_all_types; test_all_types.set_single_int64(123); Activation activation; ASSERT_THAT( BindProtoToActivation( test_all_types, BindProtoUnsetFieldBehavior::kBindDefaultValue, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT(activation.FindVariable("single_any", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(test::IsNullValue()))); } MATCHER_P(IsListValueOfSize, size, "") { const Value& v = arg; auto value = As(v); if (!value) { return false; } auto s = value->Size(); return s.ok() && *s == size; } TEST_F(BindProtoToActivationTest, BindProtoToActivationRepeated) { TestAllTypes test_all_types; test_all_types.add_repeated_int64(123); test_all_types.add_repeated_int64(456); test_all_types.add_repeated_int64(789); Activation activation; ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT(activation.FindVariable("repeated_int64", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsListValueOfSize(3)))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationRepeatedEmpty) { TestAllTypes test_all_types; test_all_types.set_single_int64(123); Activation activation; ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT(activation.FindVariable("repeated_int32", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsListValueOfSize(0)))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationRepeatedComplex) { TestAllTypes test_all_types; auto* nested = test_all_types.add_repeated_nested_message(); nested->set_bb(123); nested = test_all_types.add_repeated_nested_message(); nested->set_bb(456); nested = test_all_types.add_repeated_nested_message(); nested->set_bb(789); Activation activation; ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT( activation.FindVariable("repeated_nested_message", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsListValueOfSize(3)))); } MATCHER_P(IsMapValueOfSize, size, "") { const Value& v = arg; auto value = As(v); if (!value) { return false; } auto s = value->Size(); return s.ok() && *s == size; } TEST_F(BindProtoToActivationTest, BindProtoToActivationMap) { TestAllTypes test_all_types; (*test_all_types.mutable_map_int64_int64())[1] = 2; (*test_all_types.mutable_map_int64_int64())[2] = 4; Activation activation; ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT(activation.FindVariable("map_int64_int64", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsMapValueOfSize(2)))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationMapEmpty) { TestAllTypes test_all_types; test_all_types.set_single_int64(123); Activation activation; ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT(activation.FindVariable("map_int32_int32", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsMapValueOfSize(0)))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationMapComplex) { TestAllTypes test_all_types; TestAllTypes::NestedMessage value; value.set_bb(42); (*test_all_types.mutable_map_int64_message())[1] = value; (*test_all_types.mutable_map_int64_message())[2] = value; Activation activation; ASSERT_THAT(BindProtoToActivation(test_all_types, descriptor_pool(), message_factory(), arena(), &activation), IsOk()); EXPECT_THAT(activation.FindVariable("map_int64_message", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsMapValueOfSize(2)))); } } // namespace } // namespace cel::extensions ================================================ FILE: extensions/protobuf/enum_adapter.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/enum_adapter.h" #include #include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" namespace cel::extensions { absl::Status RegisterProtobufEnum( TypeRegistry& registry, const google::protobuf::EnumDescriptor* enum_descriptor) { if (registry.resolveable_enums().contains(enum_descriptor->full_name())) { return absl::AlreadyExistsError( absl::StrCat(enum_descriptor->full_name(), " already registered.")); } // TODO(uncreated-issue/42): the registry enum implementation runs linear lookups for // constants since this isn't expected to happen at runtime. Consider updating // if / when strong enum typing is implemented. std::vector enumerators; enumerators.reserve(enum_descriptor->value_count()); for (int i = 0; i < enum_descriptor->value_count(); i++) { enumerators.push_back({std::string(enum_descriptor->value(i)->name()), enum_descriptor->value(i)->number()}); } registry.RegisterEnum(enum_descriptor->full_name(), std::move(enumerators)); return absl::OkStatus(); } } // namespace cel::extensions ================================================ FILE: extensions/protobuf/enum_adapter.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_ENUM_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_ENUM_ADAPTER_H_ #include "absl/status/status.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" namespace cel::extensions { // Register a resolveable enum for the given runtime builder. absl::Status RegisterProtobufEnum( TypeRegistry& registry, const google::protobuf::EnumDescriptor* enum_descriptor); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_ENUM_ADAPTER_H_ ================================================ FILE: extensions/protobuf/internal/BUILD ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") package( # Under active development, not yet being released. default_visibility = ["//visibility:public"], ) licenses(["notice"]) cc_library( name = "map_reflection", srcs = ["map_reflection.cc"], hdrs = ["map_reflection.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "qualify", srcs = ["qualify.cc"], hdrs = ["qualify.h"], deps = [ ":map_reflection", "//base:attributes", "//base:builtins", "//common:kind", "//common:memory", "//internal:status_macros", "//runtime:runtime_options", "//runtime/internal:errors", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: extensions/protobuf/internal/map_reflection.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/internal/map_reflection.h" #include "absl/base/nullability.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" namespace google::protobuf::expr { class CelMapReflectionFriend final { public: static bool LookupMapValue(const Reflection& reflection, const Message& message, const FieldDescriptor& field, const MapKey& key, MapValueConstRef* value) { return reflection.LookupMapValue(message, &field, key, value); } static bool ContainsMapKey(const Reflection& reflection, const Message& message, const FieldDescriptor& field, const MapKey& key) { return reflection.ContainsMapKey(message, &field, key); } static int MapSize(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field) { return reflection.MapSize(message, &field); } static google::protobuf::ConstMapIterator ConstMapBegin( const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field) { return reflection.ConstMapBegin(&message, &field); } static google::protobuf::ConstMapIterator ConstMapEnd( const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field) { return reflection.ConstMapEnd(&message, &field); } static bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, google::protobuf::Message* message, const google::protobuf::FieldDescriptor& field, const google::protobuf::MapKey& key, google::protobuf::MapValueRef* value) { return reflection.InsertOrLookupMapValue(message, &field, key, value); } static bool DeleteMapValue(const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::MapKey& key) { return reflection->DeleteMapValue(message, field, key); } }; } // namespace google::protobuf::expr namespace cel::extensions::protobuf_internal { bool LookupMapValue(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field, const google::protobuf::MapKey& key, google::protobuf::MapValueConstRef* value) { return google::protobuf::expr::CelMapReflectionFriend::LookupMapValue( reflection, message, field, key, value); } bool ContainsMapKey(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field, const google::protobuf::MapKey& key) { return google::protobuf::expr::CelMapReflectionFriend::ContainsMapKey( reflection, message, field, key); } int MapSize(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field) { return google::protobuf::expr::CelMapReflectionFriend::MapSize(reflection, message, field); } google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field) { return google::protobuf::expr::CelMapReflectionFriend::ConstMapBegin(reflection, message, field); } google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field) { return google::protobuf::expr::CelMapReflectionFriend::ConstMapEnd(reflection, message, field); } bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, google::protobuf::Message* message, const google::protobuf::FieldDescriptor& field, const google::protobuf::MapKey& key, google::protobuf::MapValueRef* value) { return google::protobuf::expr::CelMapReflectionFriend::InsertOrLookupMapValue( reflection, message, field, key, value); } bool DeleteMapValue(const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::MapKey& key) { return google::protobuf::expr::CelMapReflectionFriend::DeleteMapValue( reflection, message, field, key); } } // namespace cel::extensions::protobuf_internal ================================================ FILE: extensions/protobuf/internal/map_reflection.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MAP_REFLECTION_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MAP_REFLECTION_H_ #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #ifndef GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND #error "protobuf library is too old, please update to version 3.15.0 or newer" #endif namespace cel::extensions::protobuf_internal { bool LookupMapValue(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field, const google::protobuf::MapKey& key, google::protobuf::MapValueConstRef* value) ABSL_ATTRIBUTE_NONNULL(); bool ContainsMapKey(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field, const google::protobuf::MapKey& key); int MapSize(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field); google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field); google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field); bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, google::protobuf::Message* message, const google::protobuf::FieldDescriptor& field, const google::protobuf::MapKey& key, google::protobuf::MapValueRef* value) ABSL_ATTRIBUTE_NONNULL(); bool DeleteMapValue(const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::MapKey& key); } // namespace cel::extensions::protobuf_internal #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_MAP_REFLECTION_H_ ================================================ FILE: extensions/protobuf/internal/qualify.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/internal/qualify.h" #include #include #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/attribute.h" #include "base/builtins.h" #include "common/kind.h" #include "common/memory.h" #include "extensions/protobuf/internal/map_reflection.h" #include "internal/status_macros.h" #include "runtime/internal/errors.h" #include "runtime/runtime_options.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #include "google/protobuf/reflection.h" #undef GetMessage namespace cel::extensions::protobuf_internal { namespace { const google::protobuf::FieldDescriptor* GetNormalizedFieldByNumber( const google::protobuf::Descriptor* descriptor, const google::protobuf::Reflection* reflection, int field_number) { const google::protobuf::FieldDescriptor* field_desc = descriptor->FindFieldByNumber(field_number); if (field_desc == nullptr && reflection != nullptr) { field_desc = reflection->FindKnownExtensionByNumber(field_number); } return field_desc; } // JSON container types and Any have special unpacking rules. // // Not considered for qualify traversal for simplicity, but // could be supported in a follow-up if needed. bool IsUnsupportedQualifyType(const google::protobuf::Descriptor& desc) { switch (desc.well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_ANY: case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: return true; default: return false; } } constexpr int kKeyTag = 1; constexpr int kValueTag = 2; bool MatchesMapKeyType(const google::protobuf::FieldDescriptor* key_desc, const cel::AttributeQualifier& key) { switch (key_desc->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: return key.kind() == cel::Kind::kBool; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: // fall through case google::protobuf::FieldDescriptor::CPPTYPE_INT64: return key.kind() == cel::Kind::kInt64; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: // fall through case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: return key.kind() == cel::Kind::kUint64; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: return key.kind() == cel::Kind::kString; default: return false; } } absl::StatusOr> LookupMapValue( const google::protobuf::Message* message, const google::protobuf::Reflection* reflection, const google::protobuf::FieldDescriptor* field_desc, const google::protobuf::FieldDescriptor* key_desc, const cel::AttributeQualifier& key) { if (!MatchesMapKeyType(key_desc, key)) { return runtime_internal::CreateInvalidMapKeyTypeError( key_desc->cpp_type_name()); } std::string proto_key_string; google::protobuf::MapKey proto_key; switch (key_desc->cpp_type()) { case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: proto_key.SetBoolValue(*key.GetBoolKey()); break; case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { int64_t key_value = *key.GetInt64Key(); if (key_value > std::numeric_limits::max() || key_value < std::numeric_limits::lowest()) { return absl::OutOfRangeError("integer overflow"); } proto_key.SetInt32Value(key_value); } break; case google::protobuf::FieldDescriptor::CPPTYPE_INT64: proto_key.SetInt64Value(*key.GetInt64Key()); break; case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { proto_key_string = std::string(*key.GetStringKey()); proto_key.SetStringValue(proto_key_string); } break; case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { uint64_t key_value = *key.GetUint64Key(); if (key_value > std::numeric_limits::max()) { return absl::OutOfRangeError("unsigned integer overflow"); } proto_key.SetUInt32Value(key_value); } break; case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { proto_key.SetUInt64Value(*key.GetUint64Key()); } break; default: return runtime_internal::CreateInvalidMapKeyTypeError( key_desc->cpp_type_name()); } // Look the value up google::protobuf::MapValueConstRef value_ref; bool found = cel::extensions::protobuf_internal::LookupMapValue( *reflection, *message, *field_desc, proto_key, &value_ref); if (!found) { return absl::nullopt; } return value_ref; } bool FieldIsPresent(const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field_desc, const google::protobuf::Reflection* reflection) { if (field_desc->is_map()) { // When the map field appears in a has(msg.map_field) expression, the map // is considered 'present' when it is non-empty. Since maps are repeated // fields they don't participate with standard proto presence testing // since the repeated field is always at least empty. return reflection->FieldSize(*message, field_desc) != 0; } if (field_desc->is_repeated()) { // When the list field appears in a has(msg.list_field) expression, the // list is considered 'present' when it is non-empty. return reflection->FieldSize(*message, field_desc) != 0; } // Standard proto presence test for non-repeated fields. return reflection->HasField(*message, field_desc); } } // namespace absl::Status ProtoQualifyState::ApplySelectQualifier( const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager) { return absl::visit( absl::Overload( [&](const cel::AttributeQualifier& qualifier) -> absl::Status { if (repeated_field_desc_ == nullptr) { return absl::UnimplementedError( "dynamic field access on message not supported"); } return ApplyAttributeQualifer(qualifier, memory_manager); }, [&](const cel::FieldSpecifier& field_specifier) -> absl::Status { if (repeated_field_desc_ != nullptr) { return absl::UnimplementedError( "strong field access on container not supported"); } return ApplyFieldSpecifier(field_specifier, memory_manager); }), qualifier); } absl::Status ProtoQualifyState::ApplyLastQualifierHas( const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager) { const cel::FieldSpecifier* specifier = absl::get_if(&qualifier); return absl::visit( absl::Overload( [&](const cel::AttributeQualifier& qualifier) mutable -> absl::Status { if (qualifier.kind() != cel::Kind::kString || repeated_field_desc_ == nullptr || !repeated_field_desc_->is_map()) { SetResultFromError( runtime_internal::CreateNoMatchingOverloadError("has"), memory_manager); return absl::OkStatus(); } return MapHas(qualifier, memory_manager); }, [&](const cel::FieldSpecifier& field_specifier) mutable -> absl::Status { const auto* field_desc = GetNormalizedFieldByNumber( descriptor_, reflection_, specifier->number); if (field_desc == nullptr) { SetResultFromError( runtime_internal::CreateNoSuchFieldError(specifier->name), memory_manager); return absl::OkStatus(); } SetResultFromBool( FieldIsPresent(message_, field_desc, reflection_)); return absl::OkStatus(); }), qualifier); } absl::Status ProtoQualifyState::ApplyLastQualifierGet( const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager) { return absl::visit( absl::Overload( [&](const cel::AttributeQualifier& attr_qualifier) mutable -> absl::Status { if (repeated_field_desc_ == nullptr) { return absl::UnimplementedError( "dynamic field access on message not supported"); } if (repeated_field_desc_->is_map()) { return ApplyLastQualifierGetMap(attr_qualifier, memory_manager); } return ApplyLastQualifierGetList(attr_qualifier, memory_manager); }, [&](const cel::FieldSpecifier& specifier) mutable -> absl::Status { if (repeated_field_desc_ != nullptr) { return absl::UnimplementedError( "strong field access on container not supported"); } return ApplyLastQualifierMessageGet(specifier, memory_manager); }), qualifier); } absl::Status ProtoQualifyState::ApplyFieldSpecifier( const cel::FieldSpecifier& field_specifier, MemoryManagerRef memory_manager) { const google::protobuf::FieldDescriptor* field_desc = GetNormalizedFieldByNumber( descriptor_, reflection_, field_specifier.number); if (field_desc == nullptr) { SetResultFromError( runtime_internal::CreateNoSuchFieldError(field_specifier.name), memory_manager); return absl::OkStatus(); } if (field_desc->is_repeated()) { repeated_field_desc_ = field_desc; return absl::OkStatus(); } if (field_desc->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE || IsUnsupportedQualifyType(*field_desc->message_type())) { CEL_RETURN_IF_ERROR(SetResultFromField(message_, field_desc, ProtoWrapperTypeOptions::kUnsetNull, memory_manager)); return absl::OkStatus(); } message_ = &reflection_->GetMessage(*message_, field_desc); descriptor_ = message_->GetDescriptor(); reflection_ = message_->GetReflection(); return absl::OkStatus(); } absl::StatusOr ProtoQualifyState::CheckListIndex( const cel::AttributeQualifier& qualifier) const { if (qualifier.kind() != cel::Kind::kInt64) { return runtime_internal::CreateNoMatchingOverloadError( cel::builtin::kIndex); } int index = *qualifier.GetInt64Key(); int size = reflection_->FieldSize(*message_, repeated_field_desc_); if (index < 0 || index >= size) { return absl::InvalidArgumentError( absl::StrCat("index out of bounds: index=", index, " size=", size)); } return index; } absl::Status ProtoQualifyState::ApplyAttributeQualifierList( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { ABSL_DCHECK_NE(repeated_field_desc_, nullptr); ABSL_DCHECK(!repeated_field_desc_->is_map()); ABSL_DCHECK_EQ(repeated_field_desc_->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE); auto index_or = CheckListIndex(qualifier); if (!index_or.ok()) { SetResultFromError(std::move(index_or).status(), memory_manager); return absl::OkStatus(); } if (IsUnsupportedQualifyType(*repeated_field_desc_->message_type())) { CEL_RETURN_IF_ERROR(SetResultFromRepeatedField( message_, repeated_field_desc_, *index_or, memory_manager)); return absl::OkStatus(); } message_ = &reflection_->GetRepeatedMessage(*message_, repeated_field_desc_, *index_or); descriptor_ = message_->GetDescriptor(); reflection_ = message_->GetReflection(); repeated_field_desc_ = nullptr; return absl::OkStatus(); } absl::StatusOr ProtoQualifyState::CheckMapIndex( const cel::AttributeQualifier& qualifier) const { const auto* key_desc = repeated_field_desc_->message_type()->FindFieldByNumber(kKeyTag); CEL_ASSIGN_OR_RETURN( absl::optional value_ref, LookupMapValue(message_, reflection_, repeated_field_desc_, key_desc, qualifier)); if (!value_ref.has_value()) { std::string key_string; absl::StatusOr key_string_or = qualifier.AsString(); if (key_string_or.ok()) { key_string = *key_string_or; } return runtime_internal::CreateNoSuchKeyError(key_string); } return std::move(value_ref).value(); } absl::Status ProtoQualifyState::ApplyAttributeQualifierMap( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { ABSL_DCHECK_NE(repeated_field_desc_, nullptr); ABSL_DCHECK(repeated_field_desc_->is_map()); ABSL_DCHECK_EQ(repeated_field_desc_->cpp_type(), google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE); absl::StatusOr value_ref = CheckMapIndex(qualifier); if (!value_ref.ok()) { SetResultFromError(std::move(value_ref).status(), memory_manager); return absl::OkStatus(); } const auto* value_desc = repeated_field_desc_->message_type()->FindFieldByNumber(kValueTag); if (value_desc->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE || IsUnsupportedQualifyType(*value_desc->message_type())) { CEL_RETURN_IF_ERROR(SetResultFromMapField(message_, value_desc, *value_ref, memory_manager)); return absl::OkStatus(); } message_ = &(value_ref->GetMessageValue()); descriptor_ = message_->GetDescriptor(); reflection_ = message_->GetReflection(); repeated_field_desc_ = nullptr; return absl::OkStatus(); } absl::Status ProtoQualifyState::ApplyAttributeQualifer( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { ABSL_DCHECK_NE(repeated_field_desc_, nullptr); if (repeated_field_desc_->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { return absl::InternalError("Unexpected qualify intermediate type"); } if (repeated_field_desc_->is_map()) { return ApplyAttributeQualifierMap(qualifier, memory_manager); } // else simple repeated return ApplyAttributeQualifierList(qualifier, memory_manager); } absl::Status ProtoQualifyState::MapHas(const cel::AttributeQualifier& key, MemoryManagerRef memory_manager) { const auto* key_desc = repeated_field_desc_->message_type()->FindFieldByNumber(kKeyTag); absl::StatusOr> value_ref = LookupMapValue(message_, reflection_, repeated_field_desc_, key_desc, key); if (!value_ref.ok()) { SetResultFromError(std::move(value_ref).status(), memory_manager); return absl::OkStatus(); } SetResultFromBool(value_ref->has_value()); return absl::OkStatus(); } absl::Status ProtoQualifyState::ApplyLastQualifierMessageGet( const cel::FieldSpecifier& specifier, MemoryManagerRef memory_manager) { const auto* field_desc = GetNormalizedFieldByNumber(descriptor_, reflection_, specifier.number); if (field_desc == nullptr) { SetResultFromError(runtime_internal::CreateNoSuchFieldError(specifier.name), memory_manager); return absl::OkStatus(); } return SetResultFromField(message_, field_desc, ProtoWrapperTypeOptions::kUnsetNull, memory_manager); } absl::Status ProtoQualifyState::ApplyLastQualifierGetList( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { ABSL_DCHECK(!repeated_field_desc_->is_map()); absl::StatusOr index = CheckListIndex(qualifier); if (!index.ok()) { SetResultFromError(std::move(index).status(), memory_manager); return absl::OkStatus(); } return SetResultFromRepeatedField(message_, repeated_field_desc_, *index, memory_manager); } absl::Status ProtoQualifyState::ApplyLastQualifierGetMap( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager) { ABSL_DCHECK(repeated_field_desc_->is_map()); absl::StatusOr value_ref = CheckMapIndex(qualifier); if (!value_ref.ok()) { SetResultFromError(std::move(value_ref).status(), memory_manager); return absl::OkStatus(); } const auto* value_desc = repeated_field_desc_->message_type()->FindFieldByNumber(kValueTag); return SetResultFromMapField(message_, value_desc, *value_ref, memory_manager); } } // namespace cel::extensions::protobuf_internal ================================================ FILE: extensions/protobuf/internal/qualify.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_QUALIFY_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_QUALIFY_H_ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/attribute.h" #include "common/memory.h" #include "runtime/runtime_options.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #include "google/protobuf/reflection.h" namespace cel::extensions::protobuf_internal { class ProtoQualifyState { public: ProtoQualifyState(const google::protobuf::Message* absl_nonnull message, const google::protobuf::Descriptor* absl_nonnull descriptor, const google::protobuf::Reflection* absl_nonnull reflection) : message_(message), descriptor_(descriptor), reflection_(reflection), repeated_field_desc_(nullptr) {} virtual ~ProtoQualifyState() = default; ProtoQualifyState(const ProtoQualifyState&) = delete; ProtoQualifyState& operator=(const ProtoQualifyState&) = delete; absl::Status ApplySelectQualifier(const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager); absl::Status ApplyLastQualifierHas(const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager); absl::Status ApplyLastQualifierGet(const cel::SelectQualifier& qualifier, MemoryManagerRef memory_manager); private: virtual void SetResultFromError(absl::Status status, MemoryManagerRef memory_manager) = 0; virtual void SetResultFromBool(bool value) = 0; virtual absl::Status SetResultFromField( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, ProtoWrapperTypeOptions unboxing_option, MemoryManagerRef memory_manager) = 0; virtual absl::Status SetResultFromRepeatedField( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, int index, MemoryManagerRef memory_manager) = 0; virtual absl::Status SetResultFromMapField( const google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field, const google::protobuf::MapValueConstRef& value, MemoryManagerRef memory_manager) = 0; absl::Status ApplyFieldSpecifier(const cel::FieldSpecifier& field_specifier, MemoryManagerRef memory_manager); absl::StatusOr CheckListIndex( const cel::AttributeQualifier& qualifier) const; absl::Status ApplyAttributeQualifierList( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager); absl::StatusOr CheckMapIndex( const cel::AttributeQualifier& qualifier) const; absl::Status ApplyAttributeQualifierMap( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager); absl::Status ApplyAttributeQualifer(const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager); absl::Status MapHas(const cel::AttributeQualifier& key, MemoryManagerRef memory_manager); absl::Status ApplyLastQualifierMessageGet( const cel::FieldSpecifier& specifier, MemoryManagerRef memory_manager); absl::Status ApplyLastQualifierGetList( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager); absl::Status ApplyLastQualifierGetMap( const cel::AttributeQualifier& qualifier, MemoryManagerRef memory_manager); const google::protobuf::Message* absl_nonnull message_; const google::protobuf::Descriptor* absl_nonnull descriptor_; const google::protobuf::Reflection* absl_nonnull reflection_; const google::protobuf::FieldDescriptor* absl_nullable repeated_field_desc_; }; } // namespace cel::extensions::protobuf_internal #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_QUALIFY_H_ ================================================ FILE: extensions/protobuf/memory_manager.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/memory_manager.h" #include "absl/base/nullability.h" #include "common/memory.h" #include "google/protobuf/arena.h" namespace cel { namespace extensions { MemoryManagerRef ProtoMemoryManager(google::protobuf::Arena* arena) { return arena != nullptr ? MemoryManagerRef::Pooling(arena) : MemoryManagerRef::ReferenceCounting(); } google::protobuf::Arena* absl_nullable ProtoMemoryManagerArena( MemoryManager memory_manager) { return memory_manager.arena(); } } // namespace extensions } // namespace cel ================================================ FILE: extensions/protobuf/memory_manager.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "common/memory.h" #include "google/protobuf/arena.h" namespace cel::extensions { // Returns an appropriate `MemoryManagerRef` wrapping `google::protobuf::Arena`. The // lifetime of objects creating using the resulting `MemoryManagerRef` is tied // to that of `google::protobuf::Arena`. // // IMPORTANT: Passing `nullptr` here will result in getting // `MemoryManagerRef::ReferenceCounting()`. MemoryManager ProtoMemoryManager(google::protobuf::Arena* arena); inline MemoryManager ProtoMemoryManagerRef(google::protobuf::Arena* arena) { return ProtoMemoryManager(arena); } // Gets the underlying `google::protobuf::Arena`. If `MemoryManager` was not created using // either `ProtoMemoryManagerRef` or `ProtoMemoryManager`, this returns // `nullptr`. google::protobuf::Arena* absl_nullable ProtoMemoryManagerArena( MemoryManager memory_manager); // Allocate and construct `T` using the `ProtoMemoryManager` provided as // `memory_manager`. `memory_manager` must be `ProtoMemoryManager` or behavior // is undefined. Unlike `MemoryManager::New`, this method supports arena-enabled // messages. template ABSL_MUST_USE_RESULT T* NewInProtoArena(MemoryManager memory_manager, Args&&... args) { return google::protobuf::Arena::Create(ProtoMemoryManagerArena(memory_manager), std::forward(args)...); } } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_MEMORY_MANAGER_H_ ================================================ FILE: extensions/protobuf/memory_manager_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/memory_manager.h" #include "common/memory.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::testing::Eq; using ::testing::IsNull; using ::testing::NotNull; TEST(ProtoMemoryManager, MemoryManagement) { google::protobuf::Arena arena; auto memory_manager = ProtoMemoryManager(&arena); EXPECT_EQ(memory_manager.memory_management(), MemoryManagement::kPooling); } TEST(ProtoMemoryManager, Arena) { google::protobuf::Arena arena; auto memory_manager = ProtoMemoryManager(&arena); EXPECT_THAT(ProtoMemoryManagerArena(memory_manager), NotNull()); } TEST(ProtoMemoryManagerRef, MemoryManagement) { google::protobuf::Arena arena; auto memory_manager = ProtoMemoryManagerRef(&arena); EXPECT_EQ(memory_manager.memory_management(), MemoryManagement::kPooling); memory_manager = ProtoMemoryManagerRef(nullptr); EXPECT_EQ(memory_manager.memory_management(), MemoryManagement::kReferenceCounting); } TEST(ProtoMemoryManagerRef, Arena) { google::protobuf::Arena arena; auto memory_manager = ProtoMemoryManagerRef(&arena); EXPECT_THAT(ProtoMemoryManagerArena(memory_manager), Eq(&arena)); memory_manager = ProtoMemoryManagerRef(nullptr); EXPECT_THAT(ProtoMemoryManagerArena(memory_manager), IsNull()); } } // namespace } // namespace cel::extensions ================================================ FILE: extensions/protobuf/runtime_adapter.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/runtime_adapter.h" #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "extensions/protobuf/ast_converters.h" #include "internal/status_macros.h" #include "runtime/runtime.h" namespace cel::extensions { absl::StatusOr> ProtobufRuntimeAdapter::CreateProgram( const Runtime& runtime, const cel::expr::CheckedExpr& expr, const Runtime::CreateProgramOptions options) { CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromCheckedExpr(expr)); return runtime.CreateTraceableProgram(std::move(ast), options); } absl::StatusOr> ProtobufRuntimeAdapter::CreateProgram( const Runtime& runtime, const cel::expr::ParsedExpr& expr, const Runtime::CreateProgramOptions options) { CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromParsedExpr(expr)); return runtime.CreateTraceableProgram(std::move(ast), options); } absl::StatusOr> ProtobufRuntimeAdapter::CreateProgram( const Runtime& runtime, const cel::expr::Expr& expr, const cel::expr::SourceInfo* source_info, const Runtime::CreateProgramOptions options) { CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromParsedExpr(expr, source_info)); return runtime.CreateTraceableProgram(std::move(ast), options); } } // namespace cel::extensions ================================================ FILE: extensions/protobuf/runtime_adapter.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_RUNTIME_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_RUNTIME_ADAPTER_H_ #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "google/protobuf/descriptor.h" namespace cel::extensions { // Helper class for cel::Runtime that converts the pb serialization format for // expressions to the internal AST format. class ProtobufRuntimeAdapter { public: // Only to be used for static member functions. ProtobufRuntimeAdapter() = delete; static absl::StatusOr> CreateProgram( const Runtime& runtime, const cel::expr::CheckedExpr& expr, const Runtime::CreateProgramOptions options = {}); static absl::StatusOr> CreateProgram( const Runtime& runtime, const cel::expr::ParsedExpr& expr, const Runtime::CreateProgramOptions options = {}); static absl::StatusOr> CreateProgram( const Runtime& runtime, const cel::expr::Expr& expr, const cel::expr::SourceInfo* source_info = nullptr, const Runtime::CreateProgramOptions options = {}); }; } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_RUNTIME_ADAPTER_H_ ================================================ FILE: extensions/protobuf/value.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Utilities for wrapping and unwrapping cel::Values representing protobuf // message types. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ #include #include #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/meta/type_traits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { // Adapt a protobuf message to a cel::Value. // // Handles unwrapping message types with special meanings in CEL (WKTs). // // T value must be a protobuf message class. template std::enable_if_t>, absl::StatusOr> ProtoMessageToValue(T&& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return Value::FromMessage(std::forward(value), descriptor_pool, message_factory, arena); } inline absl::Status ProtoMessageFromValue(const Value& value, google::protobuf::Message& dest_message) { const auto* dest_descriptor = dest_message.GetDescriptor(); const google::protobuf::Message* src_message = nullptr; if (auto legacy_struct_value = cel::common_internal::AsLegacyStructValue(value); legacy_struct_value) { src_message = legacy_struct_value->message_ptr(); } if (auto parsed_message_value = value.AsParsedMessage(); parsed_message_value) { src_message = cel::to_address(*parsed_message_value); } if (src_message != nullptr) { const auto* src_descriptor = src_message->GetDescriptor(); if (dest_descriptor == src_descriptor) { dest_message.CopyFrom(*src_message); return absl::OkStatus(); } if (dest_descriptor->full_name() == src_descriptor->full_name()) { absl::Cord serialized; if (!src_message->SerializePartialToCord(&serialized)) { return absl::UnknownError(absl::StrCat("failed to serialize message: ", src_descriptor->full_name())); } if (!dest_message.ParsePartialFromCord(serialized)) { return absl::UnknownError(absl::StrCat("failed to parse message: ", dest_descriptor->full_name())); } return absl::OkStatus(); } } return TypeConversionError(value.GetRuntimeType(), MessageType(dest_descriptor)) .NativeValue(); } } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_H_ ================================================ FILE: extensions/protobuf/value_end_to_end_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Functional tests for protobuf backed CEL structs in the default runtime. #include #include #include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/value.h" #include "common/value_testing.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace cel::extensions { namespace { using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::test::BoolValueIs; using ::cel::test::BytesValueIs; using ::cel::test::DoubleValueIs; using ::cel::test::DurationValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::cel::test::IsNullValue; using ::cel::test::ListValueIs; using ::cel::test::MapValueIs; using ::cel::test::StringValueIs; using ::cel::test::StructValueIs; using ::cel::test::TimestampValueIs; using ::cel::test::UintValueIs; using ::cel::test::ValueMatcher; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::_; using ::testing::AnyOf; using ::testing::HasSubstr; using ::testing::TestWithParam; struct TestCase { std::string name; std::string expr; std::string msg_textproto; ValueMatcher matcher; template friend void AbslStringify(S& sink, const TestCase& tc) { sink.Append(tc.name); } }; class ProtobufValueEndToEndTest : public TestWithParam { public: ProtobufValueEndToEndTest() = default; protected: const TestCase& test_case() const { return GetParam(); } google::protobuf::Arena arena_; }; TEST_P(ProtobufValueEndToEndTest, Runner) { TestAllTypes message; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case().msg_textproto, &message)); Activation activation; activation.InsertOrAssignValue( "msg", Value::FromMessage(message, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena_)); RuntimeOptions opts; opts.enable_empty_wrapper_null_unboxing = true; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), opts)); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(test_case().expr)); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena_, activation)); EXPECT_THAT(result, test_case().matcher); } INSTANTIATE_TEST_SUITE_P( Singular, ProtobufValueEndToEndTest, testing::ValuesIn(std::vector{ {"single_int64", "msg.single_int64", R"pb( single_int64: 42 )pb", IntValueIs(42)}, {"single_int64_has", "has(msg.single_int64)", R"pb( single_int64: 42 )pb", BoolValueIs(true)}, {"single_int64_has_false", "has(msg.single_int64)", "", BoolValueIs(false)}, {"single_int32", "msg.single_int32", R"pb( single_int32: 42 )pb", IntValueIs(42)}, {"single_uint64", "msg.single_uint64", R"pb( single_uint64: 42 )pb", UintValueIs(42)}, {"single_uint32", "msg.single_uint32", R"pb( single_uint32: 42 )pb", UintValueIs(42)}, {"single_sint64", "msg.single_sint64", R"pb( single_sint64: 42 )pb", IntValueIs(42)}, {"single_sint32", "msg.single_sint32", R"pb( single_sint32: 42 )pb", IntValueIs(42)}, {"single_fixed64", "msg.single_fixed64", R"pb( single_fixed64: 42 )pb", UintValueIs(42)}, {"single_fixed32", "msg.single_fixed32", R"pb( single_fixed32: 42 )pb", UintValueIs(42)}, {"single_sfixed64", "msg.single_sfixed64", R"pb( single_sfixed64: 42 )pb", IntValueIs(42)}, {"single_sfixed32", "msg.single_sfixed32", R"pb( single_sfixed32: 42 )pb", IntValueIs(42)}, {"single_float", "msg.single_float", R"pb( single_float: 4.25 )pb", DoubleValueIs(4.25)}, {"single_double", "msg.single_double", R"pb( single_double: 4.25 )pb", DoubleValueIs(4.25)}, {"single_bool", "msg.single_bool", R"pb( single_bool: true )pb", BoolValueIs(true)}, {"single_string", "msg.single_string", R"pb( single_string: "Hello 😀" )pb", StringValueIs("Hello 😀")}, {"single_bytes", "msg.single_bytes", R"pb( single_bytes: "Hello" )pb", BytesValueIs("Hello")}, {"wkt_duration", "msg.single_duration", R"pb( single_duration { seconds: 10 } )pb", DurationValueIs(absl::Seconds(10))}, {"wkt_duration_default", "msg.single_duration", "", DurationValueIs(absl::Seconds(0))}, {"wkt_timestamp", "msg.single_timestamp", R"pb( single_timestamp { seconds: 10 } )pb", TimestampValueIs(absl::FromUnixSeconds(10))}, {"wkt_timestamp_default", "msg.single_timestamp", "", TimestampValueIs(absl::UnixEpoch())}, {"wkt_int64", "msg.single_int64_wrapper", R"pb( single_int64_wrapper { value: -20 } )pb", IntValueIs(-20)}, {"wkt_int64_default", "msg.single_int64_wrapper", "", IsNullValue()}, {"wkt_int32", "msg.single_int32_wrapper", R"pb( single_int32_wrapper { value: -10 } )pb", IntValueIs(-10)}, {"wkt_int32_default", "msg.single_int32_wrapper", "", IsNullValue()}, {"wkt_uint64", "msg.single_uint64_wrapper", R"pb( single_uint64_wrapper { value: 10 } )pb", UintValueIs(10)}, {"wkt_uint64_default", "msg.single_uint64_wrapper", "", IsNullValue()}, {"wkt_uint32", "msg.single_uint32_wrapper", R"pb( single_uint32_wrapper { value: 11 } )pb", UintValueIs(11)}, {"wkt_uint32_default", "msg.single_uint32_wrapper", "", IsNullValue()}, {"wkt_float", "msg.single_float_wrapper", R"pb( single_float_wrapper { value: 10.25 } )pb", DoubleValueIs(10.25)}, {"wkt_float_default", "msg.single_float_wrapper", "", IsNullValue()}, {"wkt_double", "msg.single_double_wrapper", R"pb( single_double_wrapper { value: 10.25 } )pb", DoubleValueIs(10.25)}, {"wkt_double_default", "msg.single_double_wrapper", "", IsNullValue()}, {"wkt_bool", "msg.single_bool_wrapper", R"pb( single_bool_wrapper { value: false } )pb", BoolValueIs(false)}, {"wkt_bool_default", "msg.single_bool_wrapper", "", IsNullValue()}, {"wkt_string", "msg.single_string_wrapper", R"pb( single_string_wrapper { value: "abcd" } )pb", StringValueIs("abcd")}, {"wkt_string_default", "msg.single_string_wrapper", "", IsNullValue()}, {"wkt_bytes", "msg.single_bytes_wrapper", R"pb( single_bytes_wrapper { value: "abcd" } )pb", BytesValueIs("abcd")}, {"wkt_bytes_default", "msg.single_bytes_wrapper", "", IsNullValue()}, {"wkt_null", "msg.null_value", R"pb( null_value: NULL_VALUE )pb", IsNullValue()}, {"message_field", "msg.standalone_message", R"pb( standalone_message { bb: 2 } )pb", StructValueIs(_)}, {"message_field_has", "has(msg.standalone_message)", R"pb( standalone_message { bb: 2 } )pb", BoolValueIs(true)}, {"message_field_has_false", "has(msg.standalone_message)", "", BoolValueIs(false)}, {"single_enum", "msg.standalone_enum", R"pb( standalone_enum: BAR )pb", // BAR IntValueIs(1)}})); INSTANTIATE_TEST_SUITE_P( Repeated, ProtobufValueEndToEndTest, testing::ValuesIn(std::vector{ {"repeated_int64", "msg.repeated_int64[0]", R"pb( repeated_int64: 42 )pb", IntValueIs(42)}, {"repeated_int64_has", "has(msg.repeated_int64)", R"pb( repeated_int64: 42 )pb", BoolValueIs(true)}, {"repeated_int64_has_false", "has(msg.repeated_int64)", "", BoolValueIs(false)}, {"repeated_int32", "msg.repeated_int32[0]", R"pb( repeated_int32: 42 )pb", IntValueIs(42)}, {"repeated_uint64", "msg.repeated_uint64[0]", R"pb( repeated_uint64: 42 )pb", UintValueIs(42)}, {"repeated_uint32", "msg.repeated_uint32[0]", R"pb( repeated_uint32: 42 )pb", UintValueIs(42)}, {"repeated_sint64", "msg.repeated_sint64[0]", R"pb( repeated_sint64: 42 )pb", IntValueIs(42)}, {"repeated_sint32", "msg.repeated_sint32[0]", R"pb( repeated_sint32: 42 )pb", IntValueIs(42)}, {"repeated_fixed64", "msg.repeated_fixed64[0]", R"pb( repeated_fixed64: 42 )pb", UintValueIs(42)}, {"repeated_fixed32", "msg.repeated_fixed32[0]", R"pb( repeated_fixed32: 42 )pb", UintValueIs(42)}, {"repeated_sfixed64", "msg.repeated_sfixed64[0]", R"pb( repeated_sfixed64: 42 )pb", IntValueIs(42)}, {"repeated_sfixed32", "msg.repeated_sfixed32[0]", R"pb( repeated_sfixed32: 42 )pb", IntValueIs(42)}, {"repeated_float", "msg.repeated_float[0]", R"pb( repeated_float: 4.25 )pb", DoubleValueIs(4.25)}, {"repeated_double", "msg.repeated_double[0]", R"pb( repeated_double: 4.25 )pb", DoubleValueIs(4.25)}, {"repeated_bool", "msg.repeated_bool[0]", R"pb( repeated_bool: true )pb", BoolValueIs(true)}, {"repeated_string", "msg.repeated_string[0]", R"pb( repeated_string: "Hello 😀" )pb", StringValueIs("Hello 😀")}, {"repeated_bytes", "msg.repeated_bytes[0]", R"pb( repeated_bytes: "Hello" )pb", BytesValueIs("Hello")}, {"wkt_duration", "msg.repeated_duration[0]", R"pb( repeated_duration { seconds: 10 } )pb", DurationValueIs(absl::Seconds(10))}, {"wkt_timestamp", "msg.repeated_timestamp[0]", R"pb( repeated_timestamp { seconds: 10 } )pb", TimestampValueIs(absl::FromUnixSeconds(10))}, {"wkt_int64", "msg.repeated_int64_wrapper[0]", R"pb( repeated_int64_wrapper { value: -20 } )pb", IntValueIs(-20)}, {"wkt_int32", "msg.repeated_int32_wrapper[0]", R"pb( repeated_int32_wrapper { value: -10 } )pb", IntValueIs(-10)}, {"wkt_uint64", "msg.repeated_uint64_wrapper[0]", R"pb( repeated_uint64_wrapper { value: 10 } )pb", UintValueIs(10)}, {"wkt_uint32", "msg.repeated_uint32_wrapper[0]", R"pb( repeated_uint32_wrapper { value: 11 } )pb", UintValueIs(11)}, {"wkt_float", "msg.repeated_float_wrapper[0]", R"pb( repeated_float_wrapper { value: 10.25 } )pb", DoubleValueIs(10.25)}, {"wkt_double", "msg.repeated_double_wrapper[0]", R"pb( repeated_double_wrapper { value: 10.25 } )pb", DoubleValueIs(10.25)}, {"wkt_bool", "msg.repeated_bool_wrapper[0]", R"pb( repeated_bool_wrapper { value: false } )pb", BoolValueIs(false)}, {"wkt_string", "msg.repeated_string_wrapper[0]", R"pb( repeated_string_wrapper { value: "abcd" } )pb", StringValueIs("abcd")}, {"wkt_bytes", "msg.repeated_bytes_wrapper[0]", R"pb( repeated_bytes_wrapper { value: "abcd" } )pb", BytesValueIs("abcd")}, {"wkt_null", "msg.repeated_null_value[0]", R"pb( repeated_null_value: NULL_VALUE )pb", IsNullValue()}, {"message_field", "msg.repeated_nested_message[0]", R"pb( repeated_nested_message { bb: 42 } )pb", StructValueIs(_)}, {"repeated_enum", "msg.repeated_nested_enum[0]", R"pb( repeated_nested_enum: BAR )pb", // BAR IntValueIs(1)}, // Implements CEL list interface {"repeated_size", "msg.repeated_int64.size()", R"pb( repeated_int64: 42 repeated_int64: 43 )pb", IntValueIs(2)}, {"in_repeated", "42 in msg.repeated_int64", R"pb( repeated_int64: 42 repeated_int64: 43 )pb", BoolValueIs(true)}, {"in_repeated_false", "44 in msg.repeated_int64", R"pb( repeated_int64: 42 repeated_int64: 43 )pb", BoolValueIs(false)}, {"repeated_compre_exists", "msg.repeated_int64.exists(x, x > 42)", R"pb( repeated_int64: 42 repeated_int64: 43 )pb", BoolValueIs(true)}, {"repeated_compre_map", "msg.repeated_int64.map(x, x * 2)[0]", R"pb( repeated_int64: 42 repeated_int64: 43 )pb", IntValueIs(84)}, })); INSTANTIATE_TEST_SUITE_P( Maps, ProtobufValueEndToEndTest, testing::ValuesIn(std::vector{ {"map_bool_int64", "msg.map_bool_int64[false]", R"pb( map_bool_int64 { key: false value: 42 } )pb", IntValueIs(42)}, {"map_bool_int64_has", "has(msg.map_bool_int64)", R"pb( map_bool_int64 { key: false value: 42 } )pb", BoolValueIs(true)}, {"map_bool_int64_has_false", "has(msg.map_bool_int64)", "", BoolValueIs(false)}, {"map_bool_int32", "msg.map_bool_int32[false]", R"pb( map_bool_int32 { key: false value: 42 } )pb", IntValueIs(42)}, {"map_bool_uint64", "msg.map_bool_uint64[false]", R"pb( map_bool_uint64 { key: false value: 42 } )pb", UintValueIs(42)}, {"map_bool_uint32", "msg.map_bool_uint32[false]", R"pb( map_bool_uint32 { key: false, value: 42 } )pb", UintValueIs(42)}, {"map_bool_float", "msg.map_bool_float[false]", R"pb( map_bool_float { key: false value: 4.25 } )pb", DoubleValueIs(4.25)}, {"map_bool_double", "msg.map_bool_double[false]", R"pb( map_bool_double { key: false value: 4.25 } )pb", DoubleValueIs(4.25)}, {"map_bool_bool", "msg.map_bool_bool[false]", R"pb( map_bool_bool { key: false value: true } )pb", BoolValueIs(true)}, {"map_bool_string", "msg.map_bool_string[false]", R"pb( map_bool_string { key: false value: "Hello 😀" } )pb", StringValueIs("Hello 😀")}, {"map_bool_bytes", "msg.map_bool_bytes[false]", R"pb( map_bool_bytes { key: false value: "Hello" } )pb", BytesValueIs("Hello")}, {"wkt_duration", "msg.map_bool_duration[false]", R"pb( map_bool_duration { key: false value { seconds: 10 } } )pb", DurationValueIs(absl::Seconds(10))}, {"wkt_timestamp", "msg.map_bool_timestamp[false]", R"pb( map_bool_timestamp { key: false value { seconds: 10 } } )pb", TimestampValueIs(absl::FromUnixSeconds(10))}, {"wkt_int64", "msg.map_bool_int64_wrapper[false]", R"pb( map_bool_int64_wrapper { key: false value { value: -20 } } )pb", IntValueIs(-20)}, {"wkt_int32", "msg.map_bool_int32_wrapper[false]", R"pb( map_bool_int32_wrapper { key: false value { value: -10 } } )pb", IntValueIs(-10)}, {"wkt_uint64", "msg.map_bool_uint64_wrapper[false]", R"pb( map_bool_uint64_wrapper { key: false value { value: 10 } } )pb", UintValueIs(10)}, {"wkt_uint32", "msg.map_bool_uint32_wrapper[false]", R"pb( map_bool_uint32_wrapper { key: false value { value: 11 } } )pb", UintValueIs(11)}, {"wkt_float", "msg.map_bool_float_wrapper[false]", R"pb( map_bool_float_wrapper { key: false value { value: 10.25 } } )pb", DoubleValueIs(10.25)}, {"wkt_double", "msg.map_bool_double_wrapper[false]", R"pb( map_bool_double_wrapper { key: false value { value: 10.25 } } )pb", DoubleValueIs(10.25)}, {"wkt_bool", "msg.map_bool_bool_wrapper[false]", R"pb( map_bool_bool_wrapper { key: false value { value: false } } )pb", BoolValueIs(false)}, {"wkt_string", "msg.map_bool_string_wrapper[false]", R"pb( map_bool_string_wrapper { key: false value { value: "abcd" } } )pb", StringValueIs("abcd")}, {"wkt_bytes", "msg.map_bool_bytes_wrapper[false]", R"pb( map_bool_bytes_wrapper { key: false value { value: "abcd" } } )pb", BytesValueIs("abcd")}, {"wkt_null", "msg.map_bool_null_value[false]", R"pb( map_bool_null_value { key: false value: NULL_VALUE } )pb", IsNullValue()}, {"message_field", "msg.map_bool_message[false]", R"pb( map_bool_message { key: false value { bb: 42 } } )pb", StructValueIs(_)}, {"map_bool_enum", "msg.map_bool_enum[false]", R"pb( map_bool_enum { key: false value: BAR } )pb", // BAR IntValueIs(1)}, // Simplified for remaining key types. {"map_int32_int64", "msg.map_int32_int64[42]", R"pb( map_int32_int64 { key: 42 value: -42 } )pb", IntValueIs(-42)}, {"map_int64_int64", "msg.map_int64_int64[42]", R"pb( map_int64_int64 { key: 42 value: -42 } )pb", IntValueIs(-42)}, {"map_uint32_int64", "msg.map_uint32_int64[42u]", R"pb( map_uint32_int64 { key: 42 value: -42 } )pb", IntValueIs(-42)}, {"map_uint64_int64", "msg.map_uint64_int64[42u]", R"pb( map_uint64_int64 { key: 42 value: -42 } )pb", IntValueIs(-42)}, {"map_string_int64", "msg.map_string_int64['key1']", R"pb( map_string_int64 { key: "key1" value: -42 } )pb", IntValueIs(-42)}, // Implements CEL map {"in_map_int64_true", "42 in msg.map_int64_int64", R"pb( map_int64_int64 { key: 42 value: -42 } map_int64_int64 { key: 43 value: -43 } )pb", BoolValueIs(true)}, {"in_map_int64_false", "44 in msg.map_int64_int64", R"pb( map_int64_int64 { key: 42 value: -42 } map_int64_int64 { key: 43 value: -43 } )pb", BoolValueIs(false)}, {"int_map_int64_compre_exists", "msg.map_int64_int64.exists(key, key > 42)", R"pb( map_int64_int64 { key: 42 value: -42 } map_int64_int64 { key: 43 value: -43 } )pb", BoolValueIs(true)}, {"int_map_int64_compre_map", "msg.map_int64_int64.map(key, key + 20)[0]", R"pb( map_int64_int64 { key: 42 value: -42 } map_int64_int64 { key: 43 value: -43 } )pb", IntValueIs(AnyOf(62, 63))}, {"map_string_key_not_found", "msg.map_string_int64['key2']", R"pb( map_string_int64 { key: "key1" value: -42 } )pb", ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, HasSubstr("Key not found in map")))}, {"map_string_select_key", "msg.map_string_int64.key1", R"pb( map_string_int64 { key: "key1" value: -42 } )pb", IntValueIs(-42)}, {"map_string_has_key", "has(msg.map_string_int64.key1)", R"pb( map_string_int64 { key: "key1" value: -42 } )pb", BoolValueIs(true)}, {"map_string_has_key_false", "has(msg.map_string_int64.key2)", R"pb( map_string_int64 { key: "key1" value: -42 } )pb", BoolValueIs(false)}, {"map_int32_out_of_range", "msg.map_int32_int64[0x1FFFFFFFF]", R"pb( map_int32_int64 { key: 10 value: -42 } )pb", ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, HasSubstr("Key not found in map")))}, {"map_uint32_out_of_range", "msg.map_uint32_int64[0x1FFFFFFFFu]", R"pb( map_uint32_int64 { key: 10 value: -42 } )pb", ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, HasSubstr("Key not found in map")))}})); MATCHER_P(CelSizeIs, size, "") { auto s = arg.Size(); return s.ok() && *s == size; } INSTANTIATE_TEST_SUITE_P( JsonWrappers, ProtobufValueEndToEndTest, testing::ValuesIn(std::vector{ {"single_struct", "msg.single_struct", R"pb( single_struct { fields { key: "field1" value { null_value: NULL_VALUE } } } )pb", MapValueIs(CelSizeIs(1))}, {"single_struct_null_value_field", "msg.single_struct['field1']", R"pb( single_struct { fields { key: "field1" value { null_value: NULL_VALUE } } } )pb", IsNullValue()}, {"single_struct_number_value_field", "msg.single_struct['field1']", R"pb( single_struct { fields { key: "field1" value { number_value: 10.25 } } } )pb", DoubleValueIs(10.25)}, {"single_struct_bool_value_field", "msg.single_struct['field1']", R"pb( single_struct { fields { key: "field1" value { bool_value: true } } } )pb", BoolValueIs(true)}, {"single_struct_string_value_field", "msg.single_struct['field1']", R"pb( single_struct { fields { key: "field1" value { string_value: "abcd" } } } )pb", StringValueIs("abcd")}, {"single_struct_struct_value_field", "msg.single_struct['field1']", R"pb( single_struct { fields { key: "field1" value { struct_value { fields { key: "field2", value: { null_value: NULL_VALUE } } } } } } )pb", MapValueIs(CelSizeIs(1))}, {"single_struct_list_value_field", "msg.single_struct['field1']", R"pb( single_struct { fields { key: "field1" value { list_value { values { null_value: NULL_VALUE } } } } } )pb", ListValueIs(CelSizeIs(1))}, {"single_struct_select_field", "msg.single_struct.field1", R"pb( single_struct { fields { key: "field1" value { bool_value: true } } } )pb", BoolValueIs(true)}, {"single_struct_has_field", "has(msg.single_struct.field1)", R"pb( single_struct { fields { key: "field1" value { bool_value: true } } } )pb", BoolValueIs(true)}, {"single_struct_has_field_false", "has(msg.single_struct.field2)", R"pb( single_struct { fields { key: "field1" value { bool_value: true } } } )pb", BoolValueIs(false)}, {"single_struct_map_size", "msg.single_struct.size()", R"pb( single_struct { fields { key: "field1" value { bool_value: true } } fields { key: "field2" value { bool_value: false } } } )pb", IntValueIs(2)}, {"single_struct_map_in", "'field2' in msg.single_struct", R"pb( single_struct { fields { key: "field1" value { bool_value: true } } fields { key: "field2" value { bool_value: false } } } )pb", BoolValueIs(true)}, {"single_struct_map_compre_exists", "msg.single_struct.exists(key, key == 'field2')", R"pb( single_struct { fields { key: "field1" value { bool_value: true } } fields { key: "field2" value { bool_value: false } } } )pb", BoolValueIs(true)}, {"single_struct_map_compre_map", "'__field1' in msg.single_struct.map(key, '__' + key)", R"pb( single_struct { fields { key: "field1" value { bool_value: true } } fields { key: "field2" value { bool_value: false } } } )pb", BoolValueIs(true)}, {"single_list_value", "msg.list_value", R"pb( list_value { values { null_value: NULL_VALUE } } )pb", ListValueIs(CelSizeIs(1))}, {"single_list_value_index_null", "msg.list_value[0]", R"pb( list_value { values { null_value: NULL_VALUE } } )pb", IsNullValue()}, {"single_list_value_index_number", "msg.list_value[0]", R"pb( list_value { values { number_value: 10.25 } } )pb", DoubleValueIs(10.25)}, {"single_list_value_index_string", "msg.list_value[0]", R"pb( list_value { values { string_value: "abc" } } )pb", StringValueIs("abc")}, {"single_list_value_index_bool", "msg.list_value[0]", R"pb( list_value { values { bool_value: false } } )pb", BoolValueIs(false)}, {"single_list_value_list_size", "msg.list_value.size()", R"pb( list_value { values { bool_value: false } values { bool_value: false } } )pb", IntValueIs(2)}, {"single_list_value_list_in", "10.25 in msg.list_value", R"pb( list_value { values { number_value: 10.0 } values { number_value: 10.25 } } )pb", BoolValueIs(true)}, {"single_list_value_list_compre_exists", "msg.list_value.exists(x, x == 10.25)", R"pb( list_value { values { number_value: 10.0 } values { number_value: 10.25 } } )pb", BoolValueIs(true)}, {"single_list_value_list_compre_map", "msg.list_value.map(x, x + 0.5)[1]", R"pb( list_value { values { number_value: 10.0 } values { number_value: 10.25 } } )pb", DoubleValueIs(10.75)}, {"single_list_value_index_struct", "msg.list_value[0]", R"pb( list_value { values { struct_value { fields { key: "field1" value { null_value: NULL_VALUE } } } } } )pb", MapValueIs(CelSizeIs(1))}, {"single_list_value_index_list", "msg.list_value[0]", R"pb( list_value { values { list_value { values { null_value: NULL_VALUE } } } } )pb", ListValueIs(CelSizeIs(1))}, {"single_json_value_null", "msg.single_value", R"pb( single_value { null_value: NULL_VALUE } )pb", IsNullValue()}, {"single_json_value_number", "msg.single_value", R"pb( single_value { number_value: 13.25 } )pb", DoubleValueIs(13.25)}, {"single_json_value_string", "msg.single_value", R"pb( single_value { string_value: "abcd" } )pb", StringValueIs("abcd")}, {"single_json_value_bool", "msg.single_value", R"pb( single_value { bool_value: false } )pb", BoolValueIs(false)}, {"single_json_value_struct", "msg.single_value", R"pb( single_value { struct_value {} } )pb", MapValueIs(CelSizeIs(0))}, {"single_json_value_list", "msg.single_value", R"pb( single_value { list_value {} } )pb", ListValueIs(CelSizeIs(0))}, })); // TODO(uncreated-issue/66): any support needs the reflection impl for looking up the // type name and corresponding deserializer (outside of the WKTs which are // special cased). INSTANTIATE_TEST_SUITE_P( Any, ProtobufValueEndToEndTest, testing::ValuesIn(std::vector{ {"single_any_wkt_int64", "msg.single_any", R"pb( single_any { [type.googleapis.com/google.protobuf.Int64Value] { value: 42 } } )pb", IntValueIs(42)}, {"single_any_wkt_int32", "msg.single_any", R"pb( single_any { [type.googleapis.com/google.protobuf.Int32Value] { value: 42 } } )pb", IntValueIs(42)}, {"single_any_wkt_uint64", "msg.single_any", R"pb( single_any { [type.googleapis.com/google.protobuf.UInt64Value] { value: 42 } } )pb", UintValueIs(42)}, {"single_any_wkt_uint32", "msg.single_any", R"pb( single_any { [type.googleapis.com/google.protobuf.UInt32Value] { value: 42 } } )pb", UintValueIs(42)}, {"single_any_wkt_double", "msg.single_any", R"pb( single_any { [type.googleapis.com/google.protobuf.DoubleValue] { value: 30.5 } } )pb", DoubleValueIs(30.5)}, {"single_any_wkt_string", "msg.single_any", R"pb( single_any { [type.googleapis.com/google.protobuf.StringValue] { value: "abcd" } } )pb", StringValueIs("abcd")}, {"repeated_any_wkt_string", "msg.repeated_any[0]", R"pb( repeated_any { [type.googleapis.com/google.protobuf.StringValue] { value: "abcd" } } )pb", StringValueIs("abcd")}, {"map_int64_any_wkt_string", "msg.map_int64_any[0]", R"pb( map_int64_any { key: 0 value { [type.googleapis.com/google.protobuf.StringValue] { value: "abcd" } } } )pb", StringValueIs("abcd")}, })); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/protobuf/value_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/value.h" #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "base/attribute.h" #include "common/casting.h" #include "common/value.h" #include "common/value_kind.h" #include "common/value_testing.h" #include "internal/testing.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/text_format.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto2::TestAllTypes; using ::cel::test::BoolValueIs; using ::cel::test::BytesValueIs; using ::cel::test::DoubleValueIs; using ::cel::test::DurationValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::IntValueIs; using ::cel::test::ListValueIs; using ::cel::test::MapValueIs; using ::cel::test::StringValueIs; using ::cel::test::StructValueFieldHas; using ::cel::test::StructValueFieldIs; using ::cel::test::StructValueIs; using ::cel::test::TimestampValueIs; using ::cel::test::UintValueIs; using ::cel::test::ValueKindIs; using ::testing::_; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsTrue; using ::testing::Pair; using ::testing::UnorderedElementsAre; template T ParseTextOrDie(absl::string_view text) { T proto; ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text, &proto)); return proto; } using ProtoValueTest = common_internal::ValueTest<>; class ProtoValueWrapTest : public ProtoValueTest {}; TEST_F(ProtoValueWrapTest, ProtoBoolValueToValue) { google::protobuf::BoolValue message; message.set_value(true); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(Eq(true)))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(Eq(true)))); } TEST_F(ProtoValueWrapTest, ProtoInt32ValueToValue) { google::protobuf::Int32Value message; message.set_value(1); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(Eq(1)))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(Eq(1)))); } TEST_F(ProtoValueWrapTest, ProtoInt64ValueToValue) { google::protobuf::Int64Value message; message.set_value(1); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(Eq(1)))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(Eq(1)))); } TEST_F(ProtoValueWrapTest, ProtoUInt32ValueToValue) { google::protobuf::UInt32Value message; message.set_value(1); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(UintValueIs(Eq(1)))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(UintValueIs(Eq(1)))); } TEST_F(ProtoValueWrapTest, ProtoUInt64ValueToValue) { google::protobuf::UInt64Value message; message.set_value(1); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(UintValueIs(Eq(1)))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(UintValueIs(Eq(1)))); } TEST_F(ProtoValueWrapTest, ProtoFloatValueToValue) { google::protobuf::FloatValue message; message.set_value(1); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(DoubleValueIs(Eq(1)))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(DoubleValueIs(Eq(1)))); } TEST_F(ProtoValueWrapTest, ProtoDoubleValueToValue) { google::protobuf::DoubleValue message; message.set_value(1); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(DoubleValueIs(Eq(1)))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(DoubleValueIs(Eq(1)))); } TEST_F(ProtoValueWrapTest, ProtoBytesValueToValue) { google::protobuf::BytesValue message; message.set_value("foo"); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BytesValueIs(Eq("foo")))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BytesValueIs(Eq("foo")))); } TEST_F(ProtoValueWrapTest, ProtoStringValueToValue) { google::protobuf::StringValue message; message.set_value("foo"); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs(Eq("foo")))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs(Eq("foo")))); } TEST_F(ProtoValueWrapTest, ProtoDurationToValue) { google::protobuf::Duration message; message.set_seconds(1); message.set_nanos(1); EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(DurationValueIs( Eq(absl::Seconds(1) + absl::Nanoseconds(1))))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(DurationValueIs( Eq(absl::Seconds(1) + absl::Nanoseconds(1))))); } TEST_F(ProtoValueWrapTest, ProtoTimestampToValue) { google::protobuf::Timestamp message; message.set_seconds(1); message.set_nanos(1); EXPECT_THAT( ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(TimestampValueIs( Eq(absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1))))); EXPECT_THAT( ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(TimestampValueIs( Eq(absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1))))); } TEST_F(ProtoValueWrapTest, ProtoMessageToValue) { TestAllTypes message; EXPECT_THAT(ProtoMessageToValue(message, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ValueKindIs(Eq(ValueKind::kStruct)))); EXPECT_THAT(ProtoMessageToValue(std::move(message), descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ValueKindIs(Eq(ValueKind::kStruct)))); } TEST_F(ProtoValueWrapTest, GetFieldByName) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb(single_int32: 1, single_int64: 1 single_uint32: 1 single_uint64: 1 single_float: 1 single_double: 1 single_bool: true single_string: "foo" single_bytes: "foo")pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value, StructValueIs(StructValueFieldIs( "single_int32", IntValueIs(Eq(1)), descriptor_pool(), message_factory(), arena()))); EXPECT_THAT(value, StructValueIs(StructValueFieldHas("single_int32", IsTrue()))); EXPECT_THAT(value, StructValueIs(StructValueFieldIs( "single_int64", IntValueIs(Eq(1)), descriptor_pool(), message_factory(), arena()))); EXPECT_THAT(value, StructValueIs(StructValueFieldHas("single_int64", IsTrue()))); EXPECT_THAT(value, StructValueIs(StructValueFieldIs( "single_uint32", UintValueIs(Eq(1)), descriptor_pool(), message_factory(), arena()))); EXPECT_THAT(value, StructValueIs(StructValueFieldHas("single_uint32", IsTrue()))); EXPECT_THAT(value, StructValueIs(StructValueFieldIs( "single_uint64", UintValueIs(Eq(1)), descriptor_pool(), message_factory(), arena()))); EXPECT_THAT(value, StructValueIs(StructValueFieldHas("single_uint64", IsTrue()))); } TEST_F(ProtoValueWrapTest, GetFieldNoSuchField) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue( ParseTextOrDie(R"pb(single_int32: 1)pb"), descriptor_pool(), message_factory(), arena())); ASSERT_THAT(value, StructValueIs(_)); StructValue struct_value = Cast(value); EXPECT_THAT(struct_value.GetFieldByName("does_not_exist", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field"))))); } TEST_F(ProtoValueWrapTest, GetFieldByNumber) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb(single_int32: 1, single_int64: 2 single_uint32: 3 single_uint64: 4 single_float: 1.25 single_double: 1.5 single_bool: true single_string: "foo" single_bytes: "foo")pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value, StructValueIs(_)); StructValue struct_value = Cast(value); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleInt32FieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(1))); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleInt64FieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(IntValueIs(2))); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleUint32FieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(UintValueIs(3))); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleUint64FieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(UintValueIs(4))); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleFloatFieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(DoubleValueIs(1.25))); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleDoubleFieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(DoubleValueIs(1.5))); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleBoolFieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleStringFieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(StringValueIs("foo"))); EXPECT_THAT(struct_value.GetFieldByNumber( TestAllTypes::kSingleBytesFieldNumber, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BytesValueIs("foo"))); } TEST_F(ProtoValueWrapTest, GetFieldByNumberNoSuchField) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb(single_int32: 1, single_int64: 2 single_uint32: 3 single_uint64: 4 single_float: 1.25 single_double: 1.5 single_bool: true single_string: "foo" single_bytes: "foo")pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value, StructValueIs(_)); StructValue struct_value = Cast(value); EXPECT_THAT(struct_value.GetFieldByNumber(999, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field"))))); // Out of range. EXPECT_THAT(struct_value.GetFieldByNumber(0x1ffffffff, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(ErrorValueIs(StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field"))))); } TEST_F(ProtoValueWrapTest, HasFieldByNumber) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue( ParseTextOrDie(R"pb(single_int32: 1, single_int64: 2)pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value, StructValueIs(_)); StructValue struct_value = Cast(value); EXPECT_THAT( struct_value.HasFieldByNumber(TestAllTypes::kSingleInt32FieldNumber), IsOkAndHolds(BoolValue(true))); EXPECT_THAT( struct_value.HasFieldByNumber(TestAllTypes::kSingleInt64FieldNumber), IsOkAndHolds(BoolValue(true))); EXPECT_THAT( struct_value.HasFieldByNumber(TestAllTypes::kSingleStringFieldNumber), IsOkAndHolds(BoolValue(false))); EXPECT_THAT( struct_value.HasFieldByNumber(TestAllTypes::kSingleBytesFieldNumber), IsOkAndHolds(BoolValue(false))); } TEST_F(ProtoValueWrapTest, HasFieldByNumberNoSuchField) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue( ParseTextOrDie(R"pb(single_int32: 1, single_int64: 2)pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value, StructValueIs(_)); StructValue struct_value = Cast(value); // Has returns a status directly instead of a CEL error as in Get. EXPECT_THAT( struct_value.HasFieldByNumber(999), StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field"))); EXPECT_THAT( struct_value.HasFieldByNumber(0x1ffffffff), StatusIs(absl::StatusCode::kNotFound, HasSubstr("no_such_field"))); } TEST_F(ProtoValueWrapTest, ProtoMessageEqual) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb(single_int32: 1, single_int64: 2 )pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto value2, ProtoMessageToValue(ParseTextOrDie( R"pb(single_int32: 1, single_int64: 2 )pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value.Equal(value, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); EXPECT_THAT( value2.Equal(value, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(true))); } TEST_F(ProtoValueWrapTest, ProtoMessageEqualFalse) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb(single_int32: 1, single_int64: 2 )pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto value2, ProtoMessageToValue(ParseTextOrDie( R"pb(single_int32: 2, single_int64: 1 )pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT( value2.Equal(value, descriptor_pool(), message_factory(), arena()), IsOkAndHolds(BoolValueIs(false))); } TEST_F(ProtoValueWrapTest, ProtoMessageForEachField) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb(single_int32: 1, single_int64: 2 )pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value, StructValueIs(_)); StructValue struct_value = Cast(value); std::vector fields; auto cb = [&fields](absl::string_view field, const Value&) -> absl::StatusOr { fields.push_back(std::string(field)); return true; }; ASSERT_THAT(struct_value.ForEachField(cb, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(fields, UnorderedElementsAre("single_int32", "single_int64")); } TEST_F(ProtoValueWrapTest, ProtoMessageQualify) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( standalone_message { bb: 42 } )pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value, StructValueIs(_)); StructValue struct_value = Cast(value); std::vector qualifiers{ FieldSpecifier{TestAllTypes::kStandaloneMessageFieldNumber, "standalone_message"}, FieldSpecifier{TestAllTypes::NestedMessage::kBbFieldNumber, "bb"}}; Value scratch; int count; EXPECT_THAT( struct_value.Qualify(qualifiers, /*presence_test=*/false, descriptor_pool(), message_factory(), arena(), &scratch, &count), IsOk()); EXPECT_THAT(scratch, IntValueIs(42)); } TEST_F(ProtoValueWrapTest, ProtoMessageQualifyHas) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( standalone_message { bb: 42 } )pb"), descriptor_pool(), message_factory(), arena())); EXPECT_THAT(value, StructValueIs(_)); StructValue struct_value = Cast(value); std::vector qualifiers{ FieldSpecifier{TestAllTypes::kStandaloneMessageFieldNumber, "standalone_message"}, FieldSpecifier{TestAllTypes::NestedMessage::kBbFieldNumber, "bb"}}; Value scratch; int count; EXPECT_THAT( struct_value.Qualify(qualifiers, /*presence_test=*/true, descriptor_pool(), message_factory(), arena(), &scratch, &count), IsOk()); EXPECT_THAT(scratch, BoolValueIs(true)); } TEST_F(ProtoValueWrapTest, ProtoInt64MapListKeys) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( map_int64_int64 { key: 10 value: 20 })pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN(auto map_value, Cast(value).GetFieldByName( "map_int64_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(map_value, MapValueIs(_)); ASSERT_OK_AND_ASSIGN(ListValue key_set, Cast(map_value).ListKeys( descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key0, IntValueIs(10)); } TEST_F(ProtoValueWrapTest, ProtoInt32MapListKeys) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( map_int32_int64 { key: 10 value: 20 })pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN(auto map_value, Cast(value).GetFieldByName( "map_int32_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(map_value, MapValueIs(_)); ASSERT_OK_AND_ASSIGN(ListValue key_set, Cast(map_value).ListKeys( descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key0, IntValueIs(10)); } TEST_F(ProtoValueWrapTest, ProtoBoolMapListKeys) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( map_bool_int64 { key: false value: 20 })pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN(auto map_value, Cast(value).GetFieldByName( "map_bool_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(map_value, MapValueIs(_)); ASSERT_OK_AND_ASSIGN(ListValue key_set, Cast(map_value).ListKeys( descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key0, BoolValueIs(false)); } TEST_F(ProtoValueWrapTest, ProtoUint32MapListKeys) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( map_uint32_int64 { key: 11 value: 20 })pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto map_value, Cast(value).GetFieldByName( "map_uint32_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(map_value, MapValueIs(_)); ASSERT_OK_AND_ASSIGN(ListValue key_set, Cast(map_value).ListKeys( descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key0, UintValueIs(11)); } TEST_F(ProtoValueWrapTest, ProtoUint64MapListKeys) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( map_uint64_int64 { key: 11 value: 20 })pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto map_value, Cast(value).GetFieldByName( "map_uint64_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(map_value, MapValueIs(_)); ASSERT_OK_AND_ASSIGN(ListValue key_set, Cast(map_value).ListKeys( descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key0, UintValueIs(11)); } TEST_F(ProtoValueWrapTest, ProtoStringMapListKeys) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue( ParseTextOrDie( R"pb( map_string_int64 { key: "key1" value: 20 })pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto map_value, Cast(value).GetFieldByName( "map_string_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(map_value, MapValueIs(_)); ASSERT_OK_AND_ASSIGN(ListValue key_set, Cast(map_value).ListKeys( descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key_set.Size(), IsOkAndHolds(1)); ASSERT_OK_AND_ASSIGN(Value key0, key_set.Get(0, descriptor_pool(), message_factory(), arena())); EXPECT_THAT(key0, StringValueIs("key1")); } TEST_F(ProtoValueWrapTest, ProtoMapIterator) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( map_int64_int64 { key: 10 value: 20 } map_int64_int64 { key: 12 value: 24 } )pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto field_value, Cast(value).GetFieldByName( "map_int64_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(field_value, MapValueIs(_)); MapValue map_value = Cast(field_value); std::vector keys; ASSERT_OK_AND_ASSIGN(auto iter, map_value.NewIterator()); while (iter->HasNext()) { ASSERT_OK_AND_ASSIGN( keys.emplace_back(), iter->Next(descriptor_pool(), message_factory(), arena())); } EXPECT_THAT(keys, UnorderedElementsAre(IntValueIs(10), IntValueIs(12))); } TEST_F(ProtoValueWrapTest, ProtoMapForEach) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( map_int64_int64 { key: 10 value: 20 } map_int64_int64 { key: 12 value: 24 } )pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto field_value, Cast(value).GetFieldByName( "map_int64_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(field_value, MapValueIs(_)); MapValue map_value = Cast(field_value); std::vector> pairs; auto cb = [&pairs](const Value& key, const Value& value) -> absl::StatusOr { pairs.push_back(std::pair(key, value)); return true; }; ASSERT_THAT( map_value.ForEach(cb, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(pairs, UnorderedElementsAre(Pair(IntValueIs(10), IntValueIs(20)), Pair(IntValueIs(12), IntValueIs(24)))); } TEST_F(ProtoValueWrapTest, ProtoListIterator) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( repeated_int64: 1 repeated_int64: 2 )pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto field_value, Cast(value).GetFieldByName( "repeated_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(field_value, ListValueIs(_)); ListValue list_value = Cast(field_value); std::vector elements; ASSERT_OK_AND_ASSIGN(auto iter, list_value.NewIterator()); while (iter->HasNext()) { ASSERT_OK_AND_ASSIGN( elements.emplace_back(), iter->Next(descriptor_pool(), message_factory(), arena())); } EXPECT_THAT(elements, ElementsAre(IntValueIs(1), IntValueIs(2))); } TEST_F(ProtoValueWrapTest, ProtoListForEachWithIndex) { ASSERT_OK_AND_ASSIGN( auto value, ProtoMessageToValue(ParseTextOrDie( R"pb( repeated_int64: 1 repeated_int64: 2 )pb"), descriptor_pool(), message_factory(), arena())); ASSERT_OK_AND_ASSIGN( auto field_value, Cast(value).GetFieldByName( "repeated_int64", descriptor_pool(), message_factory(), arena())); ASSERT_THAT(field_value, ListValueIs(_)); ListValue list_value = Cast(field_value); std::vector> elements; auto cb = [&elements](size_t index, const Value& value) -> absl::StatusOr { elements.push_back(std::pair(index, value)); return true; }; ASSERT_THAT( list_value.ForEach(cb, descriptor_pool(), message_factory(), arena()), IsOk()); EXPECT_THAT(elements, ElementsAre(Pair(0, IntValueIs(1)), Pair(1, IntValueIs(2)))); } } // namespace } // namespace cel::extensions ================================================ FILE: extensions/protobuf/value_testing.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_TESTING_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_TESTING_H_ #include #include #include "absl/status/status.h" #include "common/value.h" #include "extensions/protobuf/value.h" #include "internal/testing.h" #include "google/protobuf/message.h" namespace cel::extensions::test { template class StructValueAsProtoMatcher { public: using is_gtest_matcher = void; explicit StructValueAsProtoMatcher(testing::Matcher&& m) : m_(std::move(m)) {} bool MatchAndExplain(cel::Value v, testing::MatchResultListener* result_listener) const { MessageType msg; absl::Status s = ProtoMessageFromValue(v, msg); if (!s.ok()) { *result_listener << "cannot convert to " << MessageType::descriptor()->full_name() << ": " << s; return false; } return m_.MatchAndExplain(msg, result_listener); } void DescribeTo(std::ostream* os) const { *os << "matches proto message " << m_; } void DescribeNegationTo(std::ostream* os) const { *os << "does not match proto message " << m_; } private: testing::Matcher m_; }; // Returns a matcher that matches a cel::Value against a proto message. // // Example usage: // // EXPECT_THAT(value, StructValueAsProto(EqualsProto(R"pb( // single_int32: 1 // single_string: "foo" // )pb"))); template inline StructValueAsProtoMatcher StructValueAsProto( testing::Matcher&& m) { static_assert(std::is_base_of_v); return StructValueAsProtoMatcher(std::move(m)); } } // namespace cel::extensions::test #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_VALUE_TESTING_H_ ================================================ FILE: extensions/protobuf/value_testing_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/protobuf/value_testing.h" #include "common/value.h" #include "common/value_testing.h" #include "extensions/protobuf/value.h" #include "internal/proto_matchers.h" #include "internal/testing.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" namespace cel::extensions::test { namespace { using ::cel::expr::conformance::proto2::TestAllTypes; using ::cel::extensions::ProtoMessageToValue; using ::cel::internal::test::EqualsProto; using ProtoValueTestingTest = common_internal::ValueTest<>; TEST_F(ProtoValueTestingTest, StructValueAsProtoSimple) { TestAllTypes test_proto; test_proto.set_single_int32(42); test_proto.set_single_string("foo"); ASSERT_OK_AND_ASSIGN(cel::Value v, ProtoMessageToValue(test_proto, descriptor_pool(), message_factory(), arena())); EXPECT_THAT(v, StructValueAsProto(EqualsProto(R"pb( single_int32: 42 single_string: "foo" )pb"))); } } // namespace } // namespace cel::extensions::test ================================================ FILE: extensions/regex_ext.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/regex_ext.h" #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "checker/internal/builtins_arena.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "internal/casts.h" #include "internal/re2_options.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" #include "validator/regex_validator.h" #include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "re2/re2.h" namespace cel::extensions { namespace { using ::cel::checker_internal::BuiltinsArena; Value Extract(int regex_max_program_size, const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string target_scratch; std::string regex_scratch; absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view regex_view = regex.ToStringView(®ex_scratch); RE2 re2(regex_view, cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) .With(ErrorValueReturn()); const int group_count = re2.NumberOfCapturingGroups(); if (group_count > 1) { return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( "regular expression has more than one capturing group: %s", regex_view))); } // Space for the full match (\0) and the first capture group (\1). absl::string_view submatches[2]; if (re2.Match(target_view, 0, target_view.length(), RE2::UNANCHORED, submatches, 2)) { // Return the capture group if it exists else return the full match. const absl::string_view result_view = (group_count == 1) ? submatches[1] : submatches[0]; return OptionalValue::Of(StringValue::From(result_view, arena), arena); } return OptionalValue::None(); } Value ExtractAll(int regex_max_program_size, const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string target_scratch; std::string regex_scratch; absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view regex_view = regex.ToStringView(®ex_scratch); RE2 re2(regex_view, cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) .With(ErrorValueReturn()); const int group_count = re2.NumberOfCapturingGroups(); if (group_count > 1) { return ErrorValue(absl::InvalidArgumentError(absl::StrFormat( "regular expression has more than one capturing group: %s", regex_view))); } auto builder = NewListValueBuilder(arena); absl::string_view temp_target = target_view; // Space for the full match (\0) and the first capture group (\1). absl::string_view submatches[2]; const int group_to_extract = (group_count == 1) ? 1 : 0; while (re2.Match(temp_target, 0, temp_target.length(), RE2::UNANCHORED, submatches, group_count + 1)) { const absl::string_view& full_match = submatches[0]; const absl::string_view& desired_capture = submatches[group_to_extract]; // Avoid infinite loops on zero-length matches if (full_match.empty()) { if (temp_target.empty()) { break; } temp_target.remove_prefix(1); continue; } if (group_count == 1 && desired_capture.empty()) { temp_target.remove_prefix(full_match.data() - temp_target.data() + full_match.length()); continue; } absl::Status status = builder->Add(StringValue::From(desired_capture, arena)); if (!status.ok()) { return ErrorValue(status); } temp_target.remove_prefix(full_match.data() - temp_target.data() + full_match.length()); } return std::move(*builder).Build(); } Value ReplaceAll(int regex_max_program_size, const StringValue& target, const StringValue& regex, const StringValue& replacement, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string target_scratch; std::string regex_scratch; std::string replacement_scratch; absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view regex_view = regex.ToStringView(®ex_scratch); absl::string_view replacement_view = replacement.ToStringView(&replacement_scratch); RE2 re2(regex_view, cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) .With(ErrorValueReturn()); std::string error_string; if (!re2.CheckRewriteString(replacement_view, &error_string)) { return ErrorValue(absl::InvalidArgumentError( absl::StrFormat("invalid replacement string: %s", error_string))); } std::string output(target_view); RE2::GlobalReplace(&output, re2, replacement_view); return StringValue::From(std::move(output), arena); } Value ReplaceN(int regex_max_program_size, const StringValue& target, const StringValue& regex, const StringValue& replacement, int64_t count, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (count == 0) { return target; } if (count < 0) { return ReplaceAll(regex_max_program_size, target, regex, replacement, descriptor_pool, message_factory, arena); } std::string target_scratch; std::string regex_scratch; std::string replacement_scratch; absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view regex_view = regex.ToStringView(®ex_scratch); absl::string_view replacement_view = replacement.ToStringView(&replacement_scratch); RE2 re2(regex_view, cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) .With(ErrorValueReturn()); std::string error_string; if (!re2.CheckRewriteString(replacement_view, &error_string)) { return ErrorValue(absl::InvalidArgumentError( absl::StrFormat("invalid replacement string: %s", error_string))); } std::string output; absl::string_view temp_target = target_view; int replaced_count = 0; // RE2's Rewrite only supports substitutions for groups \0 through \9. absl::string_view match[10]; int nmatch = std::min(9, re2.NumberOfCapturingGroups()) + 1; while (replaced_count < count && re2.Match(temp_target, 0, temp_target.length(), RE2::UNANCHORED, match, nmatch)) { absl::string_view full_match = match[0]; output.append(temp_target.data(), full_match.data() - temp_target.data()); if (!re2.Rewrite(&output, replacement_view, match, nmatch)) { // This should ideally not happen given CheckRewriteString passed return ErrorValue(absl::InternalError("rewrite failed unexpectedly")); } temp_target.remove_prefix(full_match.data() - temp_target.data() + full_match.length()); replaced_count++; } output.append(temp_target.data(), temp_target.length()); return StringValue::From(std::move(output), arena); } absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, bool disable_extract, int regex_max_program_size) { if (!disable_extract) { CEL_RETURN_IF_ERROR(( BinaryFunctionAdapter, StringValue, StringValue>:: RegisterGlobalOverload( "regex.extract", absl::bind_front(&Extract, regex_max_program_size), registry))); } CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, StringValue, StringValue>:: RegisterGlobalOverload( "regex.extractAll", absl::bind_front(&ExtractAll, regex_max_program_size), registry))); CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, StringValue>::RegisterGlobalOverload("regex.replace", absl::bind_front( &ReplaceAll, regex_max_program_size), registry))); CEL_RETURN_IF_ERROR( (QuaternaryFunctionAdapter, StringValue, StringValue, StringValue, int64_t>:: RegisterGlobalOverload( "regex.replace", absl::bind_front(&ReplaceN, regex_max_program_size), registry))); return absl::OkStatus(); } const Type& OptionalStringType() { static absl::NoDestructor kInstance( OptionalType(BuiltinsArena(), StringType())); return *kInstance; } const Type& ListStringType() { static absl::NoDestructor kInstance( ListType(BuiltinsArena(), StringType())); return *kInstance; } absl::Status RegisterRegexCheckerDecls(TypeCheckerBuilder& builder) { CEL_ASSIGN_OR_RETURN( FunctionDecl extract_decl, MakeFunctionDecl( "regex.extract", MakeOverloadDecl("regex_extract_string_string", OptionalStringType(), StringType(), StringType()))); CEL_ASSIGN_OR_RETURN( FunctionDecl extract_all_decl, MakeFunctionDecl( "regex.extractAll", MakeOverloadDecl("regex_extractAll_string_string", ListStringType(), StringType(), StringType()))); CEL_ASSIGN_OR_RETURN( FunctionDecl replace_decl, MakeFunctionDecl( "regex.replace", MakeOverloadDecl("regex_replace_string_string_string", StringType(), StringType(), StringType(), StringType()), MakeOverloadDecl("regex_replace_string_string_string_int", StringType(), StringType(), StringType(), StringType(), IntType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(extract_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(extract_all_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(replace_decl)); return absl::OkStatus(); } } // namespace absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) { auto& runtime = cel::internal::down_cast( runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder)); if (!runtime.expr_builder().optional_types_enabled()) { return absl::InvalidArgumentError( "regex extensions requires the optional types to be enabled"); } if (runtime.expr_builder().options().enable_regex) { CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions( builder.function_registry(), /*disable_extract=*/false, runtime.expr_builder().options().regex_max_program_size)); } return absl::OkStatus(); } absl::Status RegisterRegexExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options) { if (options.enable_regex) { return RegisterRegexExtensionFunctions(registry->InternalGetRegistry(), /*disable_extract=*/true, options.regex_max_program_size); } return absl::OkStatus(); } CheckerLibrary RegexExtCheckerLibrary() { return {.id = "cel.lib.ext.regex", .configure = RegisterRegexCheckerDecls}; } CompilerLibrary RegexExtCompilerLibrary() { return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary()); } Validation RegexExtValidator() { return RegexPatternValidator( /*id=*/"", {{"regex.extract", 1}, {"regex.extractAll", 1}, {"regex.replace", 1}}); } } // namespace cel::extensions ================================================ FILE: extensions/regex_ext.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This extension depends on the CEL optional type. Please ensure that the // EnableOptionalTypes is called when using regex extensions. // // # Replace // // The `regex.replace` function replaces all non-overlapping substring of a // regex pattern in the target string with the given replacement string. // Optionally, you can limit the number of replacements by providing a count // argument. When the count is a negative number, the function acts as replace // all. Only numeric (\N) capture group references are supported in the // replacement string, with validation for correctness. Backslashed-escaped // digits (\1 to \9) within the replacement argument can be used to insert text // matching the corresponding parenthesized group in the regexp pattern. An // error will be thrown for invalid regex or replace string. // // regex.replace(target: string, pattern: string, // replacement: string) -> string // regex.replace(target: string, pattern: string, // replacement: string, count: int) -> string // // Examples: // // regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi' // regex.replace('banana', 'a', 'x', 0) == 'banana' // regex.replace('banana', 'a', 'x', 1) == 'bxnana' // regex.replace('banana', 'a', 'x', -12) == 'bxnxnx' // regex.replace('foo bar', '(fo)o (ba)r', r'\2 \1') == 'ba fo' // regex.replace('test', '(.)', r'\2') \\ Runtime Error invalid replace // string regex.replace('foo bar', '(', '$2 $1') \\ Runtime Error invalid // // # Extract // // The `regex.extract` function returns the first match of a regex pattern in a // string. If no match is found, it returns an optional none value. An error // will be thrown for invalid regex or for multiple capture groups. // // regex.extract(target: string, pattern: string) -> optional // // Examples: // // regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A') // regex.extract('HELLO', 'hello') == optional.empty() // regex.extract('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error // multiple capture group // // # Extract All // // The `regex.extractAll` function returns a list of all matches of a regex // pattern in a target string. If no matches are found, it returns an empty // list. An error will be thrown for invalid regex or for multiple capture // groups. // // regex.extractAll(target: string, pattern: string) -> list // // Examples: // // regex.extractAll('id:123, id:456', 'id:\\d+') == ['id:123', 'id:456'] // regex.extractAll('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error // multiple capture group #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ #include "absl/status/status.h" #include "checker/type_checker_builder.h" #include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/runtime_builder.h" #include "validator/validator.h" namespace cel::extensions { // Register extension functions for regular expressions for // google::api::expr::runtime::CelValue runtime. // // Note: CelValue does not support optional types, so regex.extract is // unsupported. absl::Status RegisterRegexExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options); // Register extension functions for regular expressions. absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder); // Type check declarations for the regex extension library. // Provides decls for the following functions: // // regex.replace(target: str, pattern: str, replacement: str) -> str // // regex.replace(target: str, pattern: str, replacement: str, count: int) -> str // // regex.extract(target: str, pattern: str) -> optional // // regex.extractAll(target: str, pattern: str) -> list CheckerLibrary RegexExtCheckerLibrary(); // Provides decls for the following functions: // // regex.replace(target: str, pattern: str, replacement: str) -> str // // regex.replace(target: str, pattern: str, replacement: str, count: int) -> str // // regex.extract(target: str, pattern: str) -> optional // // regex.extractAll(target: str, pattern: str) -> list CompilerLibrary RegexExtCompilerLibrary(); // Returns a `Validation` that checks all calls to the CEL regex extension // functions. // // It validates that if the pattern is a literal string, it is a valid regular // expression. Validation RegexExtValidator(); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ ================================================ FILE: extensions/regex_ext_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/regex_ext.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/standard_library.h" #include "checker/validation_result.h" #include "common/kind.h" #include "common/value.h" #include "common/value_testing.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "eval/public/activation.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/optional_types.h" #include "runtime/reference_resolver.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/extension_set.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::BoolValueIs; using ::cel::test::ErrorValueIs; using ::cel::test::OptionalValueIs; using ::cel::test::OptionalValueIsEmpty; using ::cel::test::StringValueIs; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelFunctionRegistry; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::SizeIs; using ::testing::TestWithParam; using ::testing::ValuesIn; using LegacyActivation = google::api::expr::runtime::Activation; TEST(RegexExtTest, BuildFailsWithoutOptionalSupport) { RuntimeOptions options; options.enable_regex = true; options.enable_qualified_type_identifiers = true; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_THAT( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), IsOk()); // Optional types are NOT enabled. ASSERT_THAT(RegisterRegexExtensionFunctions(builder), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("regex extensions requires the optional types " "to be enabled"))); } TEST(RegexExtTest, LegacyRuntimeSmokeTest) { InterpreterOptions options; options.enable_regex = true; options.enable_qualified_type_identifiers = true; options.enable_qualified_identifier_rewrites = true; std::unique_ptr builder = CreateCelExpressionBuilder( internal::GetTestingDescriptorPool(), nullptr, options); // Optional types are NOT enabled. ASSERT_THAT(RegisterRegexExtensionFunctions(builder->GetRegistry(), options), IsOk()); ASSERT_OK_AND_ASSIGN(auto expr, Parse("regex.extractAll('hello world', 'hello (.*)')")); LegacyActivation activation; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression( &expr.expr(), &expr.source_info())); ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsList()); ASSERT_EQ(result.ListOrDie()->size(), 1); ASSERT_TRUE(result.ListOrDie()->Get(&arena, 0).IsString()); EXPECT_EQ(result.ListOrDie()->Get(&arena, 0).StringOrDie().value(), "world"); } TEST(RegexExtTest, DoesNotRegisterExtractForLegacy) { InterpreterOptions options; options.enable_regex = true; CelFunctionRegistry registry; // Optional types are not usable in legacy runtime, so extract should not be // registered. ASSERT_THAT(RegisterRegexExtensionFunctions(®istry, options), IsOk()); EXPECT_THAT( registry.FindStaticOverloads("regex.extract", false, {cel::Kind::kString, cel::Kind::kString}), IsEmpty()); EXPECT_THAT( registry.FindStaticOverloads("regex.extractAll", false, {cel::Kind::kString, cel::Kind::kString}), SizeIs(1)); EXPECT_THAT(registry.FindStaticOverloads( "regex.replace", false, {cel::Kind::kString, cel::Kind::kString, cel::Kind::kString}), SizeIs(1)); EXPECT_THAT( registry.FindStaticOverloads("regex.replace", false, {cel::Kind::kString, cel::Kind::kString, cel::Kind::kString, cel::Kind::kInt64}), SizeIs(1)); } TEST(RegexExtTest, FollowsRegexOption) { InterpreterOptions options; options.enable_regex = false; CelFunctionRegistry registry; ASSERT_THAT(RegisterRegexExtensionFunctions(®istry, options), IsOk()); EXPECT_THAT( registry.FindStaticOverloads("regex.extract", false, {cel::Kind::kString, cel::Kind::kString}), IsEmpty()); EXPECT_THAT( registry.FindStaticOverloads("regex.extractAll", false, {cel::Kind::kString, cel::Kind::kString}), IsEmpty()); EXPECT_THAT(registry.FindStaticOverloads( "regex.replace", false, {cel::Kind::kString, cel::Kind::kString, cel::Kind::kString}), IsEmpty()); EXPECT_THAT( registry.FindStaticOverloads("regex.replace", false, {cel::Kind::kString, cel::Kind::kString, cel::Kind::kString, cel::Kind::kInt64}), IsEmpty()); } enum class EvaluationType { kBoolTrue, kOptionalValue, kOptionalNone, kRuntimeError, kUnknownStaticError, kInvalidArgStaticError }; struct RegexExtTestCase { EvaluationType evaluation_type; std::string expr; std::string expected_result = ""; }; class RegexExtTest : public TestWithParam { public: void SetUp() override { RuntimeOptions options; options.enable_regex = true; options.enable_qualified_type_identifiers = true; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_THAT( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), IsOk()); ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); ASSERT_THAT(RegisterRegexExtensionFunctions(builder), IsOk()); ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); } absl::StatusOr TestEvaluate(const std::string& expr_string) { CEL_ASSIGN_OR_RETURN(auto parsed_expr, Parse(expr_string)); CEL_ASSIGN_OR_RETURN(std::unique_ptr program, cel::extensions::ProtobufRuntimeAdapter::CreateProgram( *runtime_, parsed_expr)); Activation activation; return program->Evaluate(&arena_, activation); } google::protobuf::Arena arena_; std::unique_ptr runtime_; }; std::vector regexTestCases() { return { // Tests for extract Function {EvaluationType::kOptionalValue, R"(regex.extract('hello world', 'hello (.*)'))", "world"}, {EvaluationType::kOptionalValue, R"(regex.extract('item-A, item-B', r'item-(\w+)'))", "A"}, {EvaluationType::kOptionalValue, R"(regex.extract('The color is red', r'The color is (\w+)'))", "red"}, {EvaluationType::kOptionalValue, R"(regex.extract('The color is red', r'The color is \w+'))", "The color is red"}, {EvaluationType::kOptionalValue, "regex.extract('brand', 'brand')", "brand"}, {EvaluationType::kOptionalNone, "regex.extract('hello world', 'goodbye (.*)')"}, {EvaluationType::kOptionalNone, "regex.extract('HELLO', 'hello')"}, {EvaluationType::kOptionalNone, R"(regex.extract('', r'\w+'))"}, {EvaluationType::kBoolTrue, "regex.extract('4122345432', '22').orValue('777') == '22'"}, {EvaluationType::kBoolTrue, "regex.extract('4122345432', '22').or(optional.of('777')) == " "optional.of('22')"}, // Tests for extractAll Function {EvaluationType::kBoolTrue, "regex.extractAll('id:123, id:456', 'assa') == []"}, {EvaluationType::kBoolTrue, R"(regex.extractAll('id:123, id:456', r'id:\d+') == ['id:123','id:456'])"}, {EvaluationType::kBoolTrue, R"(regex.extractAll('Files: f_1.txt, f_2.csv', r'f_(\d+)')==['1','2'])"}, {EvaluationType::kBoolTrue, R"(regex.extractAll('testuser@', '(?P.*)@') == ['testuser'])"}, {EvaluationType::kBoolTrue, R"cel(regex.extractAll('t@gmail.com, a@y.com, 22@sdad.com', '(?P.*)@') == ['t@gmail.com, a@y.com, 22'])cel"}, {EvaluationType::kBoolTrue, R"cel(regex.extractAll('t@gmail.com, a@y.com, 22@sdad.com', r'(?P\w+)@') == ['t','a', '22'])cel"}, {EvaluationType::kBoolTrue, "regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"}, {EvaluationType::kBoolTrue, R"(regex.extractAll('item:a1, topic:b2', r'(?:item:|topic:)([a-z]\d)') == ['a1', 'b2'])"}, {EvaluationType::kBoolTrue, R"(regex.extractAll('val=a, val=, val=c', 'val=([^,]*)')==['a','c'])"}, {EvaluationType::kBoolTrue, "regex.extractAll('key=, key=, key=', 'key=([^,]*)') == []"}, {EvaluationType::kBoolTrue, R"(regex.extractAll('a b c', r'(\S*)\s*') == ['a', 'b', 'c'])"}, {EvaluationType::kBoolTrue, "regex.extractAll('abc', 'a|b*') == ['a','b']"}, {EvaluationType::kBoolTrue, "regex.extractAll('abc', 'a|(b)|c*') == ['b']"}, // Tests for replace Function {EvaluationType::kBoolTrue, "regex.replace('abc', '$', '_end') == 'abc_end'"}, {EvaluationType::kBoolTrue, R"(regex.replace('a-b', r'\b', '|') == '|a|-|b|')"}, {EvaluationType::kBoolTrue, R"(regex.replace('foo bar', '(fo)o (ba)r', r'\2 \1') == 'ba fo')"}, {EvaluationType::kBoolTrue, R"(regex.replace('foo bar', 'foo', r'\\') == '\\ bar')"}, {EvaluationType::kBoolTrue, "regex.replace('banana', 'ana', 'x') == 'bxna'"}, {EvaluationType::kBoolTrue, R"(regex.replace('abc', 'b(.)', r'x\1') == 'axc')"}, {EvaluationType::kBoolTrue, "regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi'"}, {EvaluationType::kBoolTrue, R"(regex.replace('ac', 'a(b)?c', r'[\1]') == '[]')"}, {EvaluationType::kBoolTrue, "regex.replace('apple pie', 'p', 'X') == 'aXXle Xie'"}, {EvaluationType::kBoolTrue, R"(regex.replace('remove all spaces', r'\s', '') == 'removeallspaces')"}, {EvaluationType::kBoolTrue, R"(regex.replace('digit:99919291992', r'\d+', '3') == 'digit:3')"}, {EvaluationType::kBoolTrue, R"cel(regex.replace('foo bar baz', r'\w+', r'(\0)') == '(foo) (bar) (baz)')cel"}, {EvaluationType::kBoolTrue, "regex.replace('', 'a', 'b') == ''"}, {EvaluationType::kBoolTrue, R"cel(regex.replace('User: Alice, Age: 30', r'User: (?P\w+), Age: (?P\d+)', '${name} is ${age} years old') == '${name} is ${age} years old')cel"}, {EvaluationType::kBoolTrue, R"cel(regex.replace('User: Alice, Age: 30', r'User: (?P\w+), Age: (?P\d+)', r'\1 is \2 years old') == 'Alice is 30 years old')cel"}, {EvaluationType::kBoolTrue, "regex.replace('hello ☃', '☃', '❄') == 'hello ❄'"}, {EvaluationType::kBoolTrue, R"(regex.replace('id=123', r'id=(?P\d+)', r'value: \1') == 'value: 123')"}, {EvaluationType::kBoolTrue, "regex.replace('banana', 'a', 'x') == 'bxnxnx'"}, {EvaluationType::kBoolTrue, R"(regex.replace(regex.replace('%(foo) %(bar) %2', r'%\((\w+)\)', r'${\1}'),r'%(\d+)', r'$\1') == '${foo} ${bar} $2')"}, {EvaluationType::kBoolTrue, R"(regex.replace('abc def', r'(abc)', r'\\1') == r'\1 def')"}, {EvaluationType::kBoolTrue, R"(regex.replace('abc def', r'(abc)', r'\\2') == r'\2 def')"}, {EvaluationType::kBoolTrue, R"(regex.replace('abc def', r'(abc)', r'\\{word}') == '\\{word} def')"}, {EvaluationType::kBoolTrue, R"(regex.replace('abc def', r'(abc)', r'\\word') == '\\word def')"}, {EvaluationType::kBoolTrue, "regex.replace('abc', '^', 'start_') == 'start_abc'"}, // Tests for replace Function with count variable {EvaluationType::kBoolTrue, R"(regex.replace('foofoo', 'foo', 'bar', 9223372036854775807) == 'barbar')"}, {EvaluationType::kBoolTrue, "regex.replace('banana', 'a', 'x', 0) == 'banana'"}, {EvaluationType::kBoolTrue, "regex.replace('banana', 'a', 'x', 1) == 'bxnana'"}, {EvaluationType::kBoolTrue, "regex.replace('banana', 'a', 'x', 2) == 'bxnxna'"}, {EvaluationType::kBoolTrue, "regex.replace('banana', 'a', 'x', 100) == 'bxnxnx'"}, {EvaluationType::kBoolTrue, "regex.replace('banana', 'a', 'x', -1) == 'bxnxnx'"}, {EvaluationType::kBoolTrue, "regex.replace('banana', 'a', 'x', -100) == 'bxnxnx'"}, {EvaluationType::kBoolTrue, R"cel(regex.replace('cat-dog dog-cat cat-dog dog-cat', '(cat)-(dog)', r'\2-\1', 1) == 'dog-cat dog-cat cat-dog dog-cat')cel"}, {EvaluationType::kBoolTrue, R"cel(regex.replace('cat-dog dog-cat cat-dog dog-cat', '(cat)-(dog)', r'\2-\1', 2) == 'dog-cat dog-cat dog-cat dog-cat')cel"}, {EvaluationType::kBoolTrue, R"(regex.replace('a.b.c', r'\.', '-', 1) == 'a-b.c')"}, {EvaluationType::kBoolTrue, R"(regex.replace('a.b.c', r'\.', '-', -1) == 'a-b-c')"}, {EvaluationType::kBoolTrue, R"(regex.replace('123456789ABC', '(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\w)(\\w)(\\w)','X', 1) == 'X')"}, {EvaluationType::kBoolTrue, R"(regex.replace('123456789ABC', '(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\d)(\\w)(\\w)(\\w)', r'\1-\9-X', 1) == '1-9-X')"}, // Static Errors {EvaluationType::kUnknownStaticError, "regex.replace('abc', '^', 1)", "No matching overloads found : regex.replace(string, string, int64)"}, {EvaluationType::kUnknownStaticError, "regex.replace('abc', '^', '1','')", "No matching overloads found : regex.replace(string, string, string, " "string)"}, {EvaluationType::kUnknownStaticError, "regex.extract('foo bar', 1)", "No matching overloads found : regex.extract(string, int64)"}, {EvaluationType::kInvalidArgStaticError, "regex.extract('foo bar', 1, 'bar')", "No overload found in reference resolve step for extract"}, {EvaluationType::kInvalidArgStaticError, "regex.extractAll()", "No overload found in reference resolve step for extractAll"}, // Runtime Errors {EvaluationType::kRuntimeError, R"(regex.extract('foo', 'fo(o+)(abc'))", "invalid regular expression: missing ): fo(o+)(abc"}, {EvaluationType::kRuntimeError, R"(regex.extractAll('foo bar', '[a-z'))", "invalid regular expression: missing ]: [a-z"}, {EvaluationType::kRuntimeError, R"(regex.replace('foo bar', '[a-z', 'a'))", "invalid regular expression: missing ]: [a-z"}, {EvaluationType::kRuntimeError, R"(regex.replace('foo bar', '[a-z', 'a', 1))", "invalid regular expression: missing ]: [a-z"}, {EvaluationType::kRuntimeError, R"(regex.replace('id=123', r'id=(?P\d+)', r'value: \values'))", R"(invalid replacement string: Rewrite schema error: '\' must be followed by a digit or '\'.)"}, {EvaluationType::kRuntimeError, R"(regex.replace('test', '(t)', '\\2'))", "invalid replacement string: Rewrite schema requests 2 matches, but " "the regexp only has 1 parenthesized subexpressions"}, {EvaluationType::kRuntimeError, R"(regex.replace('id=123', r'id=(?P\d+)', '\\', 1))", R"(invalid replacement string: Rewrite schema error: '\' not allowed at end.)"}, {EvaluationType::kRuntimeError, R"(regex.extract('phone: 415-5551212', r'phone: ((\d{3})-)?'))", R"(regular expression has more than one capturing group: phone: ((\d{3})-)?)"}, {EvaluationType::kRuntimeError, R"(regex.extractAll('testuser@testdomain', '(.*)@([^.]*)'))", R"(regular expression has more than one capturing group: (.*)@([^.]*))"}, }; } TEST_P(RegexExtTest, RegexExtTests) { const RegexExtTestCase& test_case = GetParam(); auto result = TestEvaluate(test_case.expr); switch (test_case.evaluation_type) { case EvaluationType::kRuntimeError: EXPECT_THAT(result, IsOkAndHolds(ErrorValueIs( StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_case.expected_result))))) << "Expression: " << test_case.expr; break; case EvaluationType::kUnknownStaticError: EXPECT_THAT(result, IsOkAndHolds(ErrorValueIs( StatusIs(absl::StatusCode::kUnknown, HasSubstr(test_case.expected_result))))) << "Expression: " << test_case.expr; break; case EvaluationType::kInvalidArgStaticError: EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_case.expected_result))) << "Expression: " << test_case.expr; break; case EvaluationType::kOptionalNone: EXPECT_THAT(result, IsOkAndHolds(OptionalValueIsEmpty())) << "Expression: " << test_case.expr; break; case EvaluationType::kOptionalValue: EXPECT_THAT(result, IsOkAndHolds(OptionalValueIs( StringValueIs(test_case.expected_result)))) << "Expression: " << test_case.expr; break; case EvaluationType::kBoolTrue: EXPECT_THAT(result, IsOkAndHolds(BoolValueIs(true))) << "Expression: " << test_case.expr; break; } } INSTANTIATE_TEST_SUITE_P(RegexExtTest, RegexExtTest, ValuesIn(regexTestCases())); struct RegexCheckerTestCase { std::string expr_string; std::string error_substr; }; class RegexExtCheckerLibraryTest : public TestWithParam { public: void SetUp() override { // Arrange: Configure the compiler. // Add the regex checker library to the compiler builder. ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler_builder, NewCompilerBuilder(descriptor_pool_)); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(RegexExtCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); } const google::protobuf::DescriptorPool* descriptor_pool_ = internal::GetTestingDescriptorPool(); std::unique_ptr compiler_; }; TEST_P(RegexExtCheckerLibraryTest, RegexExtTypeCheckerTests) { // Act & Assert: Compile the expression and validate the result. ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler_->Compile(GetParam().expr_string)); absl::string_view error_substr = GetParam().error_substr; EXPECT_EQ(result.IsValid(), error_substr.empty()); if (!error_substr.empty()) { EXPECT_THAT(result.FormatError(), HasSubstr(error_substr)); } } std::vector createRegexCheckerParams() { return { {"regex.replace('abc', 'a', 's') == 'sbc'"}, {"regex.replace('abc', 'a', 's') == 121", "found no matching overload for '_==_' applied to '(string, int)"}, {"regex.replace('abc', 'j', '1', 2) == 9.0", "found no matching overload for '_==_' applied to '(string, double)"}, {"regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"}, {"regex.extract('foo bar', 'f') == 121", "found no matching overload for '_==_' applied to " "'(optional_type(string), int)'"}, }; } INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest, ValuesIn(createRegexCheckerParams())); absl::StatusOr> CreateRegexExtCompiler() { CEL_ASSIGN_OR_RETURN( auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder->AddLibrary(RegexExtCompilerLibrary())); return std::move(*builder).Build(); } class RegexExtValidatorTest : public TestWithParam {}; TEST_P(RegexExtValidatorTest, Basic) { ASSERT_OK_AND_ASSIGN(auto compiler, CreateRegexExtCompiler()); Validator validator; validator.AddValidation(RegexExtValidator()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(GetParam().expr_string)); validator.UpdateValidationResult(result); EXPECT_EQ(result.IsValid(), GetParam().error_substr.empty()) << "Expression: " << GetParam().expr_string; if (!GetParam().error_substr.empty()) { EXPECT_THAT(result.FormatError(), HasSubstr(GetParam().error_substr)); } } INSTANTIATE_TEST_SUITE_P(RegexExtValidatorTest, RegexExtValidatorTest, testing::ValuesIn(std::vector{ {"regex.extract('hello world', 'hello (.*)')"}, {"regex.extract('hello world', 'hello ([') ", "invalid regular expression"}, {"regex.extractAll('hello world', 'hello (.*)')"}, {"regex.extractAll('hello world', 'hello ([') ", "invalid regular expression"}, {"regex.replace('hello world', 'hello', 'hi')"}, {"regex.replace('hello world', 'he([', 'hi') ", "invalid regular expression"}, })); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/regex_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/regex_functions.h" #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/internal/builtins_arena.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "internal/re2_options.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "re2/re2.h" namespace cel::extensions { namespace { using ::cel::checker_internal::BuiltinsArena; using ::google::api::expr::runtime::CelFunctionRegistry; using ::google::api::expr::runtime::InterpreterOptions; // Extract matched group values from the given target string and rewrite the // string Value ExtractString(int regex_max_program_size, const StringValue& target, const StringValue& regex, const StringValue& rewrite, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string regex_scratch; std::string target_scratch; std::string rewrite_scratch; absl::string_view regex_view = regex.ToStringView(®ex_scratch); absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view rewrite_view = rewrite.ToStringView(&rewrite_scratch); RE2 re2(regex_view, cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) .With(ErrorValueReturn()); std::string output; bool result = RE2::Extract(target_view, re2, rewrite_view, &output); if (!result) { return ErrorValue(absl::InvalidArgumentError( "Unable to extract string for the given regex")); } return StringValue::From(std::move(output), arena); } // Captures the first unnamed/named group value // NOTE: For capturing all the groups, use CaptureStringN instead Value CaptureString(int regex_max_program_size, const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string regex_scratch; std::string target_scratch; absl::string_view regex_view = regex.ToStringView(®ex_scratch); absl::string_view target_view = target.ToStringView(&target_scratch); RE2 re2(regex_view, cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) .With(ErrorValueReturn()); std::string output; bool result = RE2::FullMatch(target_view, re2, &output); if (!result) { return ErrorValue(absl::InvalidArgumentError( "Unable to capture groups for the given regex")); } else { return StringValue::From(std::move(output), arena); } } // Does a FullMatchN on the given string and regex and returns a map with pairs as follows: // a. For a named group - // b. For an unnamed group - absl::StatusOr CaptureStringN( int regex_max_program_size, const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string target_scratch; std::string regex_scratch; absl::string_view target_view = target.ToStringView(&target_scratch); absl::string_view regex_view = regex.ToStringView(®ex_scratch); RE2 re2(regex_view, cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, regex_max_program_size)) .With(ErrorValueReturn()); const int capturing_groups_count = re2.NumberOfCapturingGroups(); const auto& named_capturing_groups_map = re2.CapturingGroupNames(); if (capturing_groups_count <= 0) { return ErrorValue(absl::InvalidArgumentError( "Capturing groups were not found in the given regex.")); } std::vector captured_strings(capturing_groups_count); std::vector captured_string_addresses(capturing_groups_count); std::vector argv(capturing_groups_count); for (int j = 0; j < capturing_groups_count; j++) { captured_string_addresses[j] = &captured_strings[j]; argv[j] = &captured_string_addresses[j]; } bool result = RE2::FullMatchN(target_view, re2, argv.data(), capturing_groups_count); if (!result) { return ErrorValue(absl::InvalidArgumentError( "Unable to capture groups for the given regex")); } auto builder = cel::NewMapValueBuilder(arena); builder->Reserve(capturing_groups_count); for (int index = 1; index <= capturing_groups_count; index++) { auto it = named_capturing_groups_map.find(index); std::string name = it != named_capturing_groups_map.end() ? it->second : std::to_string(index); CEL_RETURN_IF_ERROR(builder->Put( StringValue::From(std::move(name), arena), StringValue::From(std::move(captured_strings[index - 1]), arena))); } return std::move(*builder).Build(); } absl::Status RegisterRegexFunctions(FunctionRegistry& registry, int max_regex_program_size) { // Register Regex Extract Function CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, StringValue>::RegisterGlobalOverload(kRegexExtract, absl::bind_front( &ExtractString, max_regex_program_size), registry))); // Register Regex Captures Function CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, StringValue, StringValue>:: RegisterGlobalOverload( kRegexCapture, absl::bind_front(&CaptureString, max_regex_program_size), registry))); // Register Regex CaptureN Function CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, StringValue, StringValue>:: RegisterGlobalOverload( kRegexCaptureN, absl::bind_front(&CaptureStringN, max_regex_program_size), registry))); return absl::OkStatus(); } const Type& CaptureNMapType() { static absl::NoDestructor kInstance( MapType(BuiltinsArena(), StringType(), StringType())); return *kInstance; } absl::Status RegisterRegexDecls(TypeCheckerBuilder& builder) { CEL_ASSIGN_OR_RETURN( FunctionDecl regex_extract_decl, MakeFunctionDecl( std::string(kRegexExtract), MakeOverloadDecl("re_extract_string_string_string", StringType(), StringType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(regex_extract_decl)); CEL_ASSIGN_OR_RETURN( FunctionDecl regex_capture_decl, MakeFunctionDecl( std::string(kRegexCapture), MakeOverloadDecl("re_capture_string_string", StringType(), StringType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(regex_capture_decl)); CEL_ASSIGN_OR_RETURN( FunctionDecl regex_capture_n_decl, MakeFunctionDecl( std::string(kRegexCaptureN), MakeOverloadDecl("re_captureN_string_string", CaptureNMapType(), StringType(), StringType()))); return builder.AddFunction(regex_capture_n_decl); } } // namespace absl::Status RegisterRegexFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { if (options.enable_regex) { CEL_RETURN_IF_ERROR( RegisterRegexFunctions(registry, options.regex_max_program_size)); } return absl::OkStatus(); } absl::Status RegisterRegexFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { CEL_RETURN_IF_ERROR(RegisterRegexFunctions( registry->InternalGetRegistry(), google::api::expr::runtime::ConvertToRuntimeOptions(options))); return absl::OkStatus(); } CheckerLibrary RegexCheckerLibrary() { return {.id = "cpp_regex", .configure = RegisterRegexDecls}; } } // namespace cel::extensions ================================================ FILE: extensions/regex_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Definitions for extension functions wrapping C++ RE2 APIs. These are // only defined for the C++ CEL library and distinct from the regex // extension library (supported by other implementations). #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_FUNCTIONS_H_ #include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel::extensions { inline constexpr absl::string_view kRegexExtract = "re.extract"; inline constexpr absl::string_view kRegexCapture = "re.capture"; inline constexpr absl::string_view kRegexCaptureN = "re.captureN"; // Register Extract and Capture Functions for RE2 // Requires options.enable_regex to be true // The canonical regex extensions supported by the CEL team are registered // via the `RegisterRegexExtensionsFunctions`. This extension is deprecated. ABSL_DEPRECATED("Use RegisterRegexExtensionsFunctions instead.") absl::Status RegisterRegexFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options); absl::Status RegisterRegexFunctions(FunctionRegistry& registry, const RuntimeOptions& options); // Declarations for the regex extension library. CheckerLibrary RegexCheckerLibrary(); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_FUNCTIONS_H_ ================================================ FILE: extensions/regex_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/regex_functions.h" #include #include #include #include #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "checker/standard_library.h" #include "checker/validation_result.h" #include "common/value.h" #include "common/value_testing.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/reference_resolver.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/extension_set.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::test::ErrorValueIs; using ::cel::test::MapValueElements; using ::cel::test::MapValueIs; using ::cel::test::StringValueIs; using ::google::api::expr::parser::Parse; using ::testing::HasSubstr; using ::testing::UnorderedElementsAre; using ::testing::ValuesIn; struct TestCase { const std::string expr_string; const std::string expected_result; }; class RegexFunctionsTest : public ::testing::TestWithParam { public: void SetUp() override { RuntimeOptions options; options.enable_regex = true; options.enable_qualified_type_identifiers = true; ASSERT_OK_AND_ASSIGN( RuntimeBuilder builder, CreateStandardRuntimeBuilder(descriptor_pool_, options)); ASSERT_THAT( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), IsOk()); ASSERT_THAT(RegisterRegexFunctions(builder.function_registry(), options), IsOk()); ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); } absl::StatusOr TestEvaluate(const std::string& expr_string) { CEL_ASSIGN_OR_RETURN(auto parsed_expr, Parse(expr_string)); CEL_ASSIGN_OR_RETURN(std::unique_ptr program, cel::extensions::ProtobufRuntimeAdapter::CreateProgram( *runtime_, parsed_expr)); Activation activation; return program->Evaluate(&arena_, activation); } const google::protobuf::DescriptorPool* descriptor_pool_ = internal::GetTestingDescriptorPool(); google::protobuf::MessageFactory* message_factory_ = google::protobuf::MessageFactory::generated_factory(); google::protobuf::Arena arena_; std::unique_ptr runtime_; }; TEST_F(RegexFunctionsTest, CaptureStringSuccessWithCombinationOfGroups) { // combination of named and unnamed groups should return a celmap EXPECT_THAT( TestEvaluate((R"cel( re.captureN( 'The user testuser belongs to testdomain', 'The (user|domain) (?P.*) belongs to (?P.*)' ) )cel")), IsOkAndHolds(MapValueIs(MapValueElements( UnorderedElementsAre( Pair(StringValueIs("1"), StringValueIs("user")), Pair(StringValueIs("Username"), StringValueIs("testuser")), Pair(StringValueIs("Domain"), StringValueIs("testdomain"))), descriptor_pool_, message_factory_, &arena_)))); } TEST_F(RegexFunctionsTest, CaptureStringSuccessWithSingleNamedGroup) { // Regex containing one named group should return a map EXPECT_THAT( TestEvaluate(R"cel(re.captureN('testuser@', '(?P.*)@'))cel"), IsOkAndHolds(MapValueIs(MapValueElements( UnorderedElementsAre( Pair(StringValueIs("username"), StringValueIs("testuser"))), descriptor_pool_, message_factory_, &arena_)))); } TEST_F(RegexFunctionsTest, CaptureStringSuccessWithMultipleUnamedGroups) { // Regex containing all unnamed groups should return a map EXPECT_THAT( TestEvaluate( R"cel(re.captureN('testuser@testdomain', '(.*)@([^.]*)'))cel"), IsOkAndHolds(MapValueIs(MapValueElements( UnorderedElementsAre( Pair(StringValueIs("1"), StringValueIs("testuser")), Pair(StringValueIs("2"), StringValueIs("testdomain"))), descriptor_pool_, message_factory_, &arena_)))); } // Extract String: Extract named and unnamed strings TEST_F(RegexFunctionsTest, ExtractStringWithNamedAndUnnamedGroups) { EXPECT_THAT(TestEvaluate(R"cel( re.extract( 'The user testuser belongs to testdomain', 'The (user|domain) (?P.*) belongs to (?P.*)', '\\3 contains \\1 \\2') )cel"), IsOkAndHolds(StringValueIs("testdomain contains user testuser"))); } // Extract String: Extract with empty strings TEST_F(RegexFunctionsTest, ExtractStringWithEmptyStrings) { EXPECT_THAT(TestEvaluate(R"cel(re.extract('', '', ''))cel"), IsOkAndHolds(StringValueIs(""))); } // Extract String: Extract unnamed strings TEST_F(RegexFunctionsTest, ExtractStringWithUnnamedGroups) { EXPECT_THAT(TestEvaluate(R"cel( re.extract('testuser@google.com', '(.*)@([^.]*)', '\\2!\\1') )cel"), IsOkAndHolds(StringValueIs("google!testuser"))); } // Extract String: Extract string with no captured groups TEST_F(RegexFunctionsTest, ExtractStringWithNoGroups) { EXPECT_THAT(TestEvaluate(R"cel(re.extract('foo', '.*', '\'\\0\''))cel"), IsOkAndHolds(StringValueIs("'foo'"))); } // Capture String: Success with matching unnamed group TEST_F(RegexFunctionsTest, CaptureStringWithUnnamedGroups) { EXPECT_THAT(TestEvaluate(R"cel(re.capture('foo', 'fo(o)'))cel"), IsOkAndHolds(StringValueIs("o"))); } std::vector createParams() { return { {// Extract String: Fails for mismatched regex (R"(re.extract('foo', 'f(o+)(s)', '\\1\\2'))"), "Unable to extract string for the given regex"}, {// Extract String: Fails when rewritten string has too many placeholders (R"(re.extract('foo', 'f(o+)', '\\1\\2'))"), "Unable to extract string for the given regex"}, {// Extract String: Fails when invalid regular expression (R"(re.extract('foo', 'f(o+)(abc', '\\1\\2'))"), "invalid regular expression"}, {// Capture String: Empty regex (R"(re.capture('foo', ''))"), "Unable to capture groups for the given regex"}, {// Capture String: No Capturing groups (R"(re.capture('foo', '.*'))"), "Unable to capture groups for the given regex"}, {// Capture String: Mismatched String (R"(re.capture('', 'bar'))"), "Unable to capture groups for the given regex"}, {// Capture String: Mismatched groups (R"(re.capture('foo', 'fo(o+)(s)'))"), "Unable to capture groups for the given regex"}, {// Capture String: invalid regular expression (R"(re.capture('foo', 'fo(o+)(abc'))"), "invalid regular expression"}, {// Capture String N: Empty regex (R"(re.captureN('foo', ''))"), "Capturing groups were not found in the given regex."}, {// Capture String N: No Capturing groups (R"(re.captureN('foo', '.*'))"), "Capturing groups were not found in the given regex."}, {// Capture String N: Mismatched String (R"(re.captureN('', 'bar'))"), "Capturing groups were not found in the given regex."}, {// Capture String N: Mismatched groups (R"(re.captureN('foo', 'fo(o+)(s)'))"), "Unable to capture groups for the given regex"}, {// Capture String N: invalid regular expression (R"(re.captureN('foo', 'fo(o+)(abc'))"), "invalid regular expression"}, }; } TEST_P(RegexFunctionsTest, RegexFunctionsTests) { const TestCase& test_case = GetParam(); ABSL_LOG(INFO) << "Testing Cel Expression: " << test_case.expr_string; EXPECT_THAT(TestEvaluate(test_case.expr_string), IsOkAndHolds(ErrorValueIs( StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_case.expected_result))))); } INSTANTIATE_TEST_SUITE_P(RegexFunctionsTest, RegexFunctionsTest, ValuesIn(createParams())); struct RegexCheckerTestCase { const std::string expr_string; bool is_valid; }; class RegexCheckerLibraryTest : public ::testing::TestWithParam { public: void SetUp() override { // Arrange: Configure the compiler. // Add the regex checker library to the compiler builder. ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler_builder, NewCompilerBuilder(descriptor_pool_)); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(RegexCheckerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); } const google::protobuf::DescriptorPool* descriptor_pool_ = internal::GetTestingDescriptorPool(); std::unique_ptr compiler_; }; TEST_P(RegexCheckerLibraryTest, RegexFunctionsTypeCheckerSuccess) { // Act & Assert: Compile the expression and validate the result. ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler_->Compile(GetParam().expr_string)); EXPECT_EQ(result.IsValid(), GetParam().is_valid); } // Returns a vector of test cases for the RegexCheckerLibraryTest. // Returns both positive and negative test cases for the regex functions. std::vector createRegexCheckerParams() { return { {R"(re.extract('testuser@google.com', '(.*)@([^.]*)', '\\2!\\1') == 'google!testuser')", true}, {R"(re.extract(1, '(.*)@([^.]*)', '\\2!\\1') == 'google!testuser')", false}, {R"(re.extract('testuser@google.com', ['1', '2'], '\\2!\\1') == 'google!testuser')", false}, {R"(re.extract('testuser@google.com', '(.*)@([^.]*)', false) == 'google!testuser')", false}, {R"(re.extract('testuser@google.com', '(.*)@([^.]*)', '\\2!\\1') == 2.2)", false}, {R"(re.captureN('testuser@', '(?P.*)@') == {'username': 'testuser'})", true}, {R"(re.captureN(['foo', 'bar'], '(?P.*)@') == {'username': 'testuser'})", false}, {R"(re.captureN('testuser@', 2) == {'username': 'testuser'})", false}, {R"(re.captureN('testuser@', '(?P.*)@') == true)", false}, {R"(re.capture('foo', 'fo(o)') == 'o')", true}, {R"(re.capture('foo', 2) == 'o')", false}, {R"(re.capture(true, 'fo(o)') == 'o')", false}, {R"(re.capture('foo', 'fo(o)') == ['o'])", false}, }; } INSTANTIATE_TEST_SUITE_P(RegexCheckerLibraryTest, RegexCheckerLibraryTest, ValuesIn(createRegexCheckerParams())); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/select_optimization.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/select_optimization.h" #include #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/attribute.h" #include "base/builtins.h" #include "common/ast.h" #include "common/ast_rewrite.h" #include "common/casting.h" #include "common/constant.h" #include "common/expr.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "common/native_type.h" #include "common/type.h" #include "common/value.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/casts.h" #include "internal/number.h" #include "internal/status_macros.h" #include "runtime/internal/errors.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { using ::cel::Ast; using ::cel::AstRewriterBase; using ::cel::CallExpr; using ::cel::ConstantKind; using ::cel::Expr; using ::cel::ExprKind; using ::cel::SelectExpr; using ::google::api::expr::runtime::AttributeTrail; using ::google::api::expr::runtime::DirectExpressionStep; using ::google::api::expr::runtime::ExecutionFrame; using ::google::api::expr::runtime::ExecutionFrameBase; using ::google::api::expr::runtime::ExpressionStepBase; using ::google::api::expr::runtime::PlannerContext; using ::google::api::expr::runtime::ProgramOptimizer; // Represents a single select operation (field access or indexing). // For struct-typed field accesses, includes the field name and the field // number. struct SelectInstruction { int64_t number; std::string name; }; // Represents a single qualifier in a traversal path. // TODO(uncreated-issue/51): support variable indexes. using QualifierInstruction = absl::variant; struct SelectPath { Expr* operand; std::vector select_instructions; bool test_only; // TODO(uncreated-issue/54): support for optionals. }; // Generates the AST representation of the qualification path for the optimized // select branch. I.e., the list-typed second argument of the cel.@attribute // call. Expr MakeSelectPathExpr( const std::vector& select_instructions) { Expr result; auto& ast_list = result.mutable_list_expr().mutable_elements(); ast_list.reserve(select_instructions.size()); auto visitor = absl::Overload( [&](const SelectInstruction& instruction) { Expr ast_instruction; Expr field_number; field_number.mutable_const_expr().set_int64_value(instruction.number); Expr field_name; field_name.mutable_const_expr().set_string_value(instruction.name); auto& field_specifier = ast_instruction.mutable_list_expr().mutable_elements(); field_specifier.emplace_back().set_expr(std::move(field_number)); field_specifier.emplace_back().set_expr(std::move(field_name)); ast_list.emplace_back().set_expr(std::move(ast_instruction)); }, [&](absl::string_view instruction) { Expr const_expr; const_expr.mutable_const_expr().set_string_value(instruction); ast_list.emplace_back().set_expr(std::move(const_expr)); }, [&](int64_t instruction) { Expr const_expr; const_expr.mutable_const_expr().set_int64_value(instruction); ast_list.emplace_back().set_expr(std::move(const_expr)); }, [&](uint64_t instruction) { Expr const_expr; const_expr.mutable_const_expr().set_uint64_value(instruction); ast_list.emplace_back().set_expr(std::move(const_expr)); }, [&](bool instruction) { Expr const_expr; const_expr.mutable_const_expr().set_bool_value(instruction); ast_list.emplace_back().set_expr(std::move(const_expr)); }); for (const auto& instruction : select_instructions) { absl::visit(visitor, instruction); } return result; } // Returns a single select operation based on the inferred type of the operand // and the field name. If the operand type doesn't define the field, returns // nullopt. absl::optional GetSelectInstruction( const StructType& runtime_type, PlannerContext& planner_context, absl::string_view field_name) { auto field_or = planner_context.type_reflector() .FindStructTypeFieldByName(runtime_type, field_name) .value_or(absl::nullopt); if (field_or.has_value()) { return SelectInstruction{field_or->number(), std::string(field_or->name())}; } return absl::nullopt; } absl::StatusOr SelectQualifierFromList(const ListExpr& list) { if (list.elements().size() != 2) { return absl::InvalidArgumentError("Invalid cel.attribute select list"); } const Expr& field_number = list.elements()[0].expr(); const Expr& field_name = list.elements()[1].expr(); if (!field_number.has_const_expr() || !field_number.const_expr().has_int64_value()) { return absl::InvalidArgumentError( "Invalid cel.attribute field select number"); } if (!field_name.has_const_expr() || !field_name.const_expr().has_string_value()) { return absl::InvalidArgumentError( "Invalid cel.attribute field select name"); } return FieldSpecifier{field_number.const_expr().int64_value(), field_name.const_expr().string_value()}; } // Returns a qualifier instruction derived from a unoptimized ast. absl::StatusOr SelectInstructionFromConstant( const Constant& constant) { if (constant.has_int_value()) { return QualifierInstruction(constant.int_value()); } else if (constant.has_uint_value()) { return QualifierInstruction(constant.uint_value()); } else if (constant.has_bool_value()) { return QualifierInstruction(constant.bool_value()); } else if (constant.has_string_value()) { return QualifierInstruction(constant.string_value()); } else if (constant.has_double_value()) { cel::internal::Number number(constant.double_value()); if (number.LosslessConvertibleToInt()) { return QualifierInstruction(number.AsInt()); } else if (number.LosslessConvertibleToUint()) { return QualifierInstruction(number.AsUint()); } } return absl::InvalidArgumentError("invalid index constant for cel.attribute"); } absl::StatusOr SelectQualifierFromConstant( const Constant& constant) { if (constant.has_int_value()) { return AttributeQualifier::OfInt(constant.int_value()); } else if (constant.has_uint_value()) { return AttributeQualifier::OfUint(constant.uint_value()); } else if (constant.has_bool_value()) { return AttributeQualifier::OfBool(constant.bool_value()); } else if (constant.has_string_value()) { return AttributeQualifier::OfString(constant.string_value()); } // TODO(uncreated-issue/51): double keys could possibly be valid selectors, but // the other stacks don't implement the optimization yet and we normalize the // key to a uint or int if we do the late AST rewrite during planning. return absl::InvalidArgumentError("invalid cel.attribute constant"); } absl::StatusOr ListIndexFromQualifier(const AttributeQualifier& qual) { int64_t value = -1; switch (qual.kind()) { case Kind::kInt: value = *qual.GetInt64Key(); break; default: // TODO(uncreated-issue/51): type-checker will reject an unsigned literal, but // should be supported as a dyn / variable. return runtime_internal::CreateNoMatchingOverloadError( cel::builtin::kIndex); } if (value < 0) { return absl::InvalidArgumentError("list index less than 0"); } return static_cast(value); } absl::StatusOr MapKeyFromQualifier(const AttributeQualifier& qual, google::protobuf::Arena* absl_nonnull arena) { switch (qual.kind()) { case Kind::kInt: return cel::IntValue(*qual.GetInt64Key()); case Kind::kUint: return cel::UintValue(*qual.GetUint64Key()); case Kind::kBool: return cel::BoolValue(*qual.GetBoolKey()); case Kind::kString: return StringValue::From(*qual.GetStringKey(), arena); default: return runtime_internal::CreateNoMatchingOverloadError( cel::builtin::kIndex); } } absl::StatusOr ApplyQualifier( const Value& operand, const SelectQualifier& qualifier, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return absl::visit( absl::Overload( [&](const FieldSpecifier& field_specifier) -> absl::StatusOr { if (!operand.Is()) { return cel::ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError( "")); } CEL_ASSIGN_OR_RETURN( bool present, elem->GetStruct().HasFieldByName(field_specifier.name)); return cel::BoolValue(present); }, [&](const AttributeQualifier& qualifier) -> absl::StatusOr { if (!elem->Is() || qualifier.kind() != Kind::kString) { return cel::ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError( "has")); } return elem->GetMap().Has( StringValue(arena, *qualifier.GetStringKey()), descriptor_pool, message_factory, arena); }), last_instruction); } return ApplyQualifier(*elem, last_instruction, descriptor_pool, message_factory, arena); } absl::StatusOr> SelectInstructionsFromCall( const CallExpr& call) { if (call.args().size() < 2 || !call.args()[1].has_list_expr()) { return absl::InvalidArgumentError("Invalid cel.attribute call"); } std::vector instructions; const auto& ast_path = call.args()[1].list_expr().elements(); instructions.reserve(ast_path.size()); for (const ListExprElement& element : ast_path) { // Optimized field select. if (element.has_expr()) { const auto& element_expr = element.expr(); if (element_expr.has_list_expr()) { CEL_ASSIGN_OR_RETURN(instructions.emplace_back(), SelectQualifierFromList(element_expr.list_expr())); } else if (element_expr.has_const_expr()) { CEL_ASSIGN_OR_RETURN( instructions.emplace_back(), SelectQualifierFromConstant(element_expr.const_expr())); } else { return absl::InvalidArgumentError("Invalid cel.attribute call"); } } else { return absl::InvalidArgumentError("Invalid cel.attribute call"); } } // TODO(uncreated-issue/54): support for optionals. return instructions; } class RewriterImpl : public AstRewriterBase { public: RewriterImpl(const Ast& ast, PlannerContext& planner_context) : ast_(ast), planner_context_(planner_context) {} void PreVisitExpr(const Expr& expr) override { path_.push_back(&expr); } void PreVisitSelect(const Expr& expr, const SelectExpr& select) override { const Expr& operand = select.operand(); const std::string& field_name = select.field(); // Select optimization can generalize to lists and maps, but for now only // support message traversal. const TypeSpec checker_type = ast_.GetTypeOrDyn(operand.id()); absl::optional rt_type = (checker_type.has_message_type()) ? GetRuntimeType(checker_type.message_type().type()) : absl::nullopt; if (rt_type.has_value() && (*rt_type).Is()) { const StructType& runtime_type = rt_type->GetStruct(); absl::optional field_or = GetSelectInstruction(runtime_type, planner_context_, field_name); if (field_or.has_value()) { candidates_[&expr] = std::move(field_or).value(); } } else if (checker_type.has_map_type()) { candidates_[&expr] = QualifierInstruction(field_name); } // else // TODO(uncreated-issue/54): add support for either dyn or any. Excluded to // simplify program plan. } void PreVisitCall(const Expr& expr, const CallExpr& call) override { if (call.args().size() != 2 || call.function() != ::cel::builtin::kIndex) { return; } const auto& qualifier_expr = call.args()[1]; if (qualifier_expr.has_const_expr()) { auto qualifier_or = SelectInstructionFromConstant(qualifier_expr.const_expr()); if (!qualifier_or.ok()) { // TODO(uncreated-issue/54): should warn, but by default warnings fail overall // program planning. return; } candidates_[&expr] = std::move(qualifier_or).value(); } // TODO(uncreated-issue/54): support variable indexes } bool PostVisitRewrite(Expr& expr) override { if (!progress_status_.ok()) { return false; } path_.pop_back(); auto candidate_iter = candidates_.find(&expr); if (candidate_iter == candidates_.end()) { return false; } // On post visit, filter candidates that aren't rooted on a message or a // select chain. const QualifierInstruction& candidate = candidate_iter->second; if (!HasOptimizeableRoot(&expr, candidate)) { candidates_.erase(candidate_iter); return false; } if (!path_.empty() && candidates_.find(path_.back()) != candidates_.end()) { // parent is optimizeable, defer rewriting until we consider the parent. return false; } SelectPath path = GetSelectPath(&expr); // generate the new cel.attribute call. absl::string_view fn = path.test_only ? kCelHasField : kCelAttribute; Expr operand(std::move(*path.operand)); Expr call; call.set_id(expr.id()); call.mutable_call_expr().set_function(std::string(fn)); call.mutable_call_expr().mutable_args().reserve(2); call.mutable_call_expr().mutable_args().push_back(std::move(operand)); call.mutable_call_expr().mutable_args().push_back( MakeSelectPathExpr(path.select_instructions)); // TODO(uncreated-issue/54): support for optionals. expr = std::move(call); return true; } absl::Status GetProgressStatus() const { return progress_status_; } private: SelectPath GetSelectPath(Expr* expr) { SelectPath result; result.test_only = false; Expr* operand = expr; auto candidate_iter = candidates_.find(operand); while (candidate_iter != candidates_.end()) { result.select_instructions.push_back(candidate_iter->second); if (operand->has_select_expr()) { if (operand->select_expr().test_only()) { result.test_only = true; } operand = &(operand->mutable_select_expr().mutable_operand()); } else { ABSL_DCHECK(operand->has_call_expr()); operand = &(operand->mutable_call_expr().mutable_args()[0]); } candidate_iter = candidates_.find(operand); } absl::c_reverse(result.select_instructions); result.operand = operand; return result; } // Check whether the candidate has a message type as a root (the operand for // the batched select operation). // Called on post visit. bool HasOptimizeableRoot(const Expr* expr, const QualifierInstruction& candidate) { if (absl::holds_alternative(candidate)) { return true; } const Expr* operand = nullptr; if (expr->has_call_expr() && expr->call_expr().args().size() == 2 && expr->call_expr().function() == ::cel::builtin::kIndex) { operand = &expr->call_expr().args()[0]; } else if (expr->has_select_expr()) { operand = &expr->select_expr().operand(); } if (operand == nullptr) { return false; } return candidates_.find(operand) != candidates_.end(); } absl::optional GetRuntimeType(absl::string_view type_name) { return planner_context_.type_reflector().FindType(type_name).value_or( absl::nullopt); } void SetProgressStatus(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } } const Ast& ast_; PlannerContext& planner_context_; // ids of potentially optimizeable expr nodes. absl::flat_hash_map candidates_; std::vector path_; absl::Status progress_status_; }; class OptimizedSelectImpl { public: OptimizedSelectImpl(std::vector select_path, std::vector qualifiers, bool presence_test, SelectOptimizationOptions options) : select_path_(std::move(select_path)), qualifiers_(std::move(qualifiers)), presence_test_(presence_test), options_(options) { ABSL_DCHECK(!select_path_.empty()); } // Move constructible. OptimizedSelectImpl(const OptimizedSelectImpl&) = delete; OptimizedSelectImpl& operator=(const OptimizedSelectImpl&) = delete; OptimizedSelectImpl(OptimizedSelectImpl&&) = default; OptimizedSelectImpl& operator=(OptimizedSelectImpl&&) = delete; absl::StatusOr ApplySelect(ExecutionFrameBase& frame, const StructValue& struct_value) const; AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; absl::optional attribute() const { return attribute_; } const std::vector& qualifiers() const { return qualifiers_; } private: absl::optional attribute_; std::vector select_path_; std::vector qualifiers_; bool presence_test_; SelectOptimizationOptions options_; }; // Check for unknowns or missing attributes. absl::StatusOr> CheckForMarkedAttributes( ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { if (attribute_trail.empty()) { return absl::nullopt; } if (frame.unknown_processing_enabled() && frame.attribute_utility().CheckForUnknownExact(attribute_trail)) { // Check if the inferred attribute is marked. Only matches if this attribute // or a parent is marked unknown (use_partial = false). // Partial matches (i.e. descendant of this attribute is marked) aren't // considered yet in case another operation would select an unmarked // descended attribute. // // TODO(uncreated-issue/51): this may return a more specific attribute than the // declared pattern. Follow up will truncate the returned attribute to match // the pattern. return frame.attribute_utility().CreateUnknownSet( attribute_trail.attribute()); } if (frame.missing_attribute_errors_enabled() && frame.attribute_utility().CheckForMissingAttribute(attribute_trail)) { return frame.attribute_utility().CreateMissingAttributeError( attribute_trail.attribute()); } return absl::nullopt; } absl::StatusOr OptimizedSelectImpl::ApplySelect( ExecutionFrameBase& frame, const StructValue& struct_value) const { auto value_or = (options_.force_fallback_implementation) ? absl::UnimplementedError("Forced fallback impl") : struct_value.Qualify(select_path_, presence_test_, frame.descriptor_pool(), frame.message_factory(), frame.arena()); if (!value_or.ok()) { if (value_or.status().code() == absl::StatusCode::kUnimplemented) { return FallbackSelect(struct_value, select_path_, presence_test_, frame.descriptor_pool(), frame.message_factory(), frame.arena()); } return value_or.status(); } if (value_or->second < 0 || value_or->second >= select_path_.size()) { return std::move(value_or->first); } return FallbackSelect( value_or->first, absl::MakeConstSpan(select_path_).subspan(value_or->second), presence_test_, frame.descriptor_pool(), frame.message_factory(), frame.arena()); } AttributeTrail OptimizedSelectImpl::GetAttributeTrail( const AttributeTrail& operand_trail) const { if (operand_trail.empty()) { return AttributeTrail(); } std::vector qualifiers = std::vector( operand_trail.attribute().qualifier_path().begin(), operand_trail.attribute().qualifier_path().end()); qualifiers.reserve(qualifiers_.size() + qualifiers.size()); absl::c_copy(qualifiers_, std::back_inserter(qualifiers)); return AttributeTrail( Attribute(std::string(operand_trail.attribute().variable_name()), std::move(qualifiers))); } class StackMachineImpl : public ExpressionStepBase { public: StackMachineImpl(int expr_id, OptimizedSelectImpl impl) : ExpressionStepBase(expr_id), impl_(std::move(impl)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: // Get the effective attribute for the optimized select expression. // Assumes the operand is the top of stack if the attribute wasn't known at // plan time. AttributeTrail GetAttributeTrail(ExecutionFrame* frame) const; OptimizedSelectImpl impl_; }; AttributeTrail StackMachineImpl::GetAttributeTrail( ExecutionFrame* frame) const { const auto& attr = frame->value_stack().PeekAttribute(); return impl_.GetAttributeTrail(attr); } absl::Status StackMachineImpl::Evaluate(ExecutionFrame* frame) const { // Default empty. AttributeTrail attribute_trail; // TODO(uncreated-issue/51): add support for variable qualifiers and string literal // variable names. constexpr size_t kStackInputs = 1; // For now, we expect the operand to be top of stack. const Value& operand = frame->value_stack().Peek(); if (operand->Is() || operand->Is()) { // Just forward the error which is already top of stack. return absl::OkStatus(); } if (frame->enable_attribute_tracking()) { // Compute the attribute trail then check for any marked values. // When possible, this is computed at plan time based on the optimized // select arguments. // TODO(uncreated-issue/51): add support variable qualifiers attribute_trail = GetAttributeTrail(frame); CEL_ASSIGN_OR_RETURN(absl::optional value, CheckForMarkedAttributes(*frame, attribute_trail)); if (value.has_value()) { frame->value_stack().Pop(kStackInputs); frame->value_stack().Push(std::move(value).value(), std::move(attribute_trail)); return absl::OkStatus(); } } if (!operand->Is()) { return absl::InvalidArgumentError( "Expected struct type for select optimization."); } CEL_ASSIGN_OR_RETURN(Value result, impl_.ApplySelect(*frame, operand.GetStruct())); frame->value_stack().Pop(kStackInputs); frame->value_stack().Push(std::move(result), std::move(attribute_trail)); return absl::OkStatus(); } class RecursiveImpl : public DirectExpressionStep { public: RecursiveImpl(int64_t expr_id, std::unique_ptr operand, OptimizedSelectImpl impl) : DirectExpressionStep(expr_id), operand_(std::move(operand)), impl_(std::move(impl)) {} absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const override; private: // Get the effective attribute for the optimized select expression. // Assumes the operand is the top of stack if the attribute wasn't known at // plan time. AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; std::unique_ptr operand_; OptimizedSelectImpl impl_; }; AttributeTrail RecursiveImpl::GetAttributeTrail( const AttributeTrail& operand_trail) const { return impl_.GetAttributeTrail(operand_trail); } absl::Status RecursiveImpl::Evaluate(ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute)); if (InstanceOf(result) || InstanceOf(result)) { // Just forward. return absl::OkStatus(); } if (frame.attribute_tracking_enabled()) { attribute = impl_.GetAttributeTrail(attribute); CEL_ASSIGN_OR_RETURN(auto value, CheckForMarkedAttributes(frame, attribute)); if (value.has_value()) { result = std::move(value).value(); return absl::OkStatus(); } } if (!InstanceOf(result)) { return absl::InvalidArgumentError( "Expected struct type for select optimization"); } CEL_ASSIGN_OR_RETURN(result, impl_.ApplySelect(frame, Cast(result))); return absl::OkStatus(); } class SelectOptimizer : public ProgramOptimizer { public: explicit SelectOptimizer(const SelectOptimizationOptions& options) : options_(options) {} absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { return absl::OkStatus(); } absl::Status OnPostVisit(PlannerContext& context, const Expr& node) override; private: SelectOptimizationOptions options_; }; absl::Status SelectOptimizer::OnPostVisit(PlannerContext& context, const Expr& node) { if (!node.has_call_expr()) { return absl::OkStatus(); } absl::string_view fn = node.call_expr().function(); if (fn != kCelHasField && fn != kCelAttribute) { return absl::OkStatus(); } if (node.call_expr().args().size() < 2 || node.call_expr().args().size() > 3) { return absl::InvalidArgumentError("Invalid cel.attribute call"); } if (node.call_expr().args().size() == 3) { return absl::UnimplementedError("Optionals not yet supported"); } CEL_ASSIGN_OR_RETURN(std::vector instructions, SelectInstructionsFromCall(node.call_expr())); if (instructions.empty()) { return absl::InvalidArgumentError("Invalid cel.attribute no select steps."); } bool presence_test = false; if (fn == kCelHasField) { presence_test = true; } const Expr& operand = node.call_expr().args()[0]; absl::string_view identifier; if (operand.has_ident_expr()) { identifier = operand.ident_expr().name(); } if (absl::StrContains(identifier, ".")) { return absl::UnimplementedError("qualified identifiers not supported."); } std::vector qualifiers; qualifiers.reserve(instructions.size()); for (const auto& instruction : instructions) { qualifiers.push_back( absl::visit(absl::Overload( [](const FieldSpecifier& field) { return AttributeQualifier::OfString(field.name); }, [](const AttributeQualifier& q) { return q; }), instruction)); } // TODO(uncreated-issue/51): If the first argument is a string literal, the custom // step needs to handle variable lookup. auto* subexpression = context.program_builder().GetSubexpression(&node); if (subexpression == nullptr || subexpression->IsFlattened()) { // No information on the subprogram, can't optimize. return absl::OkStatus(); } OptimizedSelectImpl impl(std::move(instructions), std::move(qualifiers), presence_test, options_); if (subexpression->IsRecursive()) { auto program = subexpression->ExtractRecursiveProgram(); auto deps = program.step->ExtractDependencies(); if (!deps.has_value() || deps->empty()) { return absl::InvalidArgumentError("Unexpected cel.@attribute call"); } subexpression->set_recursive_program( std::make_unique(node.id(), std::move(deps->at(0)), std::move(impl)), program.depth); return absl::OkStatus(); } google::api::expr::runtime::ExecutionPath path; // else, we need to preserve the original plan for the first argument. if (context.GetSubplan(operand).empty()) { // Indicates another extension modified the step. Nothing to do here. return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto operand_subplan, context.ExtractSubplan(operand)); absl::c_move(operand_subplan, std::back_inserter(path)); path.push_back( std::make_unique(node.id(), std::move(impl))); return context.ReplaceSubplan(node, std::move(path)); } google::api::expr::runtime::FlatExprBuilder* GetFlatExprBuilder( RuntimeBuilder& builder) { auto& runtime = runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder); if (runtime_internal::RuntimeFriendAccess::RuntimeTypeId(runtime) == NativeTypeId::For()) { auto& runtime_impl = cel::internal::down_cast(runtime); return &runtime_impl.expr_builder(); } return nullptr; } } // namespace absl::Status SelectOptimizationAstUpdater::UpdateAst(PlannerContext& context, Ast& ast) const { RewriterImpl rewriter(ast, context); AstRewrite(ast.mutable_root_expr(), rewriter); return rewriter.GetProgressStatus(); } google::api::expr::runtime::ProgramOptimizerFactory CreateSelectOptimizationProgramOptimizer( const SelectOptimizationOptions& options) { return [=](PlannerContext& context, const Ast& ast) { return std::make_unique(options); }; } absl::Status EnableSelectOptimization( cel::RuntimeBuilder& builder, const SelectOptimizationOptions& options) { auto* flat_expr_builder = GetFlatExprBuilder(builder); if (flat_expr_builder == nullptr) { return absl::InvalidArgumentError( "SelectOptimization requires default runtime implementation"); } flat_expr_builder->AddAstTransform( std::make_unique()); // Add overloads for select optimization signature. // These are never bound, only used to prevent the builder from failing on // the overloads check. CEL_RETURN_IF_ERROR(builder.function_registry().RegisterLazyFunction( FunctionDescriptor(kCelAttribute, false, {Kind::kAny, Kind::kList}))); CEL_RETURN_IF_ERROR(builder.function_registry().RegisterLazyFunction( FunctionDescriptor(kCelHasField, false, {Kind::kAny, Kind::kList}))); // Add runtime implementation. flat_expr_builder->AddProgramOptimizer( CreateSelectOptimizationProgramOptimizer(options)); return absl::OkStatus(); } } // namespace cel::extensions ================================================ FILE: extensions/select_optimization.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_SELECT_OPTIMIZATION_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_SELECT_OPTIMIZATION_H_ #include "absl/status/status.h" #include "common/ast.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "runtime/runtime_builder.h" namespace cel::extensions { constexpr char kCelAttribute[] = "cel.@attribute"; constexpr char kCelHasField[] = "cel.@hasField"; // Configuration options for the select optimization. struct SelectOptimizationOptions { // Force the program to use the fallback implementation for the select. // This implementation simply collapses the select operation into one program // step and calls the normal field accessors on the Struct value. // // Normally, the fallback implementation is used when the Qualify operation is // unimplemented for a given StructType. This option is exposed for testing or // to more closely match behavior of unoptimized expressions. bool force_fallback_implementation = false; }; // Enable select optimization on the given RuntimeBuilder, replacing long // select chains with a single operation. // // This assumes that the type information at check time agrees with the // configured types at runtime. // // Important: The select optimization follows spec behavior for traversals. // - `enable_empty_wrapper_null_unboxing` is ignored and optimized traversals // always operates as though it is `true`. // - `enable_heterogeneous_equality` is ignored and optimized traversals // always operate as though it is `true`. // // This should only be called *once* on a given runtime builder. // // Assumes the default runtime implementation, an error with code // InvalidArgument is returned if it is not. // // Note: implementation does not support optional field traversal, and will // instead revert to the normal implementation instead of trying to optimize. absl::Status EnableSelectOptimization( cel::RuntimeBuilder& builder, const SelectOptimizationOptions& options = {}); // =============================================================== // Implementation details -- CEL users should not depend on these. // Exposed here for enabling on Legacy APIs. They expose internal details // which are not guaranteed to be stable. // =============================================================== // Scans ast for optimizable select branches. // // In general, this should be done by a type checker but may be deferred to // runtime. // // This assumes the runtime type registry has the same definitions as the one // used by the type checker. class SelectOptimizationAstUpdater : public google::api::expr::runtime::AstTransform { public: SelectOptimizationAstUpdater() = default; absl::Status UpdateAst(google::api::expr::runtime::PlannerContext& context, cel::Ast& ast) const override; }; google::api::expr::runtime::ProgramOptimizerFactory CreateSelectOptimizationProgramOptimizer( const SelectOptimizationOptions& options = {}); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SELECT_OPTIMIZATION_H_ ================================================ FILE: extensions/select_optimization_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/select_optimization.h" #include #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "google/protobuf/empty.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/ast.h" #include "base/attribute.h" #include "base/builtins.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/decl.h" #include "common/decl_proto.h" #include "common/expr.h" #include "common/kind.h" #include "common/memory.h" #include "common/value.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/optional.h" #include "compiler/standard_library.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/evaluator_core.h" #include "eval/internal/interop.h" #include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_info_apis.h" #include "extensions/protobuf/ast_converters.h" #include "internal/number.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/internal/issue_collector.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/extension_set.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::conformance::proto2::NestedTestAllTypes; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::runtime_internal::RuntimeEnv; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::CelProtoWrapper; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::FlatExprBuilder; using ::google::api::expr::runtime::FlatExpression; using ::google::api::expr::runtime::LegacyTypeAccessApis; using ::google::api::expr::runtime::LegacyTypeInfoApis; using ::google::api::expr::runtime::LegacyTypeMutationApis; using ::google::protobuf::Empty; using ::testing::_; using ::testing::AllOf; using ::testing::AnyOf; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::NiceMock; using ::testing::Return; using ::testing::SizeIs; using ::testing::Truly; namespace conformancepb = ::cel::expr::conformance; using MessageWrapper = CelValue::MessageWrapper; absl::Status ApplyDecl(absl::string_view decl, TypeCheckerBuilder& builder) { cel::expr::Decl decl_proto; if (!google::protobuf::TextFormat::ParseFromString(decl, &decl_proto)) { return absl::InvalidArgumentError("failed to parse decl"); } if (decl_proto.has_ident()) { CEL_ASSIGN_OR_RETURN( cel::VariableDecl d, cel::VariableDeclFromProto(decl_proto.name(), decl_proto.ident(), builder.descriptor_pool(), builder.arena())); CEL_RETURN_IF_ERROR(builder.AddVariable(std::move(d))); } else if (decl_proto.has_function()) { CEL_ASSIGN_OR_RETURN( cel::FunctionDecl d, cel::FunctionDeclFromProto(decl_proto.name(), decl_proto.function(), builder.descriptor_pool(), builder.arena())); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(d))); } else { return absl::InvalidArgumentError("decl has no ident or function"); } return absl::OkStatus(); } absl::StatusOr> NewTestCompiler() { CompilerOptions options; options.parser_options.enable_quoted_identifiers = true; CEL_ASSIGN_OR_RETURN(std::unique_ptr builder, cel::NewCompilerBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCompilerLibrary())); auto& checker_builder = builder->GetCheckerBuilder(); google::protobuf::LinkMessageReflection(); checker_builder.set_container("cel.expr.conformance"); CEL_RETURN_IF_ERROR(ApplyDecl( R"pb( name: "nested_test_all_types" ident { type { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } )pb", checker_builder)); CEL_RETURN_IF_ERROR(ApplyDecl( R"pb( name: "test_all_types" ident { type { message_type: "cel.expr.conformance.proto2.TestAllTypes" } } )pb", checker_builder)); CEL_RETURN_IF_ERROR(ApplyDecl( R"pb( name: "a" ident { type { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } )pb", checker_builder)); CEL_RETURN_IF_ERROR(ApplyDecl( R"pb( name: "b" ident { type { message_type: "cel.expr.conformance.proto2.NestedTestAllTypes" } } )pb", checker_builder)); CEL_RETURN_IF_ERROR(ApplyDecl( R"pb( name: "custom_predicate" function { overloads { doc: "An example predicate function for checking attribute tracking for " "the result of the optimized select chain." overload_id: "custom_predicate_TestAllTypesNestedType" params { message_type: "cel.expr.conformance.proto2.TestAllTypes.NestedMessage" } result_type { primitive: BOOL } } } )pb", checker_builder)); return builder->Build(); } const cel::Compiler& TestCaseCompiler() { static const Compiler* compiler = []() { auto compiler = NewTestCompiler(); ABSL_CHECK_OK(compiler); return compiler->release(); }(); return *compiler; } absl::StatusOr> CompileForTestCase( absl::string_view expr) { CEL_ASSIGN_OR_RETURN(cel::ValidationResult r, TestCaseCompiler().Compile(expr)); if (!r.IsValid()) { return absl::InvalidArgumentError(r.FormatError()); } return r.ReleaseAst(); } class MockAccessApis : public LegacyTypeInfoApis, public LegacyTypeAccessApis { public: std::string DebugString( const MessageWrapper& wrapped_message) const override { return "MockAccessApis"; } absl::string_view GetTypename( const MessageWrapper& wrapped_message) const override { return "MockAccessApis"; } const LegacyTypeAccessApis* GetAccessApis( const MessageWrapper& wrapped_message) const override { return this; } const LegacyTypeMutationApis* GetMutationApis( const MessageWrapper& wrapped_message) const override { return nullptr; } absl::optional FindFieldByName( absl::string_view field_name) const override { return absl::nullopt; } MOCK_METHOD(absl::StatusOr, GetField, (absl::string_view field_name, const CelValue::MessageWrapper& instance, ProtoWrapperTypeOptions unboxing_option, cel::MemoryManagerRef memory_manager), (const, override)); MOCK_METHOD(absl::StatusOr, HasField, (absl::string_view field_name, const CelValue::MessageWrapper& value), (const, override)); MOCK_METHOD(absl::StatusOr, Qualify, (absl::Span qualifiers, const CelValue::MessageWrapper& instance, bool presence_test, MemoryManagerRef memory_manager), (const, override)); bool IsEqualTo( const CelValue::MessageWrapper& instance, const CelValue::MessageWrapper& other_instance) const override { return false; } std::vector ListFields( const CelValue::MessageWrapper& instance) const override { return {}; } }; std::pair MakeMockLegacyMessage( google::protobuf::Arena* arena) { auto* mock_access_apis = google::protobuf::Arena::Create>(arena); auto* message = google::protobuf::Arena::Create(arena); CelValue::MessageWrapper::Builder wrapper(message); return {mock_access_apis, CelValue::CreateMessageWrapper(wrapper.Build(mock_access_apis))}; } absl::Status TestBindLegacyValue(absl::string_view variable, CelValue legacy_value, google::protobuf::Arena* arena, Activation& act) { CEL_ASSIGN_OR_RETURN(Value value, interop_internal::FromLegacyValue(arena, legacy_value)); act.InsertOrAssignValue(variable, std::move(value)); return absl::OkStatus(); } absl::Status TestBindLegacyMessage(absl::string_view variable, const google::protobuf::Message& message, google::protobuf::Arena* arena, cel::Activation& act) { CelValue legacy_value = CelProtoWrapper::CreateMessage(&message, arena); return TestBindLegacyValue(variable, legacy_value, arena, act); } class SelectOptimizationTest : public testing::Test { public: SelectOptimizationTest() : env_(NewTestingRuntimeEnv()), legacy_registry_(env_->legacy_type_registry), type_registry_(env_->type_registry), function_registry_(env_->function_registry), resolver_("", function_registry_, type_registry_, type_registry_.GetComposedTypeProvider()), issue_collector_(RuntimeIssue::Severity::kError), context_(env_, resolver_, runtime_options_, type_registry_.GetComposedTypeProvider(), issue_collector_, program_builder_, shared_arena_) { runtime_options_.fail_on_warnings = false; } void SetUp() override { google::protobuf::LinkMessageReflection(); ASSERT_THAT( function_registry_.Register( UnaryFunctionAdapter::CreateDescriptor( "custom_predicate", false), UnaryFunctionAdapter::WrapFunction( [](const StructValue&) { return true; })), IsOk()); } protected: absl_nonnull std::shared_ptr env_; google::api::expr::runtime::CelTypeRegistry& legacy_registry_; TypeRegistry& type_registry_; FunctionRegistry& function_registry_; google::protobuf::Arena arena_; RuntimeOptions runtime_options_; google::api::expr::runtime::Resolver resolver_; cel::runtime_internal::IssueCollector issue_collector_; google::api::expr::runtime::ProgramBuilder program_builder_; std::shared_ptr shared_arena_; google::api::expr::runtime::PlannerContext context_; }; MATCHER_P2(SelectFieldEntry, id, name, "") { const cel::Expr& entry = arg.expr(); if (entry.list_expr().elements().size() != 2) { *result_listener << "want 2-tuple entry, got " << entry.list_expr().elements().size(); return false; } int64_t got_id = entry.list_expr().elements()[0].expr().const_expr().int64_value(); absl::string_view got_name = entry.list_expr().elements()[1].expr().const_expr().string_value(); *result_listener << "want " << id << ": '" << name << "'" << " got " << got_id << ": '" << got_name << "'"; return entry.list_expr().elements()[0].expr().const_expr().int64_value() == id && entry.list_expr().elements()[1].expr().const_expr().string_value() == name; } std::string ToString(const AttributeQualifier& qualifier) { switch (qualifier.kind()) { case Kind::kInt: return absl::StrCat(*qualifier.GetInt64Key()); case Kind::kString: return absl::StrCat("'", *qualifier.GetStringKey(), "'"); case Kind::kUint: return absl::StrCat(*qualifier.GetUint64Key()); case Kind::kBool: return absl::StrCat(*qualifier.GetBoolKey()); default: return ""; } } MATCHER_P(SelectQualifier, qualifier, absl::StrCat("SelectQualifier: ", ToString(qualifier))) { const cel::Expr& entry = arg.expr(); if (!entry.has_const_expr()) { *result_listener << "wanted const_expr"; return false; } cel::AttributeQualifier got_qualifier; if (entry.const_expr().has_int64_value()) { got_qualifier = AttributeQualifier::OfInt(entry.const_expr().int64_value()); } else if (entry.const_expr().has_string_value()) { got_qualifier = AttributeQualifier::OfString(entry.const_expr().string_value()); } else if (entry.const_expr().has_bool_value()) { got_qualifier = AttributeQualifier::OfBool(entry.const_expr().bool_value()); } else if (entry.const_expr().has_uint64_value()) { got_qualifier = AttributeQualifier::OfUint(entry.const_expr().uint64_value()); } *result_listener << "want " << ToString(qualifier) << " got " << ToString(got_qualifier); return qualifier == got_qualifier; } TEST_F(SelectOptimizationTest, AstTransformSelect) { ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase( "nested_test_all_types.child.payload.standalone_message.bb")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); const auto& attr_call = ast->root_expr().call_expr(); EXPECT_EQ(attr_call.function(), "cel.@attribute"); ASSERT_THAT(attr_call.args(), SizeIs(2)); EXPECT_EQ(attr_call.args()[0].ident_expr().name(), "nested_test_all_types"); EXPECT_THAT( attr_call.args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(1, "child"), SelectFieldEntry(2, "payload"), SelectFieldEntry(23, "standalone_message"), SelectFieldEntry(1, "bb"))); } TEST_F(SelectOptimizationTest, AstTransformSelectPresence) { ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase( "has(nested_test_all_types.child.payload.standalone_message.bb)")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); const auto& attr_call = ast->root_expr().call_expr(); EXPECT_EQ(attr_call.function(), "cel.@hasField"); ASSERT_THAT(attr_call.args(), SizeIs(2)); EXPECT_EQ(attr_call.args()[0].ident_expr().name(), "nested_test_all_types"); EXPECT_THAT( attr_call.args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(1, "child"), SelectFieldEntry(2, "payload"), SelectFieldEntry(23, "standalone_message"), SelectFieldEntry(1, "bb"))); } TEST_F(SelectOptimizationTest, AstTransformComplexSelect) { ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase( "((false)? a.child.child : b.child).child.payload.single_int64")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); const auto& attr_call = ast->root_expr().call_expr(); EXPECT_EQ(attr_call.function(), "cel.@attribute"); ASSERT_THAT(attr_call.args(), SizeIs(2)); EXPECT_THAT( attr_call.args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(1, "child"), SelectFieldEntry(2, "payload"), SelectFieldEntry(2, "single_int64"))); const auto& operand = attr_call.args()[0]; EXPECT_EQ(operand.call_expr().function(), cel::builtin::kTernary); ASSERT_THAT(operand.call_expr().args(), SizeIs(3)); const auto& true_branch = operand.call_expr().args()[1]; EXPECT_EQ(true_branch.call_expr().function(), "cel.@attribute"); ASSERT_THAT(true_branch.call_expr().args(), SizeIs(2)); EXPECT_THAT( true_branch.call_expr().args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(1, "child"), SelectFieldEntry(1, "child"))); } TEST_F(SelectOptimizationTest, AstTransformMapIndexTraversal) { // nested_test_all_types.payload.map_string_message['$not_a_field'].bb ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CompileForTestCase("nested_test_all_types.payload.map_" "string_message['$not_a_field'].bb")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); const auto& attr_call = ast->root_expr().call_expr(); EXPECT_EQ(attr_call.function(), "cel.@attribute"); ASSERT_THAT(attr_call.args(), SizeIs(2)); EXPECT_THAT( attr_call.args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(2, "payload"), SelectFieldEntry(227, "map_string_message"), SelectQualifier(AttributeQualifier::OfString("$not_a_field")), SelectFieldEntry(1, "bb"))); const auto& operand = attr_call.args()[0]; EXPECT_EQ(operand.ident_expr().name(), "nested_test_all_types"); } TEST_F(SelectOptimizationTest, AstTransformMapIndexUnsupportedConstant) { // nested_test_all_types.payload.map_string_message['$not_a_field'].bb ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CompileForTestCase("nested_test_all_types.payload.map_" "string_message['$not_a_field'].bb")); // Type-checker shouldn't allow a bytes key, so simulating here for // coverage. ast->mutable_root_expr() .mutable_select_expr() .mutable_operand() .mutable_call_expr() .mutable_args()[1] .mutable_const_expr() .set_bytes_value("$not_a_field"); // We don't fail here, but we also don't optimize past the map lookup with // an unsupported constant key. SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); EXPECT_EQ(ast->root_expr().call_expr().function(), "cel.@attribute"); ASSERT_THAT(ast->root_expr().call_expr().args(), SizeIs(2)); EXPECT_EQ(ast->root_expr().call_expr().args()[0].call_expr().function(), "_[_]"); // cel.@attribute( // cel.@attribute( // nested_test_all_types, // [payload, map_string_message])[b'$not_a_field'], // [bb]) EXPECT_THAT(ast->root_expr().call_expr().args()[1].list_expr().elements(), SizeIs(1)); } TEST_F(SelectOptimizationTest, AstTransformMapIndexHeterogeneousDoubleKey) { ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase("nested_test_all_types.payload.single_any[1.0].bb")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); EXPECT_EQ(ast->root_expr().select_expr().field(), "bb"); // TODO(uncreated-issue/51): Right now we don't optimize past a dyn/any field // and discard the select optimization if the root isn't a message, so we will // consider the double as a candidate then discard it. EXPECT_THAT(ast->root_expr().select_expr().operand().call_expr().function(), "cel.@attribute"); ASSERT_THAT(ast->root_expr().select_expr().operand().call_expr().args(), SizeIs(2)); EXPECT_THAT(ast->root_expr() .select_expr() .operand() .call_expr() .args()[1] .list_expr() .elements(), SizeIs(3)); } TEST_F(SelectOptimizationTest, AstTransformMapIndexHeterogeneousDoubleKeyUint) { constexpr uint64_t kBigUint = static_cast(internal::kMaxDoubleRepresentableAsUint); ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase(absl::StrCat( "nested_test_all_types.payload.single_any[", kBigUint, ".0].bb"))); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); EXPECT_EQ(ast->root_expr().select_expr().field(), "bb"); // TODO(uncreated-issue/51): Right now we don't optimize past a dyn/any field // and discard additional select steps. EXPECT_THAT(ast->root_expr().select_expr().operand().call_expr().function(), "cel.@attribute"); ASSERT_THAT(ast->root_expr().select_expr().operand().call_expr().args(), SizeIs(2)); EXPECT_THAT(ast->root_expr() .select_expr() .operand() .call_expr() .args()[1] .list_expr() .elements(), SizeIs(3)); } TEST_F(SelectOptimizationTest, AstTransformFilterToMessageRoot) { // {'field_like_key': // nested_test_all_types}.field_like_key.payload.single_int64 ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase( "{'field_like_key': " "nested_test_all_types}.field_like_key.payload.single_int64")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); const auto& attr_call = ast->root_expr().call_expr(); EXPECT_EQ(attr_call.function(), "cel.@attribute"); ASSERT_THAT(attr_call.args(), SizeIs(2)); EXPECT_THAT(attr_call.args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(2, "payload"), SelectFieldEntry(2, "single_int64"))); const auto& operand = attr_call.args()[0]; EXPECT_EQ(operand.select_expr().field(), "field_like_key"); } TEST_F(SelectOptimizationTest, AstTransformMapDotTraversal) { // nested_test_all_types.payload.map_string_message.field_like_key.bb ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CompileForTestCase("nested_test_all_types.payload.map_" "string_message.field_like_key.bb")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); const auto& attr_call = ast->root_expr().call_expr(); EXPECT_EQ(attr_call.function(), "cel.@attribute"); ASSERT_THAT(attr_call.args(), SizeIs(2)); EXPECT_THAT(attr_call.args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(2, "payload"), SelectFieldEntry(227, "map_string_message"), SelectQualifier( AttributeQualifier::OfString("field_like_key")), SelectFieldEntry(1, "bb"))); const auto& operand = attr_call.args()[0]; EXPECT_EQ(operand.ident_expr().name(), "nested_test_all_types"); } TEST_F(SelectOptimizationTest, AstTransformAnyDotTraversal) { ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase( "nested_test_all_types.payload.single_any.single_int64")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); // When fully supported, we'd expect this to collapse to one attribute call. const auto& attr_call = ast->root_expr().select_expr().operand().call_expr(); EXPECT_EQ(attr_call.function(), "cel.@attribute"); ASSERT_THAT(attr_call.args(), SizeIs(2)); EXPECT_THAT(attr_call.args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(2, "payload"), SelectFieldEntry(100, "single_any"))); const auto& operand = attr_call.args()[0]; EXPECT_EQ(operand.ident_expr().name(), "nested_test_all_types"); } TEST_F(SelectOptimizationTest, AstTransformRepeated) { // nested_test_all_types.payload.repeated_nested_message[1].bb ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase( "nested_test_all_types.payload.repeated_nested_message[1].bb")); SelectOptimizationAstUpdater updater; EXPECT_THAT(updater.UpdateAst(context_, *ast), IsOk()); // When fully supported, we'd expect this to collapse to one attribute call. const auto& attr_call = ast->root_expr().call_expr(); EXPECT_EQ(attr_call.function(), "cel.@attribute"); ASSERT_THAT(attr_call.args(), SizeIs(2)); EXPECT_THAT(attr_call.args()[1].list_expr().elements(), ElementsAre(SelectFieldEntry(2, "payload"), SelectFieldEntry(51, "repeated_nested_message"), SelectQualifier(AttributeQualifier::OfInt(1)), SelectFieldEntry(1, "bb"))); const auto& operand = attr_call.args()[0]; EXPECT_EQ(operand.ident_expr().name(), "nested_test_all_types"); } TEST_F(SelectOptimizationTest, AstTransformParseOnlyNotUpdated) { google::protobuf::LinkMessageReflection(); FlatExprBuilder builder(env_, runtime_options_); builder.AddAstTransform(std::make_unique()); // nested_test_all_types.payload.repeated_nested_message[1].bb ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("nested_test_all_types.payload.repeated_nested_message[1].bb")); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CreateAstFromParsedExpr(expr)); ASSERT_OK_AND_ASSIGN(FlatExpression plan, builder.CreateExpressionImpl(std::move(ast), nullptr)); NestedTestAllTypes var; var.mutable_payload()->add_repeated_nested_message(); var.mutable_payload()->add_repeated_nested_message()->set_bb(42); cel::Activation act; ASSERT_THAT(TestBindLegacyMessage("nested_test_all_types", var, &arena_, act), IsOk()); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); ASSERT_OK_AND_ASSIGN( Value result, plan.EvaluateWithCallback( act, /*embedder_context=*/nullptr, google::api::expr::runtime::EvaluationListener(), state)); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); } TEST_F(SelectOptimizationTest, ProgramOptimizerUnoptimizedAst) { google::protobuf::LinkMessageReflection(); FlatExprBuilder builder(env_, runtime_options_); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); // nested_test_all_types.child.payload.standalone_message.bb ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase( "nested_test_all_types.child.payload.standalone_message.bb")); ASSERT_OK_AND_ASSIGN(FlatExpression plan, builder.CreateExpressionImpl(std::move(ast), nullptr)); NestedTestAllTypes var; var.mutable_child()->mutable_payload()->mutable_standalone_message()->set_bb( 42); cel::Activation act; ASSERT_THAT(TestBindLegacyMessage("nested_test_all_types", var, &arena_, act), IsOk()); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); ASSERT_OK_AND_ASSIGN( Value result, plan.EvaluateWithCallback( act, /*embedder_context=*/nullptr, google::api::expr::runtime::EvaluationListener(), state)); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); } TEST_F(SelectOptimizationTest, MissingAttributeIndependentOfUnknown) { google::protobuf::LinkMessageReflection(); RuntimeOptions options = runtime_options_; options.unknown_processing = UnknownProcessingOptions::kDisabled; options.enable_missing_attribute_errors = true; FlatExprBuilder builder(env_, options); builder.AddAstTransform(std::make_unique()); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase("custom_predicate(nested_test_all_types.child.payload." "standalone_message)")); ASSERT_OK_AND_ASSIGN(FlatExpression plan, builder.CreateExpressionImpl(std::move(ast), nullptr)); cel::Activation act; // activation only uses a ptr to the underlying message, persist them. NestedTestAllTypes var; act.SetMissingPatterns( {AttributePattern("nested_test_all_types", { AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("payload"), })}); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( child { payload { standalone_message { bb: 20 } } } )pb", &var)); ASSERT_THAT(TestBindLegacyMessage("nested_test_all_types", var, &arena_, act), IsOk()); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); ASSERT_OK_AND_ASSIGN( Value result, plan.EvaluateWithCallback( act, /*embedder_context=*/nullptr, google::api::expr::runtime::EvaluationListener(), state)); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_THAT(result.GetError().NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("nested_test_all_types.child.payload"))); } TEST_F(SelectOptimizationTest, NullUnboxingOptionHonored) { google::protobuf::LinkMessageReflection(); RuntimeOptions options = runtime_options_; options.enable_empty_wrapper_null_unboxing = true; FlatExprBuilder builder(env_, options); builder.AddAstTransform(std::make_unique()); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); // nested_test_all_types.payload.single_int64_wrapper ASSERT_OK_AND_ASSIGN( std::unique_ptr ast, CompileForTestCase("nested_test_all_types.payload.single_int64_wrapper")); ASSERT_OK_AND_ASSIGN(FlatExpression plan, builder.CreateExpressionImpl(std::move(ast), nullptr)); cel::Activation act; // activation only uses a ptr to the underlying message, persist them. NestedTestAllTypes var; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( payload {} )pb", &var)); ASSERT_THAT(TestBindLegacyMessage("nested_test_all_types", var, &arena_, act), IsOk()); auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); ASSERT_OK_AND_ASSIGN( Value result, plan.EvaluateWithCallback( act, /*embedder_context=*/nullptr, google::api::expr::runtime::EvaluationListener(), state)); ASSERT_TRUE(result->Is()) << result->DebugString(); } using ActivationSetupFn = std::function; struct ProgramOptimizerTestCase { std::string case_name; std::string expr; // identifier -> NestedTestAllTypes textproto absl::flat_hash_map vars; ActivationSetupFn setup_activation; std::function&)> expectations; }; class SelectOptimizationProgramOptimizerTest : public SelectOptimizationTest, public testing::WithParamInterface {}; TEST_P(SelectOptimizationProgramOptimizerTest, Default) { const ProgramOptimizerTestCase& test_case = GetParam(); google::protobuf::LinkMessageReflection(); RuntimeOptions options = runtime_options_; options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; options.enable_missing_attribute_errors = true; FlatExprBuilder builder(env_, options); builder.AddAstTransform(std::make_unique()); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CompileForTestCase(test_case.expr)); ASSERT_OK_AND_ASSIGN(FlatExpression plan, builder.CreateExpressionImpl(std::move(ast), nullptr)); cel::Activation act; // activation only uses a ptr to the underlying message, persist them. std::vector> vars; for (const auto& entry : test_case.vars) { vars.push_back(std::make_unique()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(entry.second, vars.back().get())); ASSERT_THAT(TestBindLegacyMessage(entry.first, *vars.back(), &arena_, act), IsOk()); } if (test_case.setup_activation != nullptr) { ASSERT_THAT(test_case.setup_activation(&arena_, act), IsOk()); } auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); absl::StatusOr result = plan.EvaluateWithCallback( act, /*embedder_context=*/nullptr, google::api::expr::runtime::EvaluationListener(), state); ASSERT_NO_FATAL_FAILURE(test_case.expectations(result)); } TEST_P(SelectOptimizationProgramOptimizerTest, ForceFallbackImpl) { const ProgramOptimizerTestCase& test_case = GetParam(); google::protobuf::LinkMessageReflection(); RuntimeOptions options = runtime_options_; options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; options.enable_missing_attribute_errors = true; FlatExprBuilder builder(env_, options); SelectOptimizationOptions select_options; select_options.force_fallback_implementation = true; builder.AddAstTransform(std::make_unique()); builder.AddProgramOptimizer( CreateSelectOptimizationProgramOptimizer(select_options)); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CompileForTestCase(test_case.expr)); ASSERT_OK_AND_ASSIGN(FlatExpression plan, builder.CreateExpressionImpl(std::move(ast), nullptr)); cel::Activation act; // activation only uses a ptr to the underlying message, persist them. std::vector> vars; for (const auto& entry : test_case.vars) { vars.push_back(std::make_unique()); ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(entry.second, vars.back().get())); ASSERT_THAT(TestBindLegacyMessage(entry.first, *vars.back(), &arena_, act), IsOk()); } if (test_case.setup_activation != nullptr) { ASSERT_THAT(test_case.setup_activation(&arena_, act), IsOk()); } auto state = plan.MakeEvaluatorState(env_->descriptor_pool.get(), env_->MutableMessageFactory(), &arena_); absl::StatusOr result = plan.EvaluateWithCallback( act, /*embedder_context=*/nullptr, google::api::expr::runtime::EvaluationListener(), state); ASSERT_NO_FATAL_FAILURE(test_case.expectations(result)); } INSTANTIATE_TEST_SUITE_P( TestCases, SelectOptimizationProgramOptimizerTest, testing::ValuesIn({ { "chained_select_success", "nested_test_all_types.child.payload.standalone_message.bb", {{"nested_test_all_types", R"pb( child { payload { standalone_message { bb: 42 } } } )pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); }, }, { "chained_select_defaults_success", "nested_test_all_types.child.payload.standalone_message.bb", {{"nested_test_all_types", R"pb()pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 0); }, }, { "chained_select_partial_success", "nested_test_all_types.child.payload.standalone_message.bb", {}, [](google::protobuf::Arena* arena, Activation& act) { auto mock_pair = MakeMockLegacyMessage(arena); MockAccessApis* mock = mock_pair.first; CelValue mocked_value = mock_pair.second; ON_CALL(*mock, Qualify(SizeIs(4), _, /*presence_test=*/false, _)) .WillByDefault( Return(LegacyTypeAccessApis::LegacyQualifyResult{ mocked_value, 3})); ON_CALL(*mock, GetField("bb", _, _, _)) .WillByDefault(Return(CelValue::CreateInt64(42))); // Support the forced-fallback case. ON_CALL(*mock, GetField(AnyOf(Eq("child"), Eq("payload"), Eq("standalone_message")), _, _, _)) .WillByDefault(Return(mocked_value)); return TestBindLegacyValue("nested_test_all_types", mocked_value, arena, act); }, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); }, }, { "chained_select_presence_partial_present", "has(nested_test_all_types.child.payload.standalone_message.bb)", {}, [](google::protobuf::Arena* arena, Activation& act) { auto mock_pair = MakeMockLegacyMessage(arena); MockAccessApis* mock = mock_pair.first; CelValue mocked_value = mock_pair.second; ON_CALL(*mock, Qualify(SizeIs(4), _, /*presence_test=*/true, _)) .WillByDefault( Return(LegacyTypeAccessApis::LegacyQualifyResult{ mocked_value, 3})); ON_CALL(*mock, HasField("bb", _)).WillByDefault(Return(true)); ON_CALL(*mock, GetField("bb", _, _, _)) .WillByDefault(Return(CelValue::CreateInt64(42))); // Support the forced-fallback case. ON_CALL(*mock, GetField(AnyOf(Eq("child"), Eq("payload"), Eq("standalone_message")), _, _, _)) .WillByDefault(Return(mocked_value)); ON_CALL(*mock, HasField(AnyOf(Eq("child"), Eq("payload"), Eq("standalone_message")), _)) .WillByDefault(Return(true)); return TestBindLegacyValue("nested_test_all_types", mocked_value, arena, act); }, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, { "chained_select_not_bound", "nested_test_all_types.child.payload.standalone_message.bb", {}, // not set ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_THAT(result.GetError().NativeValue(), StatusIs(absl::StatusCode::kUnknown, HasSubstr("nested_test_all_types"))); }, }, { // Some clients will use maps to represent a protobuf message at // runtime. This is not yet supported. "chained_select_map_as_root_unsupported", "nested_test_all_types.child.payload.standalone_message.bb", {}, // not set [](google::protobuf::Arena* arena, Activation& act) -> absl::Status { auto builder = cel::NewMapValueBuilder(arena); CEL_RETURN_IF_ERROR( builder->Put(cel::StringValue("child"), cel::NullValue())); auto value = std::move(*builder).Build(); act.InsertOrAssignValue("nested_test_all_types", std::move(value)); return absl::OkStatus(); }, [](const absl::StatusOr& got) { EXPECT_THAT(got.status(), StatusIs(absl::StatusCode::kInvalidArgument)); }, }, { // Some clients will use maps to represent a protobuf at runtime, // this is not yet supported. "chained_select_noncontainer_as_root_unsupported", "nested_test_all_types.child.payload.standalone_message.bb", {}, // not set [](google::protobuf::Arena* arena, Activation& act) { act.InsertOrAssignValue("nested_test_all_types", cel::DurationValue(absl::Seconds(1))); return absl::OkStatus(); }, [](const absl::StatusOr& got) { EXPECT_THAT(got.status(), StatusIs(absl::StatusCode::kInvalidArgument)); }, }, { "complex_select_success", "((false)? a.child.child : b.child).child.payload.single_int64", {{"a", ""}, {"b", R"pb( child { child { payload { single_int64: -42 } } } )pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), -42); }, }, { "chained_select_presence_present", "has(nested_test_all_types.child.payload.standalone_message.bb)", {{"nested_test_all_types", R"pb( child { payload { standalone_message { bb: 2 } } } )pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, { "chained_select_presence_not_present", "has(nested_test_all_types.child.payload.standalone_message.bb)", {{"nested_test_all_types", R"pb( child { payload { standalone_message {} } } )pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_FALSE(result.GetBool().NativeValue()); }, }, { "select_with_map_supported", "nested_test_all_types.payload.map_string_message['$not_a_field']." "bb", {{"nested_test_all_types", R"pb( payload { map_string_message { key: "$not_a_field", value { bb: 5 } } } )pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 5); }, }, { "select_with_map_no_such_key", "nested_test_all_types.payload.map_string_message['$not_a_field']." "bb", {{"nested_test_all_types", R"pb( payload { map_string_message { key: "a_different_field", value { bb: 5 } } } )pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_THAT(result.GetError().NativeValue(), StatusIs(absl::StatusCode::kNotFound, AllOf(HasSubstr("Key not found"), HasSubstr("$not_a_field")))); }, }, { "select_with_repeated_supported", "nested_test_all_types.payload.repeated_nested_message[1].bb", {{"nested_test_all_types", R"pb( payload { repeated_nested_message {} repeated_nested_message { bb: 7 } } )pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 7); }, }, { "select_with_repeated_index_out_of_bounds", "nested_test_all_types.payload.repeated_nested_message[1].bb", {{"nested_test_all_types", R"pb( payload { repeated_nested_message {} } )pb"}}, ActivationSetupFn(), [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_THAT(result.GetError().NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("index out of bounds"))); }, }, { "unknown_field", "((false)? a.child.child : b.child).child.payload.single_int64", {{"a", ""}, {"b", R"pb( child { child { payload { single_int64: -42 } } } )pb"}}, [](google::protobuf::Arena*, Activation& act) { act.SetUnknownPatterns({AttributePattern( "b", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("child")})}); return absl::OkStatus(); }, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_THAT( result.GetUnknown().attribute_set(), ElementsAre(Eq(Attribute( "b", { AttributeQualifier::OfString("child"), AttributeQualifier::OfString("child"), AttributeQualifier::OfString("payload"), AttributeQualifier::OfString("single_int64"), })))); }, }, { "unknown_field_partial", "((false)? a.child.child : b.child).child.payload.single_int64", {{"a", ""}, {"b", R"pb( child { child { payload { single_int64: -42 } } } )pb"}}, [](google::protobuf::Arena*, Activation& act) { act.SetUnknownPatterns({AttributePattern( "b", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("child")})}); return absl::OkStatus(); }, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), -42); }, }, { "unknown_ident", "((false)? a.child.child : b.child).child.payload.single_int64", {{"a", ""}, {"b", R"pb( child { child { payload { single_int64: -42 } } } )pb"}}, [](google::protobuf::Arena*, Activation& act) { act.SetUnknownPatterns({ AttributePattern("b", {}), }); return absl::OkStatus(); }, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_THAT(result.GetUnknown().attribute_set(), ElementsAre(Truly([](const Attribute& attr) { return attr.variable_name() == "b"; }))); }, }, { "unknown_pruned", "((false)? a.child.child : b.child).child.payload.single_int64", {{"a", ""}, {"b", R"pb( child { child { payload { single_int64: -42 } } } )pb"}}, [](google::protobuf::Arena*, Activation& act) { act.SetUnknownPatterns({ AttributePattern("a", {}), }); return absl::OkStatus(); }, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), -42); }, }, { "missing_field", "custom_predicate(nested_test_all_types.child.payload.standalone_" "message)", {{"nested_test_all_types", R"pb( child { payload { standalone_message { bb: 20 } } } )pb"}}, [](google::protobuf::Arena*, Activation& act) { act.SetMissingPatterns({AttributePattern( "nested_test_all_types", { AttributeQualifierPattern::OfString("child"), })}); return absl::OkStatus(); }, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_THAT(result.GetError().NativeValue().message(), HasSubstr("nested_test_all_types.child.payload." "standalone_message")); }, }, { "missing_field_partial", "custom_predicate(nested_test_all_types.child.payload.standalone_" "message)", {{"nested_test_all_types", R"pb( child { payload { standalone_message { bb: 20 } } } )pb"}}, [](google::protobuf::Arena*, Activation& act) { act.SetMissingPatterns({AttributePattern( "b", {AttributeQualifierPattern::OfString("child"), AttributeQualifierPattern::OfString("child")})}); return absl::OkStatus(); }, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, { "select_wrapper_int_leaf", "nested_test_all_types.payload.single_int64_wrapper", {{"nested_test_all_types", R"pb( payload { single_int64_wrapper { value: 10 } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 10); }, }, { "select_repeated_leaf", "nested_test_all_types.payload.repeated_int64", {{"nested_test_all_types", R"pb( payload { repeated_int64: 10 } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); }, }, { "select_map_leaf", "nested_test_all_types.payload.map_string_int64", {{"nested_test_all_types", R"pb( payload { map_string_int64 { key: "key", value: 12 } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); }, }, { "select_with_map_dot", "nested_test_all_types.payload.map_string_message.field_like_key." "bb", {{"nested_test_all_types", R"pb( payload { map_string_message { key: "field_like_key", value { bb: 42 } } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); }, }, { "select_with_map_bool", "nested_test_all_types.payload.map_bool_message[false].bb", {{"nested_test_all_types", R"pb( payload { map_bool_message { key: false, value { bb: 42 } } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); }, }, { "select_with_map_int", "nested_test_all_types.payload.map_int64_message[-1].bb", {{"nested_test_all_types", R"pb( payload { map_int64_message { key: -1, value { bb: 42 } } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); }, }, { "select_with_map_uint", "nested_test_all_types.payload.map_uint64_message[1u].bb", {{"nested_test_all_types", R"pb( payload { map_uint64_message { key: 1, value { bb: 42 } } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); }, }, { "select_with_repeated", "nested_test_all_types.payload.repeated_nested_message[1].bb", {{"nested_test_all_types", R"pb( payload { repeated_nested_message {} repeated_nested_message { bb: 42 } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); }, }, { "select_with_any", "nested_test_all_types.payload.single_any.single_int64", {{"nested_test_all_types", R"pb( payload { single_any { [type.googleapis.com/cel.expr.conformance.proto2 .TestAllTypes] { single_int64: 42 } } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_EQ(result.GetInt().NativeValue(), 42); }, }, { "has_repeated_leaf_true", "has(nested_test_all_types.payload.repeated_int64)", {{"nested_test_all_types", R"pb( payload { repeated_int64: 42 } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, { "has_repeated_leaf_false", "has(nested_test_all_types.payload.repeated_int64)", {{"nested_test_all_types", R"pb( payload {} )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_FALSE(result.GetBool().NativeValue()); }, }, { "has_map_leaf_true", "has(nested_test_all_types.payload.map_string_int64)", {{"nested_test_all_types", R"pb( payload { map_string_int64 { key: "string" value: 12 } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, { "has_map_leaf_false", "has(nested_test_all_types.payload.map_string_int64)", {{"nested_test_all_types", R"pb( payload {} )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_FALSE(result.GetBool().NativeValue()); }, }, { "has_map_field_like_key", "has(nested_test_all_types.payload.map_string_int64.field_like_" "key)", {{"nested_test_all_types", R"pb( payload { map_string_int64 { key: "field_like_key" value: 12 } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, { "has_map_field_like_key_false", "has(nested_test_all_types.payload.map_string_int64.field_like_" "key)", {{"nested_test_all_types", R"pb( payload { map_string_int64 { key: "wrong_key" value: 12 } } )pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_FALSE(result.GetBool().NativeValue()); }, }, { "select_wrong_runtime_type", "test_all_types.single_int64", {{}}, [](google::protobuf::Arena* arena, Activation& activation) { activation.InsertOrAssignValue("test_all_types", cel::IntValue(42)); return absl::OkStatus(); }, [](const absl::StatusOr& got) { EXPECT_THAT(got, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Expected struct type"))); }, }, { "select_with_struct", "nested_test_all_types.payload.single_struct['key']['subkey']", {{"nested_test_all_types", R"pb(payload { single_struct { fields { key: "key" value { struct_value { fields { key: "subkey" value { bool_value: true } } } } } } })pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, { "select_with_list_value", "nested_test_all_types.payload.list_value[0]['subkey']", {{"nested_test_all_types", R"pb(payload { list_value { values { struct_value { fields { key: "subkey" value { bool_value: true } } } } } })pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, { "select_with_value", "nested_test_all_types.payload.single_value['key']['subkey']", {{"nested_test_all_types", R"pb(payload { single_value { struct_value { fields { key: "key" value { struct_value { fields { key: "subkey" value { bool_value: true } } } } } } } })pb"}}, nullptr, [](const absl::StatusOr& got) { ASSERT_OK_AND_ASSIGN(Value result, got); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_TRUE(result.GetBool().NativeValue()); }, }, }), [](const testing::TestParamInfo& info) { return info.param.case_name; }); // Tests covering unexpected / malformed ASTs. // // These cases shouldn't be possible under normal usage, but are possible if // there's a bug in the optimizer implementation or if a hand-rolled AST is // used. class SelectOptimizationUnexpectedAstTest : public SelectOptimizationTest { public: SelectOptimizationUnexpectedAstTest() : SelectOptimizationTest(), next_id_(1) {} Expr NextExpr() { Expr result; result.set_id(next_id_++); return result; } cel::ListExprElement NextListExprElement() { cel::ListExprElement element; element.set_expr(NextExpr()); return element; } protected: int64_t next_id_; }; TEST_F(SelectOptimizationUnexpectedAstTest, WrongArgumentCount) { std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_ident_expr() .set_name("ident"); FlatExprBuilder builder(env_, runtime_options_); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(SelectOptimizationUnexpectedAstTest, EmptySelectPath) { std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_ident_expr() .set_name("ident"); ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_list_expr(); FlatExprBuilder builder(env_, runtime_options_); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(SelectOptimizationUnexpectedAstTest, MalformedSelectPathNotPair) { std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_ident_expr() .set_name("ident"); auto& select_step_list = ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_list_expr(); auto& select_step_element = select_step_list.mutable_elements() .emplace_back(NextListExprElement()) .mutable_expr() .mutable_list_expr(); select_step_element.mutable_elements() .emplace_back(NextListExprElement()) .mutable_expr() .mutable_const_expr() .set_string_value("field"); FlatExprBuilder builder(env_, runtime_options_); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(SelectOptimizationUnexpectedAstTest, MalformedSelectPathWrongPairTypes) { std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_ident_expr() .set_name("ident"); auto& select_step_list = ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_list_expr(); auto& select_step_element = select_step_list.mutable_elements() .emplace_back(NextListExprElement()) .mutable_expr() .mutable_list_expr(); select_step_element.mutable_elements() .emplace_back(NextListExprElement()) .mutable_expr() .mutable_const_expr() .set_string_value("field"); select_step_element.mutable_elements() .emplace_back(NextListExprElement()) .mutable_expr() .mutable_const_expr() .set_int64_value(1); FlatExprBuilder builder(env_, runtime_options_); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(SelectOptimizationUnexpectedAstTest, MalformedSelectPathUnsupportedConstant) { std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_ident_expr() .set_name("ident"); auto& select_step_list = ast->mutable_root_expr() .mutable_call_expr() .mutable_args() .emplace_back(NextExpr()) .mutable_list_expr(); auto& select_step_element = select_step_list.mutable_elements() .emplace_back(NextListExprElement()) .mutable_expr(); select_step_element.mutable_const_expr().set_bytes_value("bytes_key"); FlatExprBuilder builder(env_, runtime_options_); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(SelectOptimizationUnexpectedAstTest, OptionalNotYetSupported) { std::unique_ptr ast = std::make_unique(NextExpr(), SourceInfo()); ast->mutable_root_expr().mutable_call_expr().set_function(kCelAttribute); auto& call_args = ast->mutable_root_expr().mutable_call_expr().mutable_args(); call_args.emplace_back(NextExpr()).mutable_ident_expr().set_name("ident"); auto& list_expr = call_args.emplace_back(NextExpr()).mutable_list_expr(); auto& fields = list_expr.mutable_elements() .emplace_back(NextListExprElement()) .mutable_expr() .mutable_list_expr() .mutable_elements(); fields.emplace_back(NextListExprElement()) .mutable_expr() .mutable_const_expr() .set_int64_value(1); fields.emplace_back(NextListExprElement()) .mutable_expr() .mutable_const_expr() .set_string_value("field"); call_args.emplace_back(NextExpr()).mutable_const_expr().set_int64_value(0); FlatExprBuilder builder(env_, runtime_options_); builder.AddProgramOptimizer(CreateSelectOptimizationProgramOptimizer()); EXPECT_THAT(builder.CreateExpressionImpl(std::move(ast), nullptr), StatusIs(absl::StatusCode::kUnimplemented)); } } // namespace } // namespace cel::extensions ================================================ FILE: extensions/sets_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/sets_functions.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/function_adapter.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { using google::api::expr::runtime::CelFunctionRegistry; using google::api::expr::runtime::ConvertToRuntimeOptions; using google::api::expr::runtime::InterpreterOptions; namespace { absl::StatusOr SetsContains( const ListValue& list, const ListValue& sublist, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { bool any_missing = false; CEL_RETURN_IF_ERROR(sublist.ForEach( [&](const Value& sublist_element) -> absl::StatusOr { CEL_ASSIGN_OR_RETURN(auto contains, list.Contains(sublist_element, descriptor_pool, message_factory, arena)); // Treat CEL error as missing any_missing = !contains->Is() || !contains.GetBool().NativeValue(); // The first false result will terminate the loop. return !any_missing; }, descriptor_pool, message_factory, arena)); return BoolValue(!any_missing); } absl::StatusOr SetsIntersects( const ListValue& list, const ListValue& sublist, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { bool exists = false; CEL_RETURN_IF_ERROR(list.ForEach( [&](const Value& list_element) -> absl::StatusOr { CEL_ASSIGN_OR_RETURN(auto contains, sublist.Contains(list_element, descriptor_pool, message_factory, arena)); // Treat contains return CEL error as false for the sake of // intersecting. exists = contains->Is() && contains.GetBool().NativeValue(); return !exists; }, descriptor_pool, message_factory, arena)); return BoolValue(exists); } absl::StatusOr SetsEquivalent( const ListValue& list, const ListValue& sublist, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN( auto contains_sublist, SetsContains(list, sublist, descriptor_pool, message_factory, arena)); if (contains_sublist.Is() && !contains_sublist.GetBool().NativeValue()) { return contains_sublist; } return SetsContains(sublist, list, descriptor_pool, message_factory, arena); } absl::Status RegisterSetsContainsFunction(FunctionRegistry& registry) { return registry.Register( BinaryFunctionAdapter< absl::StatusOr, const ListValue&, const ListValue&>::CreateDescriptor("sets.contains", /*receiver_style=*/false), BinaryFunctionAdapter, const ListValue&, const ListValue&>::WrapFunction(SetsContains)); } absl::Status RegisterSetsIntersectsFunction(FunctionRegistry& registry) { return registry.Register( BinaryFunctionAdapter< absl::StatusOr, const ListValue&, const ListValue&>::CreateDescriptor("sets.intersects", /*receiver_style=*/false), BinaryFunctionAdapter, const ListValue&, const ListValue&>::WrapFunction(SetsIntersects)); } absl::Status RegisterSetsEquivalentFunction(FunctionRegistry& registry) { return registry.Register( BinaryFunctionAdapter< absl::StatusOr, const ListValue&, const ListValue&>::CreateDescriptor("sets.equivalent", /*receiver_style=*/false), BinaryFunctionAdapter, const ListValue&, const ListValue&>::WrapFunction(SetsEquivalent)); } absl::Status RegisterSetsDecls(TypeCheckerBuilder& b) { ListType list_t(b.arena(), TypeParamType("T")); CEL_ASSIGN_OR_RETURN( auto decl, MakeFunctionDecl("sets.contains", MakeOverloadDecl("list_sets_contains_list", BoolType(), list_t, list_t))); CEL_RETURN_IF_ERROR(b.AddFunction(decl)); CEL_ASSIGN_OR_RETURN( decl, MakeFunctionDecl("sets.equivalent", MakeOverloadDecl("list_sets_equivalent_list", BoolType(), list_t, list_t))); CEL_RETURN_IF_ERROR(b.AddFunction(decl)); CEL_ASSIGN_OR_RETURN( decl, MakeFunctionDecl("sets.intersects", MakeOverloadDecl("list_sets_intersects_list", BoolType(), list_t, list_t))); return b.AddFunction(decl); } } // namespace CheckerLibrary SetsCheckerLibrary() { return {.id = "cel.lib.ext.sets", .configure = RegisterSetsDecls}; } absl::Status RegisterSetsFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { CEL_RETURN_IF_ERROR(RegisterSetsContainsFunction(registry)); CEL_RETURN_IF_ERROR(RegisterSetsIntersectsFunction(registry)); CEL_RETURN_IF_ERROR(RegisterSetsEquivalentFunction(registry)); return absl::OkStatus(); } absl::Status RegisterSetsFunctions(CelFunctionRegistry* registry, const InterpreterOptions& options) { return RegisterSetsFunctions(registry->InternalGetRegistry(), ConvertToRuntimeOptions(options)); } } // namespace cel::extensions ================================================ FILE: extensions/sets_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ #include "absl/status/status.h" #include "checker/type_checker_builder.h" #include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel::extensions { // Declarations for the sets functions. CheckerLibrary SetsCheckerLibrary(); inline CompilerLibrary SetsCompilerLibrary() { return CompilerLibrary::FromCheckerLibrary(SetsCheckerLibrary()); } // Register set functions. absl::Status RegisterSetsFunctions(FunctionRegistry& registry, const RuntimeOptions& options); absl::Status RegisterSetsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_SETS_FUNCTIONS_H_ ================================================ FILE: extensions/sets_functions_benchmark_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "common/value.h" #include "eval/internal/interop.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "extensions/sets_functions.h" #include "internal/benchmark.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::cel::Value; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::ContainerBackedListImpl; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; using ::google::api::expr::runtime::RegisterBuiltinFunctions; enum class ListImpl : int { kLegacy = 0, kWrappedModern = 1, kRhsConstant = 2 }; int ToNumber(ListImpl impl) { return static_cast(impl); } ListImpl FromNumber(int number) { switch (number) { case 0: return ListImpl::kLegacy; case 1: return ListImpl::kWrappedModern; case 2: return ListImpl::kRhsConstant; default: return ListImpl::kLegacy; } } struct TestCase { std::string test_name; std::string expr; ListImpl list_impl; int size; CelValue result; std::string MakeLabel(int len) const { std::string list_impl; switch (this->list_impl) { case ListImpl::kRhsConstant: list_impl = "rhs_constant"; break; case ListImpl::kWrappedModern: list_impl = "wrapped_modern"; break; case ListImpl::kLegacy: list_impl = "legacy"; break; } return absl::StrCat(test_name, "/", list_impl, "/", len); } }; class ListStorage { public: virtual ~ListStorage() = default; }; class LegacyListStorage : public ListStorage { public: LegacyListStorage(ContainerBackedListImpl x, ContainerBackedListImpl y) : x_(std::move(x)), y_(std::move(y)) {} CelValue x() { return CelValue::CreateList(&x_); } CelValue y() { return CelValue::CreateList(&y_); } private: ContainerBackedListImpl x_; ContainerBackedListImpl y_; }; class ModernListStorage : public ListStorage { public: ModernListStorage(Value x, Value y) : x_(std::move(x)), y_(std::move(y)) {} CelValue x() { return interop_internal::ModernValueToLegacyValueOrDie(&arena_, x_); } CelValue y() { return interop_internal::ModernValueToLegacyValueOrDie(&arena_, y_); } private: google::protobuf::Arena arena_; Value x_; Value y_; }; absl::StatusOr> RegisterLegacyLists( bool overlap, int len, Activation& activation) { std::vector x; std::vector y; x.reserve(len + 1); y.reserve(len + 1); if (overlap) { x.push_back(CelValue::CreateInt64(2)); y.push_back(CelValue::CreateInt64(1)); } for (int i = 0; i < len; i++) { x.push_back(CelValue::CreateInt64(1)); y.push_back(CelValue::CreateInt64(2)); } auto result = std::make_unique( ContainerBackedListImpl(std::move(x)), ContainerBackedListImpl(std::move(y))); activation.InsertValue("x", result->x()); activation.InsertValue("y", result->y()); return result; } // Constant list literal that has the same elements as the bound test cases. std::string ConstantList(bool overlap, int len) { std::string list_body; for (int i = 0; i < len; i++) { } return absl::StrCat("[", overlap ? "1, " : "", absl::StrJoin(std::vector(len, "2"), ", "), "]"); } absl::StatusOr> RegisterModernLists( bool overlap, int len, google::protobuf::Arena* absl_nonnull arena, Activation& activation) { auto x_builder = cel::NewListValueBuilder(arena); auto y_builder = cel::NewListValueBuilder(arena); x_builder->Reserve(len + 1); y_builder->Reserve(len + 1); if (overlap) { CEL_RETURN_IF_ERROR(x_builder->Add(cel::IntValue(2))); CEL_RETURN_IF_ERROR(y_builder->Add(cel::IntValue(1))); } for (int i = 0; i < len; i++) { CEL_RETURN_IF_ERROR(x_builder->Add(cel::IntValue(1))); CEL_RETURN_IF_ERROR(y_builder->Add(cel::IntValue(2))); } auto x = std::move(*x_builder).Build(); auto y = std::move(*y_builder).Build(); auto result = std::make_unique(std::move(x), std::move(y)); activation.InsertValue("x", result->x()); activation.InsertValue("y", result->y()); return result; } absl::StatusOr> RegisterLists( bool overlap, int len, bool use_modern, google::protobuf::Arena* absl_nonnull arena, Activation& activation) { if (use_modern) { return RegisterModernLists(overlap, len, arena, activation); } else { return RegisterLegacyLists(overlap, len, activation); } } void RunBenchmark(const TestCase& test_case, benchmark::State& state) { bool lists_overlap = test_case.result.BoolOrDie(); std::string expr = test_case.expr; if (test_case.list_impl == ListImpl::kRhsConstant) { expr = absl::StrReplaceAll( expr, {{"y", ConstantList(lists_overlap, test_case.size)}}); } ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(expr)); google::protobuf::Arena arena; InterpreterOptions options; options.constant_folding = true; options.constant_arena = &arena; options.enable_qualified_identifier_rewrites = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK(RegisterSetsFunctions(builder->GetRegistry()->InternalGetRegistry(), cel::RuntimeOptions{})); ASSERT_OK_AND_ASSIGN( auto cel_expr, builder->CreateExpression(&(parsed_expr.expr()), nullptr)); Activation activation; ASSERT_OK_AND_ASSIGN( auto storage, RegisterLists(test_case.result.BoolOrDie(), test_case.size, test_case.list_impl == ListImpl::kWrappedModern, &arena, activation)); state.SetLabel(test_case.MakeLabel(test_case.size)); for (auto _ : state) { ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(result.IsBool()); ASSERT_EQ(result.BoolOrDie(), test_case.result.BoolOrDie()) << test_case.test_name; } } void BM_SetsIntersectsTrue(benchmark::State& state) { ListImpl impl = FromNumber(state.range(0)); int size = state.range(1); RunBenchmark({"sets.intersects_true", "sets.intersects(x, y)", impl, size, CelValue::CreateBool(true)}, state); } void BM_SetsIntersectsFalse(benchmark::State& state) { ListImpl impl = FromNumber(state.range(0)); int size = state.range(1); RunBenchmark({"sets.intersects_false", "sets.intersects(x, y)", impl, size, CelValue::CreateBool(false)}, state); } void BM_SetsIntersectsComprehensionTrue(benchmark::State& state) { ListImpl impl = FromNumber(state.range(0)); int size = state.range(1); RunBenchmark({"comprehension_intersects_true", "x.exists(i, i in y)", impl, size, CelValue::CreateBool(true)}, state); } void BM_SetsIntersectsComprehensionFalse(benchmark::State& state) { ListImpl impl = FromNumber(state.range(0)); int size = state.range(1); RunBenchmark({"comprehension_intersects_false", "x.exists(i, i in y)", impl, size, CelValue::CreateBool(false)}, state); } void BM_SetsEquivalentTrue(benchmark::State& state) { ListImpl impl = FromNumber(state.range(0)); int size = state.range(1); RunBenchmark({"sets.equivalent_true", "sets.equivalent(x, y)", impl, size, CelValue::CreateBool(true)}, state); } void BM_SetsEquivalentFalse(benchmark::State& state) { ListImpl impl = FromNumber(state.range(0)); int size = state.range(1); RunBenchmark({"sets.equivalent_false", "sets.equivalent(x, y)", impl, size, CelValue::CreateBool(false)}, state); } void BM_SetsEquivalentComprehensionTrue(benchmark::State& state) { ListImpl impl = FromNumber(state.range(0)); int size = state.range(1); RunBenchmark( {"comprehension_equivalent_true", "x.all(i, i in y) && y.all(j, j in x)", impl, size, CelValue::CreateBool(true)}, state); } void BM_SetsEquivalentComprehensionFalse(benchmark::State& state) { ListImpl impl = FromNumber(state.range(0)); int size = state.range(1); RunBenchmark( {"comprehension_equivalent_false", "x.all(i, i in y) && y.all(j, j in x)", impl, size, CelValue::CreateBool(false)}, state); } template void BenchArgs(Benchmark* bench) { for (ListImpl impl : {ListImpl::kLegacy, ListImpl::kWrappedModern, ListImpl::kRhsConstant}) { for (int size : {1, 8, 32, 64, 256}) { bench->ArgPair(ToNumber(impl), size); } } } BENCHMARK(BM_SetsIntersectsComprehensionTrue)->Apply(BenchArgs); BENCHMARK(BM_SetsIntersectsComprehensionFalse)->Apply(BenchArgs); BENCHMARK(BM_SetsIntersectsTrue)->Apply(BenchArgs); BENCHMARK(BM_SetsIntersectsFalse)->Apply(BenchArgs); BENCHMARK(BM_SetsEquivalentComprehensionTrue)->Apply(BenchArgs); BENCHMARK(BM_SetsEquivalentComprehensionFalse)->Apply(BenchArgs); BENCHMARK(BM_SetsEquivalentTrue)->Apply(BenchArgs); BENCHMARK(BM_SetsEquivalentFalse)->Apply(BenchArgs); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/sets_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/sets_functions.h" #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status_matchers.h" #include "checker/standard_library.h" #include "checker/validation_result.h" #include "common/ast_proto.h" #include "common/minimal_descriptor_pool.h" #include "compiler/compiler_factory.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "internal/testing.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::FunctionAdapter; using ::google::api::expr::runtime::InterpreterOptions; using ::absl_testing::IsOk; using ::google::protobuf::Arena; struct TestInfo { std::string expr; }; class CelSetsFunctionsTest : public testing::TestWithParam {}; TEST_P(CelSetsFunctionsTest, EndToEnd) { const TestInfo& test_info = GetParam(); ASSERT_OK_AND_ASSIGN(auto compiler_builder, NewCompilerBuilder(cel::GetMinimalDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(SetsCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); ASSERT_OK_AND_ASSIGN(ValidationResult compiled, compiler->Compile(test_info.expr)); ASSERT_TRUE(compiled.IsValid()) << compiled.FormatError(); cel::expr::CheckedExpr checked_expr; ASSERT_THAT(AstToCheckedExpr(*compiled.GetAst(), &checked_expr), IsOk()); // Obtain CEL Expression builder. InterpreterOptions options; options.enable_heterogeneous_equality = true; options.enable_empty_wrapper_null_unboxing = true; options.enable_qualified_identifier_rewrites = true; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_THAT(RegisterSetsFunctions(builder->GetRegistry(), options), IsOk()); ASSERT_THAT(google::api::expr::runtime::RegisterBuiltinFunctions( builder->GetRegistry(), options), IsOk()); // Create CelExpression from AST (Expr object). ASSERT_OK_AND_ASSIGN(auto cel_expr, builder->CreateExpression(&checked_expr)); Arena arena; Activation activation; // Run evaluation. ASSERT_OK_AND_ASSIGN(CelValue out, cel_expr->Evaluate(activation, &arena)); ASSERT_TRUE(out.IsBool()) << test_info.expr << " -> " << out.DebugString(); EXPECT_TRUE(out.BoolOrDie()) << test_info.expr << " -> " << out.DebugString(); } INSTANTIATE_TEST_SUITE_P( CelSetsFunctionsTest, CelSetsFunctionsTest, testing::ValuesIn({ {"sets.contains([], [])"}, {"sets.contains([1], [])"}, {"sets.contains([1], [1])"}, {"sets.contains([1], [1, 1])"}, {"sets.contains([1, 1], [1])"}, {"sets.contains([2, 1], [1])"}, {"sets.contains([1], [1.0, 1u])"}, {"sets.contains([1, 2], [2u, 2.0])"}, {"sets.contains([1, 2u], [2, 2.0])"}, {"!sets.contains([1], [2])"}, {"!sets.contains([1], [1, 2])"}, {"!sets.contains([1], [\"1\", 1])"}, {"!sets.contains([1], [1.1, 2])"}, {"sets.intersects([1], [1])"}, {"sets.intersects([1], [1, 1])"}, {"sets.intersects([1, 1], [1])"}, {"sets.intersects([2, 1], [1])"}, {"sets.intersects([1], [1, 2])"}, {"sets.intersects([1], [1.0, 2])"}, {"sets.intersects([1, 2], [2u, 2, 2.0])"}, {"sets.intersects([1, 2], [1u, 2, 2.3])"}, {"!sets.intersects([], [])"}, {"!sets.intersects([1], [])"}, {"!sets.intersects([1], [2])"}, {"!sets.intersects([1], [\"1\", 2])"}, {"!sets.intersects([1], [1.1, 2u])"}, {"sets.equivalent([], [])"}, {"sets.equivalent([1], [1])"}, {"sets.equivalent([1], [1, 1])"}, {"sets.equivalent([1, 1, 2], [2, 2, 1])"}, {"sets.equivalent([1, 1], [1])"}, {"sets.equivalent([1], [1u, 1.0])"}, {"sets.equivalent([1], [1u, 1.0])"}, {"sets.equivalent([1, 2, 3], [3u, 2.0, 1])"}, {"!sets.equivalent([2, 1], [1])"}, {"!sets.equivalent([1], [1, 2])"}, {"!sets.equivalent([1, 2], [2u, 2, 2.0])"}, {"!sets.equivalent([1, 2], [1u, 2, 2.3])"}, {"sets.equivalent([false, true], [true, false])"}, {"!sets.equivalent([true], [false])"}, {"sets.equivalent(['foo', 'bar'], ['bar', 'foo'])"}, {"!sets.equivalent(['foo'], ['bar'])"}, {"sets.equivalent([b'foo', b'bar'], [b'bar', b'foo'])"}, {"!sets.equivalent([b'foo'], [b'bar'])"}, {"sets.equivalent([null], [null])"}, {"!sets.equivalent([null], [])"}, {"sets.equivalent([type(1), type(1u)], [type(1u), type(1)])"}, {"!sets.equivalent([type(1)], [type(1u)])"}, {"sets.equivalent([duration('0s'), duration('1s')], [duration('1s'), " "duration('0s')])"}, {"!sets.equivalent([duration('0s')], [duration('1s')])"}, {"sets.equivalent([timestamp('1970-01-01T00:00:00Z'), " "timestamp('1970-01-01T00:00:01Z')], " "[timestamp('1970-01-01T00:00:01Z'), " "timestamp('1970-01-01T00:00:00Z')])"}, {"!sets.equivalent([timestamp('1970-01-01T00:00:00Z')], " "[timestamp('1970-01-01T00:00:01Z')])"}, {"sets.equivalent([[false, true]], [[false, true]])"}, {"!sets.equivalent([[false, true]], [[true, false]])"}, {"sets.equivalent([{'foo': true, 'bar': false}], [{'bar': false, " "'foo': true}])"}, })); } // namespace } // namespace cel::extensions ================================================ FILE: extensions/strings.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/strings.h" #include #include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "checker/internal/builtins_arena.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "extensions/formatting.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { using ::cel::checker_internal::BuiltinsArena; struct AppendToStringVisitor { std::string& append_to; void operator()(absl::string_view string) const { append_to.append(string); } void operator()(const absl::Cord& cord) const { append_to.append(static_cast(cord)); } }; absl::StatusOr Join2( const ListValue& value, const StringValue& separator, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return separator.Join(value, descriptor_pool, message_factory, arena); } absl::StatusOr Join1( const ListValue& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return StringValue().Join(value, descriptor_pool, message_factory, arena); } absl::StatusOr Split3( const StringValue& string, const StringValue& delimiter, int64_t limit, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return string.Split(delimiter, limit, arena); } absl::StatusOr Split2( const StringValue& string, const StringValue& delimiter, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return string.Split(delimiter, arena); } absl::StatusOr Replace2(const StringValue& string, const StringValue& old_sub, const StringValue& new_sub, int64_t limit, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { return string.Replace(old_sub, new_sub, limit, arena); } absl::StatusOr Replace1( const StringValue& string, const StringValue& old_sub, const StringValue& new_sub, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return string.Replace(old_sub, new_sub, -1, arena); } Value CharAt(const StringValue& string, int64_t pos) { return string.CharAt(pos); } int64_t IndexOf2(const StringValue& haystack, const StringValue& needle) { return haystack.IndexOf(needle).value_or(-1); } Value IndexOf3(const StringValue& haystack, const StringValue& needle, int64_t pos) { if (pos > haystack.Size()) { return ErrorValue{ absl::InvalidArgumentError(absl::StrCat("index out of range: ", pos))}; } return IntValue(haystack.IndexOf(needle, pos).value_or(-1)); } int64_t LastIndexOf2(const StringValue& haystack, const StringValue& needle) { return haystack.LastIndexOf(needle).value_or(-1); } Value LastIndexOf3(const StringValue& haystack, const StringValue& needle, int64_t pos) { if (pos < 0 || pos > haystack.Size()) { return ErrorValue{ absl::InvalidArgumentError(absl::StrCat("index out of range: ", pos))}; } return IntValue(haystack.LastIndexOf(needle, pos).value_or(-1)); } Value Substring2(const StringValue& string, int64_t start) { return string.Substring(start); } Value Substring3(const StringValue& string, int64_t start, int64_t end) { return string.Substring(start, end); } StringValue Trim(const StringValue& string) { return string.Trim(); } StringValue LowerAscii(const StringValue& string, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { return string.LowerAscii(arena); } StringValue UpperAscii(const StringValue& string, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { return string.UpperAscii(arena); } StringValue Quote(const StringValue& string, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { return string.Quote(arena); } StringValue Reverse(const StringValue& string, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { return string.Reverse(arena); } const Type& ListStringType() { static absl::NoDestructor kInstance( ListType(BuiltinsArena(), StringType())); return *kInstance; } absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { // Runtime Supported functions. CEL_ASSIGN_OR_RETURN( auto join_decl, MakeFunctionDecl( "join", MakeMemberOverloadDecl("list_join", StringType(), ListStringType()), MakeMemberOverloadDecl("list_join_string", StringType(), ListStringType(), StringType()))); CEL_ASSIGN_OR_RETURN( auto split_decl, MakeFunctionDecl( "split", MakeMemberOverloadDecl("string_split_string", ListStringType(), StringType(), StringType()), MakeMemberOverloadDecl("string_split_string_int", ListStringType(), StringType(), StringType(), IntType()))); CEL_ASSIGN_OR_RETURN( auto lower_decl, MakeFunctionDecl("lowerAscii", MakeMemberOverloadDecl("string_lower_ascii", StringType(), StringType()))); CEL_ASSIGN_OR_RETURN( auto replace_decl, MakeFunctionDecl( "replace", MakeMemberOverloadDecl("string_replace_string_string", StringType(), StringType(), StringType(), StringType()), MakeMemberOverloadDecl("string_replace_string_string_int", StringType(), StringType(), StringType(), StringType(), IntType()))); // Additional functions described in the spec. CEL_ASSIGN_OR_RETURN( auto char_at_decl, MakeFunctionDecl( "charAt", MakeMemberOverloadDecl("string_char_at_int", StringType(), StringType(), IntType()))); CEL_ASSIGN_OR_RETURN( auto index_of_decl, MakeFunctionDecl( "indexOf", MakeMemberOverloadDecl("string_index_of_string", IntType(), StringType(), StringType()), MakeMemberOverloadDecl("string_index_of_string_int", IntType(), StringType(), StringType(), IntType()))); CEL_ASSIGN_OR_RETURN( auto last_index_of_decl, MakeFunctionDecl( "lastIndexOf", MakeMemberOverloadDecl("string_last_index_of_string", IntType(), StringType(), StringType()), MakeMemberOverloadDecl("string_last_index_of_string_int", IntType(), StringType(), StringType(), IntType()))); CEL_ASSIGN_OR_RETURN( auto substring_decl, MakeFunctionDecl( "substring", MakeMemberOverloadDecl("string_substring_int", StringType(), StringType(), IntType()), MakeMemberOverloadDecl("string_substring_int_int", StringType(), StringType(), IntType(), IntType()))); CEL_ASSIGN_OR_RETURN( auto upper_ascii_decl, MakeFunctionDecl("upperAscii", MakeMemberOverloadDecl("string_upper_ascii", StringType(), StringType()))); CEL_ASSIGN_OR_RETURN( auto format_decl, MakeFunctionDecl("format", MakeMemberOverloadDecl("string_format", StringType(), StringType(), ListType()))); CEL_ASSIGN_OR_RETURN( auto quote_decl, MakeFunctionDecl( "strings.quote", MakeOverloadDecl("strings_quote", StringType(), StringType()))); CEL_ASSIGN_OR_RETURN( auto reverse_decl, MakeFunctionDecl("reverse", MakeMemberOverloadDecl("string_reverse", StringType(), StringType()))); CEL_ASSIGN_OR_RETURN( auto trim_decl, MakeFunctionDecl("trim", MakeMemberOverloadDecl( "string_trim", StringType(), StringType()))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(split_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(lower_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(replace_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(char_at_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(index_of_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(last_index_of_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(substring_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(upper_ascii_decl))); CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(trim_decl))); if (version == 0) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(format_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(quote_decl))); if (version == 1) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(join_decl))); if (version == 2) { return absl::OkStatus(); } // MergeFunction is used to combine with the reverse function // defined in cel.lib.ext.lists extension. CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl))); return absl::OkStatus(); } } // namespace absl::Status RegisterStringsFunctions( FunctionRegistry& registry, const RuntimeOptions& options, const StringsExtensionOptions& extension_options) { const int version = extension_options.version; CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), BinaryFunctionAdapter, StringValue, StringValue>::WrapFunction(Split2))); CEL_RETURN_IF_ERROR(registry.Register( TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, int64_t>::CreateDescriptor("split", /*receiver_style=*/true), TernaryFunctionAdapter, StringValue, StringValue, int64_t>::WrapFunction(Split3))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, StringValue>:: CreateDescriptor("lowerAscii", /*receiver_style=*/true), UnaryFunctionAdapter, StringValue>::WrapFunction( LowerAscii))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, StringValue>:: CreateDescriptor("upperAscii", /*receiver_style=*/true), UnaryFunctionAdapter, StringValue>::WrapFunction( UpperAscii))); CEL_RETURN_IF_ERROR(registry.Register( TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, StringValue>::CreateDescriptor("replace", /*receiver_style=*/true), TernaryFunctionAdapter, StringValue, StringValue, StringValue>::WrapFunction(Replace1))); CEL_RETURN_IF_ERROR(registry.Register( QuaternaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, StringValue, int64_t>::CreateDescriptor("replace", /*receiver_style=*/true), QuaternaryFunctionAdapter, StringValue, StringValue, StringValue, int64_t>::WrapFunction(Replace2))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("charAt", &CharAt, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("indexOf", &IndexOf2, registry))); CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter::RegisterMemberOverload("indexOf", &IndexOf3, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("lastIndexOf", &LastIndexOf2, registry))); CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter::RegisterMemberOverload("lastIndexOf", &LastIndexOf3, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("substring", &Substring2, registry))); CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter::RegisterMemberOverload("substring", &Substring3, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "trim", &Trim, registry))); if (version == 0) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions( registry, options, {extension_options.max_precision})); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "strings.quote", &Quote, registry))); if (version == 1) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( "join", /*receiver_style=*/true), UnaryFunctionAdapter, ListValue>::WrapFunction( Join1))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, ListValue, StringValue>:: CreateDescriptor("join", /*receiver_style=*/true), BinaryFunctionAdapter, ListValue, StringValue>::WrapFunction(Join2))); if (version == 2) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "reverse", &Reverse, registry))); return absl::OkStatus(); } absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options, const StringsExtensionOptions& extension_options) { return RegisterStringsFunctions( registry->InternalGetRegistry(), google::api::expr::runtime::ConvertToRuntimeOptions(options), extension_options); } CheckerLibrary StringsCheckerLibrary(const StringsExtensionOptions& options) { const int version = options.version; return {"strings", [version](TypeCheckerBuilder& builder) { return RegisterStringsDecls(builder, version); }}; } } // namespace cel::extensions ================================================ FILE: extensions/strings.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ #include "absl/status/status.h" #include "checker/type_checker_builder.h" #include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel::extensions { constexpr int kStringsExtensionLatestVersion = 4; struct StringsExtensionOptions { int version = kStringsExtensionLatestVersion; // Maximum precision allowed for floating point format specifiers in // format() function. This is used for both fixed and scientific notations. // Value must be in the range [0, 1000], otherwise clamped. // // Does not affect default precisions for %e and %f format specifiers. int max_precision = 1000; }; // Register extension functions for strings. absl::Status RegisterStringsFunctions( FunctionRegistry& registry, const RuntimeOptions& options, const StringsExtensionOptions& extension_options = {}); absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options, const StringsExtensionOptions& extension_options = {}); CheckerLibrary StringsCheckerLibrary( const StringsExtensionOptions& extension_options = {}); inline CheckerLibrary StringsCheckerLibrary(int version) { StringsExtensionOptions options; options.version = version; return StringsCheckerLibrary(options); } inline CompilerLibrary StringsCompilerLibrary( const StringsExtensionOptions& options = {}) { return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(options)); } inline CompilerLibrary StringsCompilerLibrary(int version) { StringsExtensionOptions options; options.version = version; return StringsCompilerLibrary(options); } } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_STRINGS_H_ ================================================ FILE: extensions/strings_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/strings.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "checker/standard_library.h" #include "checker/type_check_issue.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/options.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "testutil/baseline_tests.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Values; using ::testing::ValuesIn; TEST(StringsCheckerLibrary, SmokeTest) { ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StringsCheckerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("foo", StringType())), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN( ValidationResult result, compiler->Compile("foo.replace('he', 'we', 1) == 'wello hello'")); ASSERT_TRUE(result.IsValid()); EXPECT_EQ(test::FormatBaselineAst(*result.GetAst()), R"(_==_( foo~string^foo.replace( "he"~string, "we"~string, 1~int )~string^string_replace_string_string_int, "wello hello"~string )~bool^equals)"); } TEST(StringsExtTest, MaxPrecisionOption) { StringsExtensionOptions extension_options; extension_options.max_precision = 99; ASSERT_OK_AND_ASSIGN( auto compiler_builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, compiler_builder->Build()); ASSERT_OK_AND_ASSIGN( ValidationResult result, compiler->Compile("'abc %.100f'.format([2.0])", "")); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( auto runtime_builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_THAT(RegisterStringsFunctions(runtime_builder.function_registry(), opts, extension_options), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); google::protobuf::Arena arena; cel::Activation activation; ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value.Is()); EXPECT_THAT(value.GetError().ToStatus(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("precision specifier exceeds maximum of 99"))); } using StringsExtFunctionsTest = testing::TestWithParam; TEST_P(StringsExtFunctionsTest, ParserAndCheckerTests) { const std::string& expr = GetParam(); ASSERT_OK_AND_ASSIGN( auto compiler_builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); auto result = compiler->Compile(expr, ""); ASSERT_THAT(result, IsOk()); ASSERT_TRUE(result->IsValid()); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( auto runtime_builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_THAT( RegisterStringsFunctions(runtime_builder.function_registry(), opts), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(*result->ReleaseAst())); google::protobuf::Arena arena; cel::Activation activation; ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value.Is()); EXPECT_TRUE(value.GetBool().NativeValue()); } INSTANTIATE_TEST_SUITE_P( StringsExtMacrosParamsTest, StringsExtFunctionsTest, testing::Values( // Tests for split() "'hello world!'.split('') == ['h', 'e', 'l', 'l', 'o', ' ', " "'w', 'o', 'r', 'l', 'd', '!']", // Tests for replace() "'hello hello'.replace('he', 'we') == 'wello wello'", "'hello hello'.replace('he', 'we', -1) == 'wello wello'", "'hello hello'.replace('he', 'we', 1) == 'wello hello'", "'hello hello'.replace('he', 'we', 0) == 'hello hello'", // Tests for lowerAscii() "'UPPER lower'.lowerAscii() == 'upper lower'", // Tests for upperAscii() "'UPPER lower'.upperAscii() == 'UPPER LOWER'", // Tests for format() "'abc %.3f'.format([2.0]) == 'abc 2.000'", // Tests for charAt() "'tacocat'.charAt(3) == 'o'", "'tacocat'.charAt(7) == ''", "'©αT'.charAt(0) == '©' && '©αT'.charAt(1) == 'α' && '©αT'.charAt(2) " "== 'T'", // Tests for indexOf() "'tacocat'.indexOf('') == 0", "'tacocat'.indexOf('ac') == 1", "'tacocat'.indexOf('none') == -1", "'tacocat'.indexOf('', 3) == 3", "'tacocat'.indexOf('a', 3) == 5", "'tacocat'.indexOf('at', 3) == 5", "'ta©o©αT'.indexOf('©') == 2", "'ta©o©αT'.indexOf('©', 3) == 4", "'ta©o©αT'.indexOf('©αT', 3) == 4", "'ta©o©αT'.indexOf('©α', 5) == -1", "'ijk'.indexOf('k') == 2", "'hello wello'.indexOf('hello wello') == 0", "'hello wello'.indexOf('ello', 6) == 7", "'hello wello'.indexOf('elbo room!!') == -1", "'hello wello'.indexOf('elbo room!!!') == -1", "''.lastIndexOf('@@') == -1", "'tacocat'.lastIndexOf('') == 7", "'tacocat'.lastIndexOf('at') == 5", "'tacocat'.lastIndexOf('none') == -1", "'tacocat'.lastIndexOf('', 3) == 3", "'tacocat'.lastIndexOf('a', 3) == 1", "'ta©o©αT'.lastIndexOf('©') == 4", "'ta©o©αT'.lastIndexOf('©', 3) == 2", "'ta©o©αT'.lastIndexOf('©α', 4) == 4", "'hello wello'.lastIndexOf('ello', 6) == 1", "'hello wello'.lastIndexOf('low') == -1", "'hello wello'.lastIndexOf('elbo room!!') == -1", "'hello wello'.lastIndexOf('elbo room!!!') == -1", "'hello wello'.lastIndexOf('hello wello') == 0", "'bananananana'.lastIndexOf('nana', 7) == 6", // Tests for substring() "'tacocat'.substring(4) == 'cat'", "'tacocat'.substring(7) == ''", "'tacocat'.substring(0, 4) == 'taco'", "'tacocat'.substring(4, 4) == ''", "'ta©o©αT'.substring(2, 6) == '©o©α'", "'ta©o©αT'.substring(7, 7) == ''", // Tests for reverse() "''.reverse() == ''", "'hello'.reverse() == 'olleh'", "'©αT'.reverse() == 'Tα©'", "'gums'.reverse() == 'smug'", "'palindromes'.reverse() == 'semordnilap'", "'John Smith'.reverse() == 'htimS nhoJ'", "'u180etext'.reverse() == 'txete081u'", "'2600+U'.reverse() == 'U+0062'", "'\u180e\u200b\u200c\u200d\u2060\ufeff'.reverse() == " "'\ufeff\u2060\u200d\u200c\u200b\u180e'", // Tests for strings.quote() R"(strings.quote("first\nsecond") == "\"first\\nsecond\"")", R"(strings.quote("bell\a") == "\"bell\\a\"")", R"(strings.quote("\bbackspace") == "\"\\bbackspace\"")", R"(strings.quote("\fform feed") == "\"\\fform feed\"")", R"(strings.quote("carriage \r return") == "\"carriage \\r return\"")", R"(strings.quote("vertical \v tab") == "\"vertical \\v tab\"")", R"(strings.quote("verbatim") == "\"verbatim\"")", R"(strings.quote("ends with \\") == "\"ends with \\\\\"")", R"(strings.quote("\\ starts with") == "\"\\\\ starts with\"")", // Tests for trim() R"(' \f\n\r\t\vtext '.trim() == 'text')", R"('\u0085\u00a0\u1680text'.trim() == 'text')", R"('text\u2000\u2001\u2002\u2003\u2004\u2004\u2006\u2007\u2008\u2009'.trim() == 'text')", R"('\u200atext\u2028\u2029\u202F\u205F\u3000'.trim() == 'text')", R"(' hello world '.trim() == 'hello world')")); // Basic test for the included declarations. // Additional coverage for behavior in the spec tests. class StringsCheckerLibraryTest : public ::testing::TestWithParam { }; TEST_P(StringsCheckerLibraryTest, TypeChecks) { const std::string& expr = GetParam(); ASSERT_OK_AND_ASSIGN( auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(StringsCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(expr)); EXPECT_TRUE(result.IsValid()) << "Failed to compile: " << expr; } INSTANTIATE_TEST_SUITE_P( Expressions, StringsCheckerLibraryTest, Values("['a', 'b', 'c'].join() == 'abc'", "['a', 'b', 'c'].join('|') == 'a|b|c'", "'a|b|c'.split('|') == ['a', 'b', 'c']", "'a|b|c'.split('|', 1) == ['a', 'b|c']", "'a|b|c'.split('|') == ['a', 'b', 'c']", "'AbC'.lowerAscii() == 'abc'", "'tacocat'.replace('cat', 'dog') == 'tacodog'", "'tacocat'.replace('aco', 'an', 2) == 'tacocat'", "'tacocat'.charAt(2) == 'c'", "'tacocat'.indexOf('c') == 2", "'tacocat'.indexOf('c', 3) == 4", "'tacocat'.lastIndexOf('c') == 4", "'tacocat'.lastIndexOf('c', 5) == -1", "'tacocat'.substring(1) == 'acocat'", "'tacocat'.substring(1, 3) == 'aco'", "'aBc'.upperAscii() == 'ABC'", "'abc %d'.format([2]) == 'abc 2'", "strings.quote('abc') == \"'abc 2'\"", "'abc'.reverse() == 'cba'", "'ta©o©αT'.substring(7, 7) == ''")); class StringsOverloadNotFoundTest : public ::testing::TestWithParam {}; TEST_P(StringsOverloadNotFoundTest, PlannerTests) { const std::string& expr_string = GetParam(); const auto options = RuntimeOptions{}; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(expr_string, "", ParserOptions{})); EXPECT_THAT( ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), absl_testing::StatusIs(absl::StatusCode::kInvalidArgument, testing::HasSubstr("No overloads provided"))); } INSTANTIATE_TEST_SUITE_P( OverloadNotFound, StringsOverloadNotFoundTest, Values( // string_ext.type_errors/indexof_ternary_invalid_arguments "'42'.indexOf('4', 0, 1) == 0", // string_ext.type_errors/replace_quaternary_invalid_argument "'42'.replace('2', '1', 1, false) == '41'", // string_ext.type_errors/split_ternary_invalid_argument "'42'.split('2', 1, 1) == ['4']", // string_ext.type_errors/substring_ternary_invalid_argument "'hello'.substring(1, 2, 3) == ''")); class StringsRuntimeErrorTest : public ::testing::TestWithParam {}; TEST_P(StringsRuntimeErrorTest, EvaluationErrors) { const std::string& expr = GetParam(); ASSERT_OK_AND_ASSIGN( auto compiler_builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); auto result = compiler->Compile(expr, ""); ASSERT_THAT(result, IsOk()); ASSERT_TRUE(result->IsValid()); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( auto runtime_builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_THAT( RegisterStringsFunctions(runtime_builder.function_registry(), opts), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(*result->ReleaseAst())); google::protobuf::Arena arena; cel::Activation activation; ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value.Is()); EXPECT_THAT(value.As()->NativeValue().code(), absl::StatusCode::kInvalidArgument); } INSTANTIATE_TEST_SUITE_P(EvaluationErrors, StringsRuntimeErrorTest, Values("'a'.substring(-1)", "'a'.substring(2)", "'a'.substring(0, -1)", "'a'.substring(0, 2)", "'a'.substring(1, 0)")); struct StringsExtensionVersionTestCase { std::string expr; std::vector expected_supported_versions; }; class StringsExtensionVersionTest : public ::testing::TestWithParam {}; TEST_P(StringsExtensionVersionTest, StringsExtensionVersions) { const StringsExtensionVersionTestCase& test_case = GetParam(); for (int version = 0; version <= cel::extensions::kStringsExtensionLatestVersion; ++version) { CompilerLibrary compiler_library = StringsCompilerLibrary(version); ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), CompilerOptions())); ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(test_case.expr)); if (absl::c_contains(test_case.expected_supported_versions, version)) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << "Expected no issues for expr: " << test_case.expr << " at version: " << version << " but got: " << result.FormatError(); } else { EXPECT_THAT(result.GetIssues(), Contains(Property(&TypeCheckIssue::message, HasSubstr("undeclared reference")))); } } }; std::vector CreateStringsExtensionVersionParams() { return { StringsExtensionVersionTestCase{ .expr = "'foo'.charAt(0)", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.indexOf('f')", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.lastIndexOf('f')", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.lowerAscii()", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.replace('f', 'b')", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.split('o')", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.substring(0, 1)", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.trim()", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.upperAscii()", .expected_supported_versions = {0, 1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'%d'.format([1])", .expected_supported_versions = {1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "strings.quote('foo')", .expected_supported_versions = {1, 2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "['a', 'b', 'c'].join(',')", .expected_supported_versions = {2, 3, 4}, }, StringsExtensionVersionTestCase{ .expr = "'foo'.reverse()", .expected_supported_versions = {3, 4}, }, }; } INSTANTIATE_TEST_SUITE_P(StringsExtensionVersionTest, StringsExtensionVersionTest, ValuesIn(CreateStringsExtensionVersionParams())); } // namespace } // namespace cel::extensions ================================================ FILE: internal/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") load("//bazel:cel_cc_embed.bzl", "cel_cc_embed") load("//bazel:cel_proto_transitive_descriptor_set.bzl", "cel_proto_transitive_descriptor_set") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "align", hdrs = ["align.h"], deps = [ "@com_google_absl//absl/base", "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/numeric:bits", ], ) cc_test( name = "align_test", srcs = ["align_test.cc"], tags = ["no_test_msvc"], deps = [ ":align", ":testing", ], ) cc_library( name = "new", srcs = ["new.cc"], hdrs = ["new.h"], deps = [ ":align", "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/numeric:bits", ], ) cc_test( name = "new_test", srcs = ["new_test.cc"], deps = [ ":new", ":testing", ], ) cc_library( name = "benchmark", testonly = True, hdrs = ["benchmark.h"], deps = ["@com_github_google_benchmark//:benchmark_main"], ) cc_library( name = "casts", hdrs = ["casts.h"], ) cc_library( name = "re2_options", hdrs = ["re2_options.h"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", ], ) cc_library( name = "status_builder", hdrs = ["status_builder.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", ], ) cc_library( name = "overflow", srcs = ["overflow.cc"], hdrs = ["overflow.h"], deps = [ ":status_macros", ":time", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", ], ) cc_test( name = "overflow_test", srcs = ["overflow_test.cc"], deps = [ ":overflow", ":testing", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/time", ], ) cc_library( name = "number", hdrs = ["number.h"], deps = ["@com_google_absl//absl/types:variant"], ) cc_test( name = "number_test", srcs = ["number_test.cc"], deps = [ ":number", ":testing", ], ) cc_library( name = "exceptions", hdrs = ["exceptions.h"], deps = ["@com_google_absl//absl/base:config"], ) cc_library( name = "status_macros", hdrs = ["status_macros.h"], deps = [ ":status_builder", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", ], ) cc_library( name = "string_pool", srcs = ["string_pool.cc"], hdrs = ["string_pool.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "string_pool_test", srcs = ["string_pool_test.cc"], deps = [ ":string_pool", ":testing", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "strings", srcs = ["strings.cc"], hdrs = ["strings.h"], deps = [ ":lexis", ":unicode", ":utf8", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", ], ) cc_test( name = "strings_test", srcs = ["strings_test.cc"], deps = [ ":strings", ":testing", ":utf8", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:cord_test_helpers", "@com_google_absl//absl/strings:str_format", ], ) cc_library( name = "lexis", srcs = ["lexis.cc"], hdrs = ["lexis.h"], deps = [ "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) cc_test( name = "lexis_test", srcs = ["lexis_test.cc"], deps = [ ":lexis", ":testing", ], ) cc_library( name = "proto_util", hdrs = ["proto_util.h"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "proto_util_test", srcs = ["proto_util_test.cc"], deps = [ ":proto_util", ":testing", "//eval/public/structs:cel_proto_descriptor_pool_builder", "@com_google_absl//absl/status", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "proto_time_encoding", srcs = ["proto_time_encoding.cc"], hdrs = ["proto_time_encoding.h"], deps = [ ":status_macros", ":time", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:time_util", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_test( name = "proto_time_encoding_test", srcs = ["proto_time_encoding_test.cc"], deps = [ ":proto_time_encoding", ":testing", "//testutil:util", "@com_google_absl//absl/time", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_library( name = "testing", testonly = True, srcs = [ "testing.cc", ], hdrs = [ "testing.h", ], deps = [ ":status_macros", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "testing_no_main", testonly = True, srcs = [ "testing.cc", ], hdrs = [ "testing.h", ], deps = [ ":status_macros", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], ) cc_library( name = "time", srcs = ["time.cc"], hdrs = ["time.h"], deps = [ ":status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_protobuf//:time_util", ], ) cc_test( name = "time_test", srcs = ["time_test.cc"], deps = [ ":testing", ":time", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_protobuf//:time_util", ], ) cc_library( name = "unicode", hdrs = ["unicode.h"], ) cc_library( name = "utf8", srcs = ["utf8.cc"], hdrs = ["utf8.h"], deps = [ ":unicode", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", ], ) cc_test( name = "utf8_test", srcs = ["utf8_test.cc"], deps = [ ":benchmark", ":testing", ":utf8", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:cord_test_helpers", ], ) cc_library( name = "proto_matchers", testonly = True, hdrs = ["proto_matchers.h"], deps = [ ":casts", ":testing", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "proto_file_util", testonly = True, hdrs = ["proto_file_util.h"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", "@com_google_protobuf//src/google/protobuf/io", ], ) cc_library( name = "names", srcs = ["names.cc"], hdrs = ["names.h"], deps = [ ":lexis", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "names_test", srcs = ["names_test.cc"], deps = [ ":names", ":testing", ], ) cc_library( name = "to_address", hdrs = ["to_address.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/meta:type_traits", ], ) cc_test( name = "to_address_test", srcs = ["to_address_test.cc"], deps = [ ":testing", ":to_address", ], ) cel_proto_transitive_descriptor_set( name = "empty_descriptor_set", deps = [ "@com_google_protobuf//:empty_proto", ], ) cel_cc_embed( name = "empty_descriptor_set_embed", src = ":empty_descriptor_set", ) cc_library( name = "empty_descriptors", srcs = ["empty_descriptors.cc"], hdrs = ["empty_descriptors.h"], textual_hdrs = [":empty_descriptor_set_embed"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:die_if_null", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "empty_descriptors_test", srcs = ["empty_descriptors_test.cc"], deps = [ ":empty_descriptors", ":testing", ], ) cel_proto_transitive_descriptor_set( name = "minimal_descriptor_set", deps = [ "@com_google_protobuf//:any_proto", "@com_google_protobuf//:duration_proto", "@com_google_protobuf//:struct_proto", "@com_google_protobuf//:timestamp_proto", "@com_google_protobuf//:wrappers_proto", ], ) cel_cc_embed( name = "minimal_descriptor_set_embed", src = ":minimal_descriptor_set", ) alias( name = "minimal_descriptor_pool", actual = ":minimal_descriptors", ) cc_library( name = "minimal_descriptors", srcs = ["minimal_descriptors.cc"], hdrs = [ "minimal_descriptor_database.h", "minimal_descriptor_pool.h", ], textual_hdrs = [":minimal_descriptor_set_embed"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cel_proto_transitive_descriptor_set( name = "testing_descriptor_set", testonly = True, deps = [ "//eval/testutil:test_extensions_proto", "//eval/testutil:test_message_proto", "//testutil:test_json_names_proto", "@com_google_cel_spec//proto/cel/expr:checked_proto", "@com_google_cel_spec//proto/cel/expr:expr_proto", "@com_google_cel_spec//proto/cel/expr:syntax_proto", "@com_google_cel_spec//proto/cel/expr:value_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto", "@com_google_protobuf//:any_proto", "@com_google_protobuf//:duration_proto", "@com_google_protobuf//:empty_proto", "@com_google_protobuf//:field_mask_proto", "@com_google_protobuf//:struct_proto", "@com_google_protobuf//:timestamp_proto", "@com_google_protobuf//:wrappers_proto", ], ) cel_cc_embed( name = "testing_descriptor_set_embed", testonly = True, src = ":testing_descriptor_set", ) cc_library( name = "testing_descriptor_pool", testonly = True, srcs = ["testing_descriptor_pool.cc"], hdrs = ["testing_descriptor_pool.h"], textual_hdrs = [":testing_descriptor_set_embed"], deps = [ ":noop_delete", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "testing_descriptor_pool_test", srcs = ["testing_descriptor_pool_test.cc"], deps = [ ":testing", ":testing_descriptor_pool", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "message_type_name", hdrs = ["message_type_name.h"], deps = [ "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "message_type_name_test", srcs = ["message_type_name_test.cc"], deps = [ ":message_type_name", ":testing", "@com_google_protobuf//:any_cc_proto", ], ) cc_library( name = "parse_text_proto", testonly = True, hdrs = ["parse_text_proto.h"], deps = [ ":message_type_name", ":testing_descriptor_pool", ":testing_message_factory", "//common:memory", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "equals_text_proto", testonly = True, srcs = ["equals_text_proto.cc"], hdrs = ["equals_text_proto.h"], deps = [ ":parse_text_proto", ":testing", ":testing_descriptor_pool", ":testing_message_factory", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "testing_message_factory", testonly = True, srcs = ["testing_message_factory.cc"], hdrs = ["testing_message_factory.h"], deps = [ ":testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "well_known_types", srcs = ["well_known_types.cc"], hdrs = ["well_known_types.h"], deps = [ ":protobuf_runtime_version", ":status_macros", "//common:any", "//common:json", "//common:memory", "//extensions/protobuf/internal:map_reflection", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:field_mask_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:time_util", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_test( name = "well_known_types_test", srcs = ["well_known_types_test.cc"], deps = [ ":message_type_name", ":minimal_descriptor_pool", ":parse_text_proto", ":testing", ":testing_descriptor_pool", ":testing_message_factory", ":well_known_types", "//common:memory", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:field_mask_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( name = "json", srcs = ["json.cc"], hdrs = ["json.h"], deps = [ ":status_macros", ":strings", ":well_known_types", "//extensions/protobuf/internal:map_reflection", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:time_util", "@com_google_protobuf//:timestamp_cc_proto", ], ) cc_test( name = "json_test", srcs = ["json_test.cc"], deps = [ ":equals_text_proto", ":json", ":message_type_name", ":parse_text_proto", ":testing", ":testing_descriptor_pool", ":testing_message_factory", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:field_mask_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( name = "message_equality", srcs = ["message_equality.cc"], hdrs = ["message_equality.h"], deps = [ ":json", ":number", ":status_macros", ":well_known_types", "//common:memory", "//extensions/protobuf/internal:map_reflection", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "message_equality_test", srcs = ["message_equality_test.cc"], tags = ["no_test_msvc"], deps = [ ":message_equality", ":message_type_name", ":parse_text_proto", ":testing", ":testing_descriptor_pool", ":testing_message_factory", ":well_known_types", "//common:allocator", "//common:memory", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_google_protobuf//:wrappers_cc_proto", ], ) cc_library( name = "protobuf_runtime_version", hdrs = ["protobuf_runtime_version.h"], deps = ["@com_google_protobuf//:protobuf"], ) cc_library( name = "noop_delete", hdrs = ["noop_delete.h"], deps = ["@com_google_absl//absl/base:nullability"], ) cc_library( name = "manual", hdrs = ["manual.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", ], ) ================================================ FILE: internal/align.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_ALIGN_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_ALIGN_H_ #include #include #include #include "absl/base/casts.h" #include "absl/base/config.h" #include "absl/base/macros.h" #include "absl/numeric/bits.h" namespace cel::internal { template constexpr std::enable_if_t< std::conjunction_v, std::is_unsigned>, T> AlignmentMask(T alignment) { ABSL_ASSERT(absl::has_single_bit(alignment)); return alignment - T{1}; } template std::enable_if_t, std::is_unsigned>, T> AlignDown(T x, size_t alignment) { ABSL_ASSERT(absl::has_single_bit(alignment)); #if ABSL_HAVE_BUILTIN(__builtin_align_up) return __builtin_align_down(x, alignment); #else using C = std::common_type_t; return static_cast(static_cast(x) & ~AlignmentMask(static_cast(alignment))); #endif } template std::enable_if_t, T> AlignDown(T x, size_t alignment) { return absl::bit_cast(AlignDown(absl::bit_cast(x), alignment)); } template std::enable_if_t, std::is_unsigned>, T> AlignUp(T x, size_t alignment) { ABSL_ASSERT(absl::has_single_bit(alignment)); #if ABSL_HAVE_BUILTIN(__builtin_align_up) return __builtin_align_up(x, alignment); #else using C = std::common_type_t; return static_cast(AlignDown( static_cast(x) + AlignmentMask(static_cast(alignment)), alignment)); #endif } template std::enable_if_t, T> AlignUp(T x, size_t alignment) { return absl::bit_cast(AlignUp(absl::bit_cast(x), alignment)); } template constexpr std::enable_if_t< std::conjunction_v, std::is_unsigned>, bool> IsAligned(T x, size_t alignment) { ABSL_ASSERT(absl::has_single_bit(alignment)); #if ABSL_HAVE_BUILTIN(__builtin_is_aligned) return __builtin_is_aligned(x, alignment); #else using C = std::common_type_t; return (static_cast(x) & AlignmentMask(static_cast(alignment))) == C{0}; #endif } template std::enable_if_t, bool> IsAligned(T x, size_t alignment) { return IsAligned(absl::bit_cast(x), alignment); } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_ALIGN_H_ ================================================ FILE: internal/align_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/align.h" #include #include #include "internal/testing.h" namespace cel::internal { namespace { TEST(AlignmentMask, Masks) { EXPECT_EQ(AlignmentMask(size_t{1}), size_t{0}); EXPECT_EQ(AlignmentMask(size_t{2}), size_t{1}); EXPECT_EQ(AlignmentMask(size_t{4}), size_t{3}); } TEST(AlignDown, Aligns) { EXPECT_EQ(AlignDown(uintptr_t{3}, 4), 0); EXPECT_EQ(AlignDown(uintptr_t{0}, 4), 0); EXPECT_EQ(AlignDown(uintptr_t{5}, 4), 4); EXPECT_EQ(AlignDown(uintptr_t{4}, 4), 4); uint64_t val = 0; EXPECT_EQ(AlignDown(&val, alignof(val)), &val); } TEST(AlignUp, Aligns) { EXPECT_EQ(AlignUp(uintptr_t{0}, 4), 0); EXPECT_EQ(AlignUp(uintptr_t{3}, 4), 4); EXPECT_EQ(AlignUp(uintptr_t{5}, 4), 8); uint64_t val = 0; EXPECT_EQ(AlignUp(&val, alignof(val)), &val); } TEST(IsAligned, Aligned) { EXPECT_TRUE(IsAligned(uintptr_t{0}, 4)); EXPECT_TRUE(IsAligned(uintptr_t{4}, 4)); EXPECT_FALSE(IsAligned(uintptr_t{3}, 4)); EXPECT_FALSE(IsAligned(uintptr_t{5}, 4)); uint64_t val = 0; EXPECT_TRUE(IsAligned(&val, alignof(val))); } } // namespace } // namespace cel::internal ================================================ FILE: internal/benchmark.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_BENCHMARK_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_BENCHMARK_H_ #include "benchmark/benchmark.h" // IWYU pragma: export #endif // THIRD_PARTY_CEL_CPP_INTERNAL_BENCHMARK_H_ ================================================ FILE: internal/casts.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_CASTS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_CASTS_H_ #include #include #include namespace cel::internal { template To down_cast(From* from) { static_assert(std::is_pointer_v, "Target type not a pointer."); static_assert((std::is_base_of_v>), "Target type not derived from source type."); #if !defined(__GNUC__) || defined(__GXX_RTTI) assert(from == nullptr || dynamic_cast(from) != nullptr); #endif return static_cast(from); } template To down_cast(From& from) { static_assert(std::is_lvalue_reference_v, "Target type not a lvalue reference."); static_assert((std::is_base_of_v>), "Target type not derived from source type."); #if !defined(__GNUC__) || defined(__GXX_RTTI) assert(dynamic_cast>>( std::addressof(from)) != nullptr); #endif return static_cast(from); } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_CASTS_H_ ================================================ FILE: internal/empty_descriptors.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/empty_descriptors.h" #include #include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/log/die_if_null.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" namespace cel::internal { namespace { ABSL_CONST_INIT const uint8_t kEmptyDescriptorSet[] = { #include "internal/empty_descriptor_set_embed.inc" }; const google::protobuf::DescriptorPool* absl_nonnull GetEmptyDescriptorPool() { static const google::protobuf::DescriptorPool* absl_nonnull const pool = []() { google::protobuf::FileDescriptorSet file_desc_set; ABSL_CHECK(file_desc_set.ParseFromArray( // Crash OK kEmptyDescriptorSet, ABSL_ARRAYSIZE(kEmptyDescriptorSet))); auto* pool = new google::protobuf::DescriptorPool(); for (const auto& file_desc : file_desc_set.file()) { ABSL_CHECK(pool->BuildFile(file_desc) != nullptr); // Crash OK } return pool; }(); return pool; } google::protobuf::MessageFactory* absl_nonnull GetEmptyMessageFactory() { static absl::NoDestructor factory; return &*factory; } } // namespace const google::protobuf::Message* absl_nonnull GetEmptyDefaultInstance() { static const google::protobuf::Message* absl_nonnull const instance = []() { return ABSL_DIE_IF_NULL( // Crash OK ABSL_DIE_IF_NULL( // Crash OK GetEmptyMessageFactory()->GetPrototype( ABSL_DIE_IF_NULL( // Crash OK GetEmptyDescriptorPool()->FindMessageTypeByName( "google.protobuf.Empty"))))) ->New(); }(); return instance; } } // namespace cel::internal ================================================ FILE: internal/empty_descriptors.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_EMPTY_DESCRIPTORS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_EMPTY_DESCRIPTORS_H_ #include "absl/base/nullability.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::internal { // GetEmptyDefaultInstance returns a pointer to a `google::protobuf::Message` which is an // instance of `google.protobuf.Empty`. The returned `google::protobuf::Message` is valid // for the lifetime of the process. const google::protobuf::Message* absl_nonnull GetEmptyDefaultInstance(); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_EMPTY_DESCRIPTORS_H_ ================================================ FILE: internal/empty_descriptors_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/empty_descriptors.h" #include "internal/testing.h" namespace cel::internal { namespace { using ::testing::NotNull; TEST(GetEmptyDefaultInstance, Empty) { const auto* empty = GetEmptyDefaultInstance(); ASSERT_THAT(empty, NotNull()); EXPECT_EQ(empty->GetDescriptor()->full_name(), "google.protobuf.Empty"); EXPECT_EQ(empty, GetEmptyDefaultInstance()); } } // namespace } // namespace cel::internal ================================================ FILE: internal/equals_text_proto.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/equals_text_proto.h" #include #include #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/strings/cord.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/text_format.h" #include "google/protobuf/util/message_differencer.h" namespace cel::internal { void TextProtoMatcher::DescribeTo(std::ostream* os) const { std::string text; ABSL_CHECK( // Crash OK google::protobuf::TextFormat::PrintToString(*message_, &text)); *os << "is equal to <" << text << ">"; } void TextProtoMatcher::DescribeNegationTo(std::ostream* os) const { std::string text; ABSL_CHECK( // Crash OK google::protobuf::TextFormat::PrintToString(*message_, &text)); *os << "is not equal to <" << text << ">"; } bool TextProtoMatcher::MatchAndExplain( const google::protobuf::MessageLite& other, ::testing::MatchResultListener* listener) const { if (other.GetTypeName() != message_->GetTypeName()) { if (listener->IsInterested()) { *listener << "whose type should be " << message_->GetTypeName() << " but actually is " << other.GetTypeName(); } return false; } google::protobuf::util::MessageDifferencer differencer; std::string diff; if (listener->IsInterested()) { differencer.ReportDifferencesToString(&diff); } bool match; if (const auto* other_full_message = google::protobuf::DynamicCastMessage(&other); other_full_message != nullptr && other_full_message->GetDescriptor() == message_->GetDescriptor()) { match = differencer.Compare(*other_full_message, *message_); } else { auto other_message = absl::WrapUnique(message_->New()); absl::Cord serialized; ABSL_CHECK(other.SerializeToString(&serialized)); // Crash OK ABSL_CHECK(other_message->ParseFromString(serialized)); // Crash OK match = differencer.Compare(*other_message, *message_); } if (!match && listener->IsInterested()) { if (!diff.empty() && diff.back() == '\n') { diff.erase(diff.end() - 1); } *listener << "with the difference:\n" << diff; } return match; } } // namespace cel::internal ================================================ FILE: internal/equals_text_proto.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_EQUALS_PROTO_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_EQUALS_PROTO_H_ #include #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" namespace cel::internal { class TextProtoMatcher { public: TextProtoMatcher(const google::protobuf::Message* absl_nonnull message, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory) : message_(message), pool_(pool), factory_(factory) {} void DescribeTo(std::ostream* os) const; void DescribeNegationTo(std::ostream* os) const; bool MatchAndExplain(const google::protobuf::MessageLite& other, ::testing::MatchResultListener* listener) const; private: const google::protobuf::Message* absl_nonnull message_; const google::protobuf::DescriptorPool* absl_nonnull pool_; google::protobuf::MessageFactory* absl_nonnull factory_; }; template ::testing::PolymorphicMatcher EqualsTextProto( google::protobuf::Arena* absl_nonnull arena, absl::string_view text, const google::protobuf::DescriptorPool* absl_nonnull pool = GetTestingDescriptorPool(), google::protobuf::MessageFactory* absl_nonnull factory = GetTestingMessageFactory()) { return ::testing::MakePolymorphicMatcher(TextProtoMatcher( DynamicParseTextProto(arena, text, pool, factory), pool, factory)); } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_EQUALS_PROTO_H_ ================================================ FILE: internal/exceptions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_EXCEPTIONS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_EXCEPTIONS_H_ #include "absl/base/config.h" // IWYU pragma: keep #ifdef ABSL_HAVE_EXCEPTIONS #define CEL_INTERNAL_TRY try #define CEL_INTERNAL_CATCH_ANY catch (...) #define CEL_INTERNAL_RETHROW \ do { \ throw; \ } while (false) #else #define CEL_INTERNAL_TRY if (true) #define CEL_INTERNAL_CATCH_ANY else if (false) #define CEL_INTERNAL_RETHROW \ do { \ } while (false) #endif #endif // THIRD_PARTY_CEL_CPP_INTERNAL_EXCEPTIONS_H_ ================================================ FILE: internal/json.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/json.h" #include #include #include #include #include #include #include #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/base/attributes.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/cord.h" #include "absl/strings/escaping.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "extensions/protobuf/internal/map_reflection.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "internal/well_known_types.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/util/time_util.h" #undef GetMessage namespace cel::internal { namespace { using ::cel::well_known_types::AsVariant; using ::cel::well_known_types::GetListValueReflection; using ::cel::well_known_types::GetRepeatedBytesField; using ::cel::well_known_types::GetRepeatedStringField; using ::cel::well_known_types::GetStructReflection; using ::cel::well_known_types::GetValueReflection; using ::cel::well_known_types::JsonReflection; using ::cel::well_known_types::ListValueReflection; using ::cel::well_known_types::Reflection; using ::cel::well_known_types::StructReflection; using ::cel::well_known_types::ValueReflection; using ::google::protobuf::Descriptor; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::util::TimeUtil; // Yanked from the implementation `google::protobuf::util::TimeUtil`. template absl::Status SnakeCaseToCamelCaseImpl(Chars input, std::string* absl_nonnull output) { output->clear(); bool after_underscore = false; for (char input_char : input) { if (absl::ascii_isupper(input_char)) { // The field name must not contain uppercase letters. return absl::InvalidArgumentError( "field mask path name contains uppercase letters"); } if (after_underscore) { if (absl::ascii_islower(input_char)) { output->push_back(absl::ascii_toupper(input_char)); after_underscore = false; } else { // The character after a "_" must be a lowercase letter. return absl::InvalidArgumentError( "field mask path contains '_' not followed by a lowercase letter"); } } else if (input_char == '_') { after_underscore = true; } else { output->push_back(input_char); } } if (after_underscore) { // Trailing "_". return absl::InvalidArgumentError("field mask path contains trailing '_'"); } return absl::OkStatus(); } absl::Status SnakeCaseToCamelCase(const well_known_types::StringValue& input, std::string* absl_nonnull output) { return absl::visit(absl::Overload( [&](absl::string_view string) -> absl::Status { return SnakeCaseToCamelCaseImpl(string, output); }, [&](const absl::Cord& cord) -> absl::Status { return SnakeCaseToCamelCaseImpl(cord.Chars(), output); }), AsVariant(input)); } class MessageToJsonState; using MapFieldKeyToString = std::string (*)(const google::protobuf::MapKey&); std::string BoolMapFieldKeyToString(const google::protobuf::MapKey& key) { return key.GetBoolValue() ? "true" : "false"; } std::string Int32MapFieldKeyToString(const google::protobuf::MapKey& key) { return absl::StrCat(key.GetInt32Value()); } std::string Int64MapFieldKeyToString(const google::protobuf::MapKey& key) { return absl::StrCat(key.GetInt64Value()); } std::string UInt32MapFieldKeyToString(const google::protobuf::MapKey& key) { return absl::StrCat(key.GetUInt32Value()); } std::string UInt64MapFieldKeyToString(const google::protobuf::MapKey& key) { return absl::StrCat(key.GetUInt64Value()); } std::string StringMapFieldKeyToString(const google::protobuf::MapKey& key) { return std::string(key.GetStringValue()); } MapFieldKeyToString GetMapFieldKeyToString( const google::protobuf::FieldDescriptor* absl_nonnull field) { switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_BOOL: return &BoolMapFieldKeyToString; case FieldDescriptor::CPPTYPE_INT32: return &Int32MapFieldKeyToString; case FieldDescriptor::CPPTYPE_INT64: return &Int64MapFieldKeyToString; case FieldDescriptor::CPPTYPE_UINT32: return &UInt32MapFieldKeyToString; case FieldDescriptor::CPPTYPE_UINT64: return &UInt64MapFieldKeyToString; case FieldDescriptor::CPPTYPE_STRING: return &StringMapFieldKeyToString; default: ABSL_UNREACHABLE(); } } using MapFieldValueToValue = absl::Status (MessageToJsonState::*)( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result); using RepeatedFieldToValue = absl::Status (MessageToJsonState::*)( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result); class MessageToJsonState { public: MessageToJsonState(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory) : descriptor_pool_(descriptor_pool), message_factory_(message_factory) {} virtual ~MessageToJsonState() = default; absl::Status ToJson(const google::protobuf::Message& message, google::protobuf::MessageLite* absl_nonnull result) { const auto* descriptor = message.GetDescriptor(); switch (descriptor->well_known_type()) { case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { CEL_RETURN_IF_ERROR(reflection_.DoubleValue().Initialize(descriptor)); SetNumberValue(result, reflection_.DoubleValue().GetValue(message)); } break; case Descriptor::WELLKNOWNTYPE_FLOATVALUE: { CEL_RETURN_IF_ERROR(reflection_.FloatValue().Initialize(descriptor)); SetNumberValue(result, reflection_.FloatValue().GetValue(message)); } break; case Descriptor::WELLKNOWNTYPE_INT64VALUE: { CEL_RETURN_IF_ERROR(reflection_.Int64Value().Initialize(descriptor)); SetNumberValue(result, reflection_.Int64Value().GetValue(message)); } break; case Descriptor::WELLKNOWNTYPE_UINT64VALUE: { CEL_RETURN_IF_ERROR(reflection_.UInt64Value().Initialize(descriptor)); SetNumberValue(result, reflection_.UInt64Value().GetValue(message)); } break; case Descriptor::WELLKNOWNTYPE_INT32VALUE: { CEL_RETURN_IF_ERROR(reflection_.Int32Value().Initialize(descriptor)); SetNumberValue(result, reflection_.Int32Value().GetValue(message)); } break; case Descriptor::WELLKNOWNTYPE_UINT32VALUE: { CEL_RETURN_IF_ERROR(reflection_.UInt32Value().Initialize(descriptor)); SetNumberValue(result, reflection_.UInt32Value().GetValue(message)); } break; case Descriptor::WELLKNOWNTYPE_STRINGVALUE: { CEL_RETURN_IF_ERROR(reflection_.StringValue().Initialize(descriptor)); StringValueToJson(reflection_.StringValue().GetValue(message, scratch_), result); } break; case Descriptor::WELLKNOWNTYPE_BYTESVALUE: { CEL_RETURN_IF_ERROR(reflection_.BytesValue().Initialize(descriptor)); BytesValueToJson(reflection_.BytesValue().GetValue(message, scratch_), result); } break; case Descriptor::WELLKNOWNTYPE_BOOLVALUE: { CEL_RETURN_IF_ERROR(reflection_.BoolValue().Initialize(descriptor)); SetBoolValue(result, reflection_.BoolValue().GetValue(message)); } break; case Descriptor::WELLKNOWNTYPE_ANY: { CEL_ASSIGN_OR_RETURN(auto unpacked, well_known_types::UnpackAnyFrom( result->GetArena(), reflection_.Any(), message, descriptor_pool_, message_factory_)); auto* struct_result = MutableStructValue(result); const auto* unpacked_descriptor = unpacked->GetDescriptor(); SetStringValue(InsertField(struct_result, "@type"), absl::StrCat("type.googleapis.com/", unpacked_descriptor->full_name())); switch (unpacked_descriptor->well_known_type()) { case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_FLOATVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_INT64VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_UINT64VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_INT32VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_UINT32VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_STRINGVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_BYTESVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_BOOLVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_FIELDMASK: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_DURATION: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_TIMESTAMP: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_LISTVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_STRUCT: return ToJson(*unpacked, InsertField(struct_result, "value")); default: if (unpacked_descriptor->full_name() == "google.protobuf.Empty") { MutableStructValue(InsertField(struct_result, "value")); return absl::OkStatus(); } else { return MessageToJson(*unpacked, struct_result); } } } case Descriptor::WELLKNOWNTYPE_FIELDMASK: { CEL_RETURN_IF_ERROR(reflection_.FieldMask().Initialize(descriptor)); std::vector paths; const int paths_size = reflection_.FieldMask().PathsSize(message); for (int i = 0; i < paths_size; ++i) { CEL_RETURN_IF_ERROR(SnakeCaseToCamelCase( reflection_.FieldMask().Paths(message, i, scratch_), &paths.emplace_back())); } SetStringValue(result, absl::StrJoin(paths, ",")); } break; case Descriptor::WELLKNOWNTYPE_DURATION: { CEL_RETURN_IF_ERROR(reflection_.Duration().Initialize(descriptor)); google::protobuf::Duration duration; duration.set_seconds(reflection_.Duration().GetSeconds(message)); duration.set_nanos(reflection_.Duration().GetNanos(message)); SetStringValue(result, TimeUtil::ToString(duration)); } break; case Descriptor::WELLKNOWNTYPE_TIMESTAMP: { CEL_RETURN_IF_ERROR(reflection_.Timestamp().Initialize(descriptor)); google::protobuf::Timestamp timestamp; timestamp.set_seconds(reflection_.Timestamp().GetSeconds(message)); timestamp.set_nanos(reflection_.Timestamp().GetNanos(message)); SetStringValue(result, TimeUtil::ToString(timestamp)); } break; case Descriptor::WELLKNOWNTYPE_VALUE: { absl::Cord serialized; if (!message.SerializePartialToString(&serialized)) { return absl::UnknownError( "failed to serialize message google.protobuf.Value"); } if (!result->ParsePartialFromString(serialized)) { return absl::UnknownError( "failed to parsed message: google.protobuf.Value"); } } break; case Descriptor::WELLKNOWNTYPE_LISTVALUE: { absl::Cord serialized; if (!message.SerializePartialToString(&serialized)) { return absl::UnknownError( "failed to serialize message google.protobuf.ListValue"); } if (!MutableListValue(result)->ParsePartialFromString(serialized)) { return absl::UnknownError( "failed to parsed message: google.protobuf.ListValue"); } } break; case Descriptor::WELLKNOWNTYPE_STRUCT: { absl::Cord serialized; if (!message.SerializePartialToString(&serialized)) { return absl::UnknownError( "failed to serialize message google.protobuf.Struct"); } if (!MutableStructValue(result)->ParsePartialFromString(serialized)) { return absl::UnknownError( "failed to parsed message: google.protobuf.Struct"); } } break; default: return MessageToJson(message, MutableStructValue(result)); } return absl::OkStatus(); } absl::Status ToJsonObject(const google::protobuf::Message& message, google::protobuf::MessageLite* absl_nonnull result) { return MessageToJson(message, result); } absl::Status FieldToJson(const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { return MessageFieldToJson(message, field, result); } absl::Status FieldToJsonArray( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { return MessageRepeatedFieldToJson(message, field, result); } absl::Status FieldToJsonObject( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { return MessageMapFieldToJson(message, field, result); } virtual absl::Status Initialize( google::protobuf::MessageLite* absl_nonnull message) = 0; private: absl::StatusOr GetMapFieldValueToValue( const google::protobuf::FieldDescriptor* absl_nonnull field) { switch (field->type()) { case FieldDescriptor::TYPE_DOUBLE: return &MessageToJsonState::MapDoubleFieldToValue; case FieldDescriptor::TYPE_FLOAT: return &MessageToJsonState::MapFloatFieldToValue; case FieldDescriptor::TYPE_FIXED64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_UINT64: return &MessageToJsonState::MapUInt64FieldToValue; case FieldDescriptor::TYPE_BOOL: return &MessageToJsonState::MapBoolFieldToValue; case FieldDescriptor::TYPE_STRING: return &MessageToJsonState::MapStringFieldToValue; case FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_MESSAGE: return &MessageToJsonState::MapMessageFieldToValue; case FieldDescriptor::TYPE_BYTES: return &MessageToJsonState::MapBytesFieldToValue; case FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_UINT32: return &MessageToJsonState::MapUInt32FieldToValue; case FieldDescriptor::TYPE_ENUM: { const auto* enum_descriptor = field->enum_type(); if (enum_descriptor->full_name() == "google.protobuf.NullValue") { return &MessageToJsonState::MapNullFieldToValue; } else { return &MessageToJsonState::MapEnumFieldToValue; } } case FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SINT32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_INT32: return &MessageToJsonState::MapInt32FieldToValue; case FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SINT64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_INT64: return &MessageToJsonState::MapInt64FieldToValue; default: return absl::InvalidArgumentError(absl::StrCat( "unexpected message field type: ", field->type_name())); } } absl::Status MapBoolFieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_BOOL); SetBoolValue(result, value.GetBoolValue()); return absl::OkStatus(); } absl::Status MapInt32FieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_INT32); SetNumberValue(result, value.GetInt32Value()); return absl::OkStatus(); } absl::Status MapInt64FieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_INT64); SetNumberValue(result, value.GetInt64Value()); return absl::OkStatus(); } absl::Status MapUInt32FieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_UINT32); SetNumberValue(result, value.GetUInt32Value()); return absl::OkStatus(); } absl::Status MapUInt64FieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_UINT64); SetNumberValue(result, value.GetUInt64Value()); return absl::OkStatus(); } absl::Status MapFloatFieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_FLOAT); SetNumberValue(result, value.GetFloatValue()); return absl::OkStatus(); } absl::Status MapDoubleFieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_DOUBLE); SetNumberValue(result, value.GetDoubleValue()); return absl::OkStatus(); } absl::Status MapBytesFieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_BYTES); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); SetStringValueFromBytes(result, value.GetStringValue()); return absl::OkStatus(); } absl::Status MapStringFieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_STRING); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); SetStringValue(result, value.GetStringValue()); return absl::OkStatus(); } absl::Status MapMessageFieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_MESSAGE); return ToJson(value.GetMessageValue(), result); } absl::Status MapEnumFieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_ENUM); ABSL_DCHECK_NE(field->enum_type()->full_name(), "google.protobuf.NullValue"); if (const auto* value_descriptor = field->enum_type()->FindValueByNumber(value.GetEnumValue()); value_descriptor != nullptr) { SetStringValue(result, value_descriptor->name()); } else { SetNumberValue(result, value.GetEnumValue()); } return absl::OkStatus(); } absl::Status MapNullFieldToValue( const google::protobuf::MapValueConstRef& value, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(value.type(), field->cpp_type()); ABSL_DCHECK(!field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_ENUM); ABSL_DCHECK_EQ(field->enum_type()->full_name(), "google.protobuf.NullValue"); SetNullValue(result); return absl::OkStatus(); } absl::StatusOr GetRepeatedFieldToValue( const google::protobuf::FieldDescriptor* absl_nonnull field) { switch (field->type()) { case FieldDescriptor::TYPE_DOUBLE: return &MessageToJsonState::RepeatedDoubleFieldToValue; case FieldDescriptor::TYPE_FLOAT: return &MessageToJsonState::RepeatedFloatFieldToValue; case FieldDescriptor::TYPE_FIXED64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_UINT64: return &MessageToJsonState::RepeatedUInt64FieldToValue; case FieldDescriptor::TYPE_BOOL: return &MessageToJsonState::RepeatedBoolFieldToValue; case FieldDescriptor::TYPE_STRING: return &MessageToJsonState::RepeatedStringFieldToValue; case FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_MESSAGE: return &MessageToJsonState::RepeatedMessageFieldToValue; case FieldDescriptor::TYPE_BYTES: return &MessageToJsonState::RepeatedBytesFieldToValue; case FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_UINT32: return &MessageToJsonState::RepeatedUInt32FieldToValue; case FieldDescriptor::TYPE_ENUM: { const auto* enum_descriptor = field->enum_type(); if (enum_descriptor->full_name() == "google.protobuf.NullValue") { return &MessageToJsonState::RepeatedNullFieldToValue; } else { return &MessageToJsonState::RepeatedEnumFieldToValue; } } case FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SINT32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_INT32: return &MessageToJsonState::RepeatedInt32FieldToValue; case FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SINT64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_INT64: return &MessageToJsonState::RepeatedInt64FieldToValue; default: return absl::InvalidArgumentError(absl::StrCat( "unexpected message field type: ", field->type_name())); } } absl::Status RepeatedBoolFieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_BOOL); SetBoolValue(result, reflection->GetRepeatedBool(message, field, index)); return absl::OkStatus(); } absl::Status RepeatedInt32FieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_INT32); SetNumberValue(result, reflection->GetRepeatedInt32(message, field, index)); return absl::OkStatus(); } absl::Status RepeatedInt64FieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_INT64); SetNumberValue(result, reflection->GetRepeatedInt64(message, field, index)); return absl::OkStatus(); } absl::Status RepeatedUInt32FieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_UINT32); SetNumberValue(result, reflection->GetRepeatedUInt32(message, field, index)); return absl::OkStatus(); } absl::Status RepeatedUInt64FieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_UINT64); SetNumberValue(result, reflection->GetRepeatedUInt64(message, field, index)); return absl::OkStatus(); } absl::Status RepeatedFloatFieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_FLOAT); SetNumberValue(result, reflection->GetRepeatedFloat(message, field, index)); return absl::OkStatus(); } absl::Status RepeatedDoubleFieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_DOUBLE); SetNumberValue(result, reflection->GetRepeatedDouble(message, field, index)); return absl::OkStatus(); } absl::Status RepeatedBytesFieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_BYTES); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); absl::visit(absl::Overload( [&](absl::string_view string) -> void { SetStringValueFromBytes(result, string); }, [&](absl::Cord&& cord) -> void { SetStringValueFromBytes(result, cord); }), AsVariant(GetRepeatedBytesField(reflection, message, field, index, scratch_))); return absl::OkStatus(); } absl::Status RepeatedStringFieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_STRING); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); absl::visit( absl::Overload( [&](absl::string_view string) -> void { SetStringValue(result, string); }, [&](absl::Cord&& cord) -> void { SetStringValue(result, cord); }), AsVariant(GetRepeatedStringField(reflection, message, field, index, scratch_))); return absl::OkStatus(); } absl::Status RepeatedMessageFieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_MESSAGE); return ToJson(reflection->GetRepeatedMessage(message, field, index), result); } absl::Status RepeatedEnumFieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_ENUM); ABSL_DCHECK_NE(field->enum_type()->full_name(), "google.protobuf.NullValue"); if (const auto* value = reflection->GetRepeatedEnum(message, field, index); value != nullptr) { SetStringValue(result, value->name()); } else { SetNumberValue(result, reflection->GetRepeatedEnumValue(message, field, index)); } return absl::OkStatus(); } absl::Status RepeatedNullFieldToValue( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, google::protobuf::MessageLite* absl_nonnull result) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_ENUM); ABSL_DCHECK_EQ(field->enum_type()->full_name(), "google.protobuf.NullValue"); SetNullValue(result); return absl::OkStatus(); } absl::Status MessageMapFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { const auto* reflection = message.GetReflection(); if (reflection->FieldSize(message, field) == 0) { return absl::OkStatus(); } const auto key_to_string = GetMapFieldKeyToString(field->message_type()->map_key()); const auto* value_descriptor = field->message_type()->map_value(); CEL_ASSIGN_OR_RETURN(const auto value_to_value, GetMapFieldValueToValue(value_descriptor)); auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, message, *field); const auto end = extensions::protobuf_internal::ConstMapEnd( *reflection, message, *field); for (; begin != end; ++begin) { auto key = (*key_to_string)(begin.GetKey()); CEL_RETURN_IF_ERROR((this->*value_to_value)( begin.GetValueRef(), value_descriptor, InsertField(result, key))); } return absl::OkStatus(); } absl::Status MessageRepeatedFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { const auto* reflection = message.GetReflection(); const int size = reflection->FieldSize(message, field); if (size == 0) { return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(const auto to_value, GetRepeatedFieldToValue(field)); for (int index = 0; index < size; ++index) { CEL_RETURN_IF_ERROR((this->*to_value)(reflection, message, field, index, AddValues(result))); } return absl::OkStatus(); } absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, google::protobuf::MessageLite* absl_nonnull result) { if (field->is_map()) { return MessageMapFieldToJson(message, field, MutableStructValue(result)); } if (field->is_repeated()) { return MessageRepeatedFieldToJson(message, field, MutableListValue(result)); } const auto* reflection = message.GetReflection(); switch (field->type()) { case FieldDescriptor::TYPE_DOUBLE: SetNumberValue(result, reflection->GetDouble(message, field)); break; case FieldDescriptor::TYPE_FLOAT: SetNumberValue(result, reflection->GetFloat(message, field)); break; case FieldDescriptor::TYPE_FIXED64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_UINT64: SetNumberValue(result, reflection->GetUInt64(message, field)); break; case FieldDescriptor::TYPE_BOOL: SetBoolValue(result, reflection->GetBool(message, field)); break; case FieldDescriptor::TYPE_STRING: StringValueToJson( well_known_types::GetStringField(message, field, scratch_), result); break; case FieldDescriptor::TYPE_GROUP: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_MESSAGE: return ToJson((reflection->GetMessage)(message, field), result); case FieldDescriptor::TYPE_BYTES: BytesValueToJson( well_known_types::GetBytesField(message, field, scratch_), result); break; case FieldDescriptor::TYPE_FIXED32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_UINT32: SetNumberValue(result, reflection->GetUInt32(message, field)); break; case FieldDescriptor::TYPE_ENUM: { const auto* enum_descriptor = field->enum_type(); if (enum_descriptor->full_name() == "google.protobuf.NullValue") { SetNullValue(result); } else { const auto* enum_value_descriptor = reflection->GetEnum(message, field); if (enum_value_descriptor != nullptr) { SetStringValue(result, enum_value_descriptor->name()); } else { SetNumberValue(result, reflection->GetEnumValue(message, field)); } } } break; case FieldDescriptor::TYPE_SFIXED32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SINT32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_INT32: SetNumberValue(result, reflection->GetInt32(message, field)); break; case FieldDescriptor::TYPE_SFIXED64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_SINT64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::TYPE_INT64: SetNumberValue(result, reflection->GetInt64(message, field)); break; default: return absl::InvalidArgumentError(absl::StrCat( "unexpected message field type: ", field->type_name())); } return absl::OkStatus(); } absl::Status MessageToJson(const google::protobuf::Message& message, google::protobuf::MessageLite* absl_nonnull result) { std::vector fields; const auto* reflection = message.GetReflection(); reflection->ListFields(message, &fields); if (!fields.empty()) { for (const auto* field : fields) { CEL_RETURN_IF_ERROR(MessageFieldToJson( message, field, InsertField(result, field->json_name()))); } } return absl::OkStatus(); } void StringValueToJson(const well_known_types::StringValue& value, google::protobuf::MessageLite* absl_nonnull result) const { absl::visit(absl::Overload([&](absl::string_view string) -> void { SetStringValue(result, string); }, [&](const absl::Cord& cord) -> void { SetStringValue(result, cord); }), AsVariant(value)); } void BytesValueToJson(const well_known_types::BytesValue& value, google::protobuf::MessageLite* absl_nonnull result) const { absl::visit(absl::Overload( [&](absl::string_view string) -> void { SetStringValueFromBytes(result, string); }, [&](const absl::Cord& cord) -> void { SetStringValueFromBytes(result, cord); }), AsVariant(value)); } virtual void SetNullValue( google::protobuf::MessageLite* absl_nonnull message) const = 0; virtual void SetBoolValue(google::protobuf::MessageLite* absl_nonnull message, bool value) const = 0; virtual void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, double value) const = 0; void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, float value) const { SetNumberValue(message, static_cast(value)); } virtual void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, int64_t value) const = 0; void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, int32_t value) const { SetNumberValue(message, static_cast(value)); } virtual void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, uint64_t value) const = 0; void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, uint32_t value) const { SetNumberValue(message, static_cast(value)); } virtual void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, absl::string_view value) const = 0; virtual void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, const absl::Cord& value) const = 0; void SetStringValueFromBytes(google::protobuf::MessageLite* absl_nonnull message, absl::string_view value) const { if (value.empty()) { SetStringValue(message, value); return; } SetStringValue(message, absl::Base64Escape(value)); } void SetStringValueFromBytes(google::protobuf::MessageLite* absl_nonnull message, const absl::Cord& value) const { if (value.empty()) { SetStringValue(message, value); return; } if (auto flat = value.TryFlat(); flat) { SetStringValue(message, absl::Base64Escape(*flat)); return; } SetStringValue(message, absl::Base64Escape(static_cast(value))); } virtual google::protobuf::MessageLite* absl_nonnull MutableListValue( google::protobuf::MessageLite* absl_nonnull message) const = 0; virtual google::protobuf::MessageLite* absl_nonnull MutableStructValue( google::protobuf::MessageLite* absl_nonnull message) const = 0; virtual google::protobuf::MessageLite* absl_nonnull AddValues( google::protobuf::MessageLite* absl_nonnull message) const = 0; virtual google::protobuf::MessageLite* absl_nonnull InsertField( google::protobuf::MessageLite* absl_nonnull message, absl::string_view name) const = 0; const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; google::protobuf::MessageFactory* absl_nonnull const message_factory_; std::string scratch_; Reflection reflection_; }; class GeneratedMessageToJsonState final : public MessageToJsonState { public: using MessageToJsonState::MessageToJsonState; absl::Status Initialize(google::protobuf::MessageLite* absl_nonnull message) override { // Nothing to do. return absl::OkStatus(); } private: void SetNullValue(google::protobuf::MessageLite* absl_nonnull message) const override { ValueReflection::SetNullValue( google::protobuf::DownCastMessage(message)); } void SetBoolValue(google::protobuf::MessageLite* absl_nonnull message, bool value) const override { ValueReflection::SetBoolValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, double value) const override { ValueReflection::SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, int64_t value) const override { ValueReflection::SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, uint64_t value) const override { ValueReflection::SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, absl::string_view value) const override { ValueReflection::SetStringValue( google::protobuf::DownCastMessage(message), value); } void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, const absl::Cord& value) const override { ValueReflection::SetStringValue( google::protobuf::DownCastMessage(message), value); } google::protobuf::MessageLite* absl_nonnull MutableListValue( google::protobuf::MessageLite* absl_nonnull message) const override { return ValueReflection::MutableListValue( google::protobuf::DownCastMessage(message)); } google::protobuf::MessageLite* absl_nonnull MutableStructValue( google::protobuf::MessageLite* absl_nonnull message) const override { return ValueReflection::MutableStructValue( google::protobuf::DownCastMessage(message)); } google::protobuf::MessageLite* absl_nonnull AddValues( google::protobuf::MessageLite* absl_nonnull message) const override { return ListValueReflection::AddValues( google::protobuf::DownCastMessage(message)); } google::protobuf::MessageLite* absl_nonnull InsertField( google::protobuf::MessageLite* absl_nonnull message, absl::string_view name) const override { return StructReflection::InsertField( google::protobuf::DownCastMessage(message), name); } }; class DynamicMessageToJsonState final : public MessageToJsonState { public: using MessageToJsonState::MessageToJsonState; absl::Status Initialize(google::protobuf::MessageLite* absl_nonnull message) override { CEL_RETURN_IF_ERROR(reflection_.Initialize( google::protobuf::DownCastMessage(message)->GetDescriptor())); return absl::OkStatus(); } private: void SetNullValue(google::protobuf::MessageLite* absl_nonnull message) const override { reflection_.Value().SetNullValue( google::protobuf::DownCastMessage(message)); } void SetBoolValue(google::protobuf::MessageLite* absl_nonnull message, bool value) const override { reflection_.Value().SetBoolValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, double value) const override { reflection_.Value().SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, int64_t value) const override { reflection_.Value().SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetNumberValue(google::protobuf::MessageLite* absl_nonnull message, uint64_t value) const override { reflection_.Value().SetNumberValue( google::protobuf::DownCastMessage(message), value); } void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, absl::string_view value) const override { reflection_.Value().SetStringValue( google::protobuf::DownCastMessage(message), value); } void SetStringValue(google::protobuf::MessageLite* absl_nonnull message, const absl::Cord& value) const override { reflection_.Value().SetStringValue( google::protobuf::DownCastMessage(message), value); } google::protobuf::MessageLite* absl_nonnull MutableListValue( google::protobuf::MessageLite* absl_nonnull message) const override { return reflection_.Value().MutableListValue( google::protobuf::DownCastMessage(message)); } google::protobuf::MessageLite* absl_nonnull MutableStructValue( google::protobuf::MessageLite* absl_nonnull message) const override { return reflection_.Value().MutableStructValue( google::protobuf::DownCastMessage(message)); } google::protobuf::MessageLite* absl_nonnull AddValues( google::protobuf::MessageLite* absl_nonnull message) const override { return reflection_.ListValue().AddValues( google::protobuf::DownCastMessage(message)); } google::protobuf::MessageLite* absl_nonnull InsertField( google::protobuf::MessageLite* absl_nonnull message, absl::string_view name) const override { return reflection_.Struct().InsertField( google::protobuf::DownCastMessage(message), name); } JsonReflection reflection_; }; } // namespace absl::Status MessageToJson( const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Value* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(result != nullptr); auto state = std::make_unique(descriptor_pool, message_factory); CEL_RETURN_IF_ERROR(state->Initialize(result)); return state->ToJson(message, result); } absl::Status MessageToJson( const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Struct* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(result != nullptr); auto state = std::make_unique(descriptor_pool, message_factory); CEL_RETURN_IF_ERROR(state->Initialize(result)); return state->ToJsonObject(message, result); } absl::Status MessageToJson( const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull result) { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(result != nullptr); auto state = std::make_unique(descriptor_pool, message_factory); CEL_RETURN_IF_ERROR(state->Initialize(result)); switch (result->GetDescriptor()->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: return state->ToJson(message, result); case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: return state->ToJsonObject(message, result); default: return absl::InvalidArgumentError("cannot convert message to JSON array"); } } absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Value* absl_nonnull result) { ABSL_DCHECK_EQ(field->containing_type(), message.GetDescriptor()); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(result != nullptr); auto state = std::make_unique(descriptor_pool, message_factory); CEL_RETURN_IF_ERROR(state->Initialize(result)); return state->FieldToJson(message, field, result); } absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::ListValue* absl_nonnull result) { ABSL_DCHECK_EQ(field->containing_type(), message.GetDescriptor()); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(result != nullptr); auto state = std::make_unique(descriptor_pool, message_factory); CEL_RETURN_IF_ERROR(state->Initialize(result)); return state->FieldToJsonArray(message, field, result); } absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Struct* absl_nonnull result) { ABSL_DCHECK_EQ(field->containing_type(), message.GetDescriptor()); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(result != nullptr); auto state = std::make_unique(descriptor_pool, message_factory); CEL_RETURN_IF_ERROR(state->Initialize(result)); return state->FieldToJsonObject(message, field, result); } absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull result) { ABSL_DCHECK_EQ(field->containing_type(), message.GetDescriptor()); ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(result != nullptr); auto state = std::make_unique(descriptor_pool, message_factory); CEL_RETURN_IF_ERROR(state->Initialize(result)); switch (result->GetDescriptor()->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: return state->FieldToJson(message, field, result); case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: return state->FieldToJsonArray(message, field, result); case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: return state->FieldToJsonObject(message, field, result); default: return absl::InternalError("unreachable"); } } absl::Status CheckJson(const google::protobuf::MessageLite& message) { if (const auto* generated_message = google::protobuf::DynamicCastMessage(&message); generated_message) { return absl::OkStatus(); } if (const auto* dynamic_message = google::protobuf::DynamicCastMessage(&message); dynamic_message) { CEL_ASSIGN_OR_RETURN(auto reflection, GetValueReflection(dynamic_message->GetDescriptor())); CEL_RETURN_IF_ERROR( GetListValueReflection(reflection.GetListValueDescriptor()).status()); CEL_RETURN_IF_ERROR( GetStructReflection(reflection.GetStructDescriptor()).status()); return absl::OkStatus(); } return absl::InvalidArgumentError( absl::StrCat("message must be an instance of `google.protobuf.Value`: ", message.GetTypeName())); } absl::Status CheckJsonList(const google::protobuf::MessageLite& message) { if (const auto* generated_message = google::protobuf::DynamicCastMessage(&message); generated_message) { return absl::OkStatus(); } if (const auto* dynamic_message = google::protobuf::DynamicCastMessage(&message); dynamic_message) { CEL_ASSIGN_OR_RETURN( auto reflection, GetListValueReflection(dynamic_message->GetDescriptor())); CEL_ASSIGN_OR_RETURN(auto value_reflection, GetValueReflection(reflection.GetValueDescriptor())); CEL_RETURN_IF_ERROR( GetStructReflection(value_reflection.GetStructDescriptor()).status()); return absl::OkStatus(); } return absl::InvalidArgumentError(absl::StrCat( "message must be an instance of `google.protobuf.ListValue`: ", message.GetTypeName())); } absl::Status CheckJsonMap(const google::protobuf::MessageLite& message) { if (const auto* generated_message = google::protobuf::DynamicCastMessage(&message); generated_message) { return absl::OkStatus(); } if (const auto* dynamic_message = google::protobuf::DynamicCastMessage(&message); dynamic_message) { CEL_ASSIGN_OR_RETURN(auto reflection, GetStructReflection(dynamic_message->GetDescriptor())); CEL_ASSIGN_OR_RETURN(auto value_reflection, GetValueReflection(reflection.GetValueDescriptor())); CEL_RETURN_IF_ERROR( GetListValueReflection(value_reflection.GetListValueDescriptor()) .status()); return absl::OkStatus(); } return absl::InvalidArgumentError( absl::StrCat("message must be an instance of `google.protobuf.Struct`: ", message.GetTypeName())); } namespace { class JsonMapIterator final { public: using Generated = typename google::protobuf::Map::const_iterator; using Dynamic = google::protobuf::ConstMapIterator; using Value = std::pair; // NOLINTNEXTLINE(google-explicit-constructor) JsonMapIterator(Generated generated) : variant_(std::move(generated)) {} // NOLINTNEXTLINE(google-explicit-constructor) JsonMapIterator(Dynamic dynamic) : variant_(std::move(dynamic)) {} JsonMapIterator(const JsonMapIterator&) = default; JsonMapIterator(JsonMapIterator&&) = default; JsonMapIterator& operator=(const JsonMapIterator&) = default; JsonMapIterator& operator=(JsonMapIterator&&) = default; Value Next(std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { Value result; absl::visit(absl::Overload( [&](Generated& generated) -> void { result = std::pair{absl::string_view(generated->first), &generated->second}; ++generated; }, [&](Dynamic& dynamic) -> void { const auto& key = dynamic.GetKey().GetStringValue(); scratch.assign(key.data(), key.size()); result = std::pair{absl::string_view(scratch), &dynamic.GetValueRef().GetMessageValue()}; ++dynamic; }), variant_); return result; } private: absl::variant variant_; }; class JsonAccessor { public: virtual ~JsonAccessor() = default; virtual google::protobuf::Value::KindCase GetKindCase( const google::protobuf::MessageLite& message) const = 0; virtual bool GetBoolValue(const google::protobuf::MessageLite& message) const = 0; virtual double GetNumberValue(const google::protobuf::MessageLite& message) const = 0; virtual well_known_types::StringValue GetStringValue( const google::protobuf::MessageLite& message, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const = 0; virtual const google::protobuf::MessageLite& GetListValue( const google::protobuf::MessageLite& message) const = 0; virtual int ValuesSize(const google::protobuf::MessageLite& message) const = 0; virtual const google::protobuf::MessageLite& Values(const google::protobuf::MessageLite& message, int index) const = 0; virtual const google::protobuf::MessageLite& GetStructValue( const google::protobuf::MessageLite& message) const = 0; virtual int FieldsSize(const google::protobuf::MessageLite& message) const = 0; virtual const google::protobuf::MessageLite* absl_nullable FindField( const google::protobuf::MessageLite& message, absl::string_view name) const = 0; virtual JsonMapIterator IterateFields( const google::protobuf::MessageLite& message) const = 0; }; class GeneratedJsonAccessor final : public JsonAccessor { public: static const GeneratedJsonAccessor* absl_nonnull Singleton() { static const absl::NoDestructor singleton; return &*singleton; } google::protobuf::Value::KindCase GetKindCase( const google::protobuf::MessageLite& message) const override { return ValueReflection::GetKindCase( google::protobuf::DownCastMessage(message)); } bool GetBoolValue(const google::protobuf::MessageLite& message) const override { return ValueReflection::GetBoolValue( google::protobuf::DownCastMessage(message)); } double GetNumberValue(const google::protobuf::MessageLite& message) const override { return ValueReflection::GetNumberValue( google::protobuf::DownCastMessage(message)); } well_known_types::StringValue GetStringValue( const google::protobuf::MessageLite& message, std::string&) const override { return ValueReflection::GetStringValue( google::protobuf::DownCastMessage(message)); } const google::protobuf::MessageLite& GetListValue( const google::protobuf::MessageLite& message) const override { return ValueReflection::GetListValue( google::protobuf::DownCastMessage(message)); } int ValuesSize(const google::protobuf::MessageLite& message) const override { return ListValueReflection::ValuesSize( google::protobuf::DownCastMessage(message)); } const google::protobuf::MessageLite& Values(const google::protobuf::MessageLite& message, int index) const override { return ListValueReflection::Values( google::protobuf::DownCastMessage(message), index); } const google::protobuf::MessageLite& GetStructValue( const google::protobuf::MessageLite& message) const override { return ValueReflection::GetStructValue( google::protobuf::DownCastMessage(message)); } int FieldsSize(const google::protobuf::MessageLite& message) const override { return StructReflection::FieldsSize( google::protobuf::DownCastMessage(message)); } const google::protobuf::MessageLite* absl_nullable FindField( const google::protobuf::MessageLite& message, absl::string_view name) const override { return StructReflection::FindField( google::protobuf::DownCastMessage(message), name); } JsonMapIterator IterateFields( const google::protobuf::MessageLite& message) const override { return StructReflection::BeginFields( google::protobuf::DownCastMessage(message)); } }; class DynamicJsonAccessor final : public JsonAccessor { public: void InitializeValue(const google::protobuf::Message& message) { ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK } void InitializeListValue(const google::protobuf::Message& message) { ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK } void InitializeStruct(const google::protobuf::Message& message) { ABSL_CHECK_OK(reflection_.Initialize(message.GetDescriptor())); // Crash OK } google::protobuf::Value::KindCase GetKindCase( const google::protobuf::MessageLite& message) const override { return reflection_.Value().GetKindCase( google::protobuf::DownCastMessage(message)); } bool GetBoolValue(const google::protobuf::MessageLite& message) const override { return reflection_.Value().GetBoolValue( google::protobuf::DownCastMessage(message)); } double GetNumberValue(const google::protobuf::MessageLite& message) const override { return reflection_.Value().GetNumberValue( google::protobuf::DownCastMessage(message)); } well_known_types::StringValue GetStringValue( const google::protobuf::MessageLite& message, std::string& scratch) const override { return reflection_.Value().GetStringValue( google::protobuf::DownCastMessage(message), scratch); } const google::protobuf::MessageLite& GetListValue( const google::protobuf::MessageLite& message) const override { return reflection_.Value().GetListValue( google::protobuf::DownCastMessage(message)); } int ValuesSize(const google::protobuf::MessageLite& message) const override { return reflection_.ListValue().ValuesSize( google::protobuf::DownCastMessage(message)); } const google::protobuf::MessageLite& Values(const google::protobuf::MessageLite& message, int index) const override { return reflection_.ListValue().Values( google::protobuf::DownCastMessage(message), index); } const google::protobuf::MessageLite& GetStructValue( const google::protobuf::MessageLite& message) const override { return reflection_.Value().GetStructValue( google::protobuf::DownCastMessage(message)); } int FieldsSize(const google::protobuf::MessageLite& message) const override { return reflection_.Struct().FieldsSize( google::protobuf::DownCastMessage(message)); } const google::protobuf::MessageLite* absl_nullable FindField( const google::protobuf::MessageLite& message, absl::string_view name) const override { return reflection_.Struct().FindField( google::protobuf::DownCastMessage(message), name); } JsonMapIterator IterateFields( const google::protobuf::MessageLite& message) const override { return reflection_.Struct().BeginFields( google::protobuf::DownCastMessage(message)); } private: JsonReflection reflection_; }; std::string JsonStringDebugString(const well_known_types::StringValue& value) { return absl::visit(absl::Overload( [&](absl::string_view string) -> std::string { return FormatStringLiteral(string); }, [&](const absl::Cord& cord) -> std::string { return FormatStringLiteral(cord); }), well_known_types::AsVariant(value)); } std::string JsonNumberDebugString(double value) { if (std::isfinite(value)) { if (std::floor(value) != value) { // The double is not representable as a whole number, so use // absl::StrCat which will add decimal places. return absl::StrCat(value); } // absl::StrCat historically would represent 0.0 as 0, and we want the // decimal places so ZetaSQL correctly assumes the type as double // instead of int64. std::string stringified = absl::StrCat(value); if (!absl::StrContains(stringified, '.')) { absl::StrAppend(&stringified, ".0"); } else { // absl::StrCat has a decimal now? Use it directly. } return stringified; } if (std::isnan(value)) { return "nan"; } if (std::signbit(value)) { return "-infinity"; } return "+infinity"; } class JsonDebugStringState final { public: JsonDebugStringState(const JsonAccessor* absl_nonnull accessor, std::string* absl_nonnull output) : accessor_(accessor), output_(output) {} void ValueDebugString(const google::protobuf::MessageLite& message) { const auto kind_case = accessor_->GetKindCase(message); switch (kind_case) { case google::protobuf::Value::KIND_NOT_SET: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Value::kNullValue: output_->append("null"); break; case google::protobuf::Value::kBoolValue: if (accessor_->GetBoolValue(message)) { output_->append("true"); } else { output_->append("false"); } break; case google::protobuf::Value::kNumberValue: output_->append( JsonNumberDebugString(accessor_->GetNumberValue(message))); break; case google::protobuf::Value::kStringValue: output_->append(JsonStringDebugString( accessor_->GetStringValue(message, scratch_))); break; case google::protobuf::Value::kListValue: ListValueDebugString(accessor_->GetListValue(message)); break; case google::protobuf::Value::kStructValue: StructDebugString(accessor_->GetStructValue(message)); break; default: // Should not get here, but if for some terrible reason // `google.protobuf.Value` is expanded, just skip. break; } } void ListValueDebugString(const google::protobuf::MessageLite& message) { const int size = accessor_->ValuesSize(message); output_->push_back('['); for (int i = 0; i < size; ++i) { if (i > 0) { output_->append(", "); } ValueDebugString(accessor_->Values(message, i)); } output_->push_back(']'); } void StructDebugString(const google::protobuf::MessageLite& message) { const int size = accessor_->FieldsSize(message); std::string key_scratch; well_known_types::StringValue key; const google::protobuf::MessageLite* absl_nonnull value; auto iterator = accessor_->IterateFields(message); output_->push_back('{'); for (int i = 0; i < size; ++i) { if (i > 0) { output_->append(", "); } std::tie(key, value) = iterator.Next(key_scratch); output_->append(JsonStringDebugString(key)); output_->append(": "); ValueDebugString(*value); } output_->push_back('}'); } private: const JsonAccessor* absl_nonnull const accessor_; std::string* absl_nonnull const output_; std::string scratch_; }; } // namespace std::string JsonDebugString(const google::protobuf::Value& message) { std::string output; JsonDebugStringState(GeneratedJsonAccessor::Singleton(), &output) .ValueDebugString(message); return output; } std::string JsonDebugString(const google::protobuf::Message& message) { DynamicJsonAccessor accessor; accessor.InitializeValue(message); std::string output; JsonDebugStringState(&accessor, &output).ValueDebugString(message); return output; } std::string JsonListDebugString(const google::protobuf::ListValue& message) { std::string output; JsonDebugStringState(GeneratedJsonAccessor::Singleton(), &output) .ListValueDebugString(message); return output; } std::string JsonListDebugString(const google::protobuf::Message& message) { DynamicJsonAccessor accessor; accessor.InitializeListValue(message); std::string output; JsonDebugStringState(&accessor, &output).ListValueDebugString(message); return output; } std::string JsonMapDebugString(const google::protobuf::Struct& message) { std::string output; JsonDebugStringState(GeneratedJsonAccessor::Singleton(), &output) .StructDebugString(message); return output; } std::string JsonMapDebugString(const google::protobuf::Message& message) { DynamicJsonAccessor accessor; accessor.InitializeStruct(message); std::string output; JsonDebugStringState(&accessor, &output).StructDebugString(message); return output; } namespace { class JsonEqualsState final { public: explicit JsonEqualsState(const JsonAccessor* absl_nonnull lhs_accessor, const JsonAccessor* absl_nonnull rhs_accessor) : lhs_accessor_(lhs_accessor), rhs_accessor_(rhs_accessor) {} bool ValueEqual(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs) { auto lhs_kind_case = lhs_accessor_->GetKindCase(lhs); if (lhs_kind_case == google::protobuf::Value::KIND_NOT_SET) { lhs_kind_case = google::protobuf::Value::kNullValue; } auto rhs_kind_case = rhs_accessor_->GetKindCase(rhs); if (rhs_kind_case == google::protobuf::Value::KIND_NOT_SET) { rhs_kind_case = google::protobuf::Value::kNullValue; } if (lhs_kind_case != rhs_kind_case) { return false; } switch (lhs_kind_case) { case google::protobuf::Value::KIND_NOT_SET: ABSL_UNREACHABLE(); case google::protobuf::Value::kNullValue: return true; case google::protobuf::Value::kBoolValue: return lhs_accessor_->GetBoolValue(lhs) == rhs_accessor_->GetBoolValue(rhs); case google::protobuf::Value::kNumberValue: return lhs_accessor_->GetNumberValue(lhs) == rhs_accessor_->GetNumberValue(rhs); case google::protobuf::Value::kStringValue: return lhs_accessor_->GetStringValue(lhs, lhs_scratch_) == rhs_accessor_->GetStringValue(rhs, rhs_scratch_); case google::protobuf::Value::kListValue: return ListValueEqual(lhs_accessor_->GetListValue(lhs), rhs_accessor_->GetListValue(rhs)); case google::protobuf::Value::kStructValue: return StructEqual(lhs_accessor_->GetStructValue(lhs), rhs_accessor_->GetStructValue(rhs)); default: // Should not get here, but if for some terrible reason // `google.protobuf.Value` is expanded, default to false. return false; } } bool ListValueEqual(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs) { const int lhs_size = lhs_accessor_->ValuesSize(lhs); const int rhs_size = rhs_accessor_->ValuesSize(rhs); if (lhs_size != rhs_size) { return false; } for (int i = 0; i < lhs_size; ++i) { if (!ValueEqual(lhs_accessor_->Values(lhs, i), rhs_accessor_->Values(rhs, i))) { return false; } } return true; } bool StructEqual(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs) { const int lhs_size = lhs_accessor_->FieldsSize(lhs); const int rhs_size = rhs_accessor_->FieldsSize(rhs); if (lhs_size != rhs_size) { return false; } if (lhs_size == 0) { return true; } std::string lhs_key_scratch; well_known_types::StringValue lhs_key; const google::protobuf::MessageLite* absl_nonnull lhs_value; auto lhs_iterator = lhs_accessor_->IterateFields(lhs); for (int i = 0; i < lhs_size; ++i) { std::tie(lhs_key, lhs_value) = lhs_iterator.Next(lhs_key_scratch); if (const auto* rhs_value = rhs_accessor_->FindField( rhs, absl::visit( absl::Overload( [](absl::string_view string) -> absl::string_view { return string; }, [&lhs_key_scratch]( const absl::Cord& cord) -> absl::string_view { if (auto flat = cord.TryFlat(); flat) { return *flat; } absl::CopyCordToString(cord, &lhs_key_scratch); return absl::string_view(lhs_key_scratch); }), AsVariant(lhs_key))); rhs_value == nullptr || !ValueEqual(*lhs_value, *rhs_value)) { return false; } } return true; } private: const JsonAccessor* absl_nonnull const lhs_accessor_; const JsonAccessor* absl_nonnull const rhs_accessor_; std::string lhs_scratch_; std::string rhs_scratch_; }; } // namespace bool JsonEquals(const google::protobuf::Value& lhs, const google::protobuf::Value& rhs) { return JsonEqualsState(GeneratedJsonAccessor::Singleton(), GeneratedJsonAccessor::Singleton()) .ValueEqual(lhs, rhs); } bool JsonEquals(const google::protobuf::Value& lhs, const google::protobuf::Message& rhs) { DynamicJsonAccessor rhs_accessor; rhs_accessor.InitializeValue(rhs); return JsonEqualsState(GeneratedJsonAccessor::Singleton(), &rhs_accessor) .ValueEqual(lhs, rhs); } bool JsonEquals(const google::protobuf::Message& lhs, const google::protobuf::Value& rhs) { DynamicJsonAccessor lhs_accessor; lhs_accessor.InitializeValue(lhs); return JsonEqualsState(&lhs_accessor, GeneratedJsonAccessor::Singleton()) .ValueEqual(lhs, rhs); } bool JsonEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs) { DynamicJsonAccessor lhs_accessor; lhs_accessor.InitializeValue(lhs); DynamicJsonAccessor rhs_accessor; rhs_accessor.InitializeValue(rhs); return JsonEqualsState(&lhs_accessor, &rhs_accessor).ValueEqual(lhs, rhs); } bool JsonEquals(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs) { const auto* lhs_generated = google::protobuf::DynamicCastMessage(&lhs); const auto* rhs_generated = google::protobuf::DynamicCastMessage(&rhs); if (lhs_generated && rhs_generated) { return JsonEquals(*lhs_generated, *rhs_generated); } if (lhs_generated) { return JsonEquals(*lhs_generated, google::protobuf::DownCastMessage(rhs)); } if (rhs_generated) { return JsonEquals(google::protobuf::DownCastMessage(lhs), *rhs_generated); } return JsonEquals(google::protobuf::DownCastMessage(lhs), google::protobuf::DownCastMessage(rhs)); } bool JsonListEquals(const google::protobuf::ListValue& lhs, const google::protobuf::ListValue& rhs) { return JsonEqualsState(GeneratedJsonAccessor::Singleton(), GeneratedJsonAccessor::Singleton()) .ListValueEqual(lhs, rhs); } bool JsonListEquals(const google::protobuf::ListValue& lhs, const google::protobuf::Message& rhs) { DynamicJsonAccessor rhs_accessor; rhs_accessor.InitializeListValue(rhs); return JsonEqualsState(GeneratedJsonAccessor::Singleton(), &rhs_accessor) .ListValueEqual(lhs, rhs); } bool JsonListEquals(const google::protobuf::Message& lhs, const google::protobuf::ListValue& rhs) { DynamicJsonAccessor lhs_accessor; lhs_accessor.InitializeListValue(lhs); return JsonEqualsState(&lhs_accessor, GeneratedJsonAccessor::Singleton()) .ListValueEqual(lhs, rhs); } bool JsonListEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs) { DynamicJsonAccessor lhs_accessor; lhs_accessor.InitializeListValue(lhs); DynamicJsonAccessor rhs_accessor; rhs_accessor.InitializeListValue(rhs); return JsonEqualsState(&lhs_accessor, &rhs_accessor).ListValueEqual(lhs, rhs); } bool JsonListEquals(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs) { const auto* lhs_generated = google::protobuf::DynamicCastMessage(&lhs); const auto* rhs_generated = google::protobuf::DynamicCastMessage(&rhs); if (lhs_generated && rhs_generated) { return JsonListEquals(*lhs_generated, *rhs_generated); } if (lhs_generated) { return JsonListEquals(*lhs_generated, google::protobuf::DownCastMessage(rhs)); } if (rhs_generated) { return JsonListEquals(google::protobuf::DownCastMessage(lhs), *rhs_generated); } return JsonListEquals(google::protobuf::DownCastMessage(lhs), google::protobuf::DownCastMessage(rhs)); } bool JsonMapEquals(const google::protobuf::Struct& lhs, const google::protobuf::Struct& rhs) { return JsonEqualsState(GeneratedJsonAccessor::Singleton(), GeneratedJsonAccessor::Singleton()) .StructEqual(lhs, rhs); } bool JsonMapEquals(const google::protobuf::Struct& lhs, const google::protobuf::Message& rhs) { DynamicJsonAccessor rhs_accessor; rhs_accessor.InitializeStruct(rhs); return JsonEqualsState(GeneratedJsonAccessor::Singleton(), &rhs_accessor) .StructEqual(lhs, rhs); } bool JsonMapEquals(const google::protobuf::Message& lhs, const google::protobuf::Struct& rhs) { DynamicJsonAccessor lhs_accessor; lhs_accessor.InitializeStruct(lhs); return JsonEqualsState(&lhs_accessor, GeneratedJsonAccessor::Singleton()) .StructEqual(lhs, rhs); } bool JsonMapEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs) { DynamicJsonAccessor lhs_accessor; lhs_accessor.InitializeStruct(lhs); DynamicJsonAccessor rhs_accessor; rhs_accessor.InitializeStruct(rhs); return JsonEqualsState(&lhs_accessor, &rhs_accessor).StructEqual(lhs, rhs); } bool JsonMapEquals(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs) { const auto* lhs_generated = google::protobuf::DynamicCastMessage(&lhs); const auto* rhs_generated = google::protobuf::DynamicCastMessage(&rhs); if (lhs_generated && rhs_generated) { return JsonMapEquals(*lhs_generated, *rhs_generated); } if (lhs_generated) { return JsonMapEquals(*lhs_generated, google::protobuf::DownCastMessage(rhs)); } if (rhs_generated) { return JsonMapEquals(google::protobuf::DownCastMessage(lhs), *rhs_generated); } return JsonMapEquals(google::protobuf::DownCastMessage(lhs), google::protobuf::DownCastMessage(rhs)); } } // namespace cel::internal ================================================ FILE: internal/json.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_JSON_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_JSON_H_ #include #include "google/protobuf/struct.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::internal { // Converts the given message to its `google.protobuf.Value` equivalent // representation. This is similar to `proto2::json::MessageToJsonString()`, // except that this results in structured serialization. absl::Status MessageToJson( const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Value* absl_nonnull result); absl::Status MessageToJson( const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Struct* absl_nonnull result); absl::Status MessageToJson( const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull result); // Converts the given message field to its `google.protobuf.Value` equivalent // representation. This is similar to `proto2::json::MessageToJsonString()`, // except that this results in structured serialization. absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Value* absl_nonnull result); absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::ListValue* absl_nonnull result); absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Struct* absl_nonnull result); absl::Status MessageFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Message* absl_nonnull result); // Checks that the instance of `google.protobuf.Value` has a descriptor which is // well formed. inline absl::Status CheckJson(const google::protobuf::Value&) { return absl::OkStatus(); } absl::Status CheckJson(const google::protobuf::MessageLite& message); // Checks that the instance of `google.protobuf.ListValue` has a descriptor // which is well formed. inline absl::Status CheckJsonList(const google::protobuf::ListValue&) { return absl::OkStatus(); } absl::Status CheckJsonList(const google::protobuf::MessageLite& message); // Checks that the instance of `google.protobuf.Struct` has a descriptor which // is well formed. inline absl::Status CheckJsonMap(const google::protobuf::Struct&) { return absl::OkStatus(); } absl::Status CheckJsonMap(const google::protobuf::MessageLite& message); // Produces a debug string for the given instance of `google.protobuf.Value`. std::string JsonDebugString(const google::protobuf::Value& message); std::string JsonDebugString(const google::protobuf::Message& message); // Produces a debug string for the given instance of // `google.protobuf.ListValue`. std::string JsonListDebugString(const google::protobuf::ListValue& message); std::string JsonListDebugString(const google::protobuf::Message& message); // Produces a debug string for the given instance of `google.protobuf.Struct`. std::string JsonMapDebugString(const google::protobuf::Struct& message); std::string JsonMapDebugString(const google::protobuf::Message& message); // Compares the given instances of `google.protobuf.Value` for equality. bool JsonEquals(const google::protobuf::Value& lhs, const google::protobuf::Value& rhs); bool JsonEquals(const google::protobuf::Value& lhs, const google::protobuf::Message& rhs); bool JsonEquals(const google::protobuf::Message& lhs, const google::protobuf::Value& rhs); bool JsonEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs); bool JsonEquals(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs); // Compares the given instances of `google.protobuf.ListValue` for equality. bool JsonListEquals(const google::protobuf::ListValue& lhs, const google::protobuf::ListValue& rhs); bool JsonListEquals(const google::protobuf::ListValue& lhs, const google::protobuf::Message& rhs); bool JsonListEquals(const google::protobuf::Message& lhs, const google::protobuf::ListValue& rhs); bool JsonListEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs); bool JsonListEquals(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs); // Compares the given instances of `google.protobuf.Struct` for equality. bool JsonMapEquals(const google::protobuf::Struct& lhs, const google::protobuf::Struct& rhs); bool JsonMapEquals(const google::protobuf::Struct& lhs, const google::protobuf::Message& rhs); bool JsonMapEquals(const google::protobuf::Message& lhs, const google::protobuf::Struct& rhs); bool JsonMapEquals(const google::protobuf::Message& lhs, const google::protobuf::Message& rhs); bool JsonMapEquals(const google::protobuf::MessageLite& lhs, const google::protobuf::MessageLite& rhs); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_JSON_H_ ================================================ FILE: internal/json_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/json.h" #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" #include "internal/equals_text_proto.h" #include "internal/message_type_name.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::internal { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::testing::AnyOf; using ::testing::HasSubstr; using ::testing::Test; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class CheckJsonTest : public Test { public: google::protobuf::Arena* absl_nonnull arena() { return &arena_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return GetTestingDescriptorPool(); } google::protobuf::MessageFactory* absl_nonnull message_factory() { return GetTestingMessageFactory(); } template T* MakeGenerated() { return google::protobuf::Arena::Create(arena()); } template google::protobuf::Message* MakeDynamic() { const auto* descriptor = ABSL_DIE_IF_NULL( descriptor_pool()->FindMessageTypeByName(MessageTypeNameFor())); const auto* prototype = ABSL_DIE_IF_NULL(message_factory()->GetPrototype(descriptor)); return ABSL_DIE_IF_NULL(prototype->New(arena())); } private: google::protobuf::Arena arena_; }; TEST_F(CheckJsonTest, Value_Generated) { EXPECT_THAT(CheckJson(*MakeGenerated()), IsOk()); } TEST_F(CheckJsonTest, Value_Dynamic) { EXPECT_THAT(CheckJson(*MakeDynamic()), IsOk()); } TEST_F(CheckJsonTest, ListValue_Generated) { EXPECT_THAT(CheckJsonList(*MakeGenerated()), IsOk()); } TEST_F(CheckJsonTest, ListValue_Dynamic) { EXPECT_THAT(CheckJsonList(*MakeDynamic()), IsOk()); } TEST_F(CheckJsonTest, Struct_Generated) { EXPECT_THAT(CheckJsonMap(*MakeGenerated()), IsOk()); } TEST_F(CheckJsonTest, Struct_Dynamic) { EXPECT_THAT(CheckJsonMap(*MakeDynamic()), IsOk()); } class MessageToJsonTest : public Test { public: google::protobuf::Arena* absl_nonnull arena() { return &arena_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return GetTestingDescriptorPool(); } google::protobuf::MessageFactory* absl_nonnull message_factory() { return GetTestingMessageFactory(); } template T* MakeGenerated() { return google::protobuf::Arena::Create(arena()); } template google::protobuf::Message* MakeDynamic() { const auto* descriptor = ABSL_DIE_IF_NULL( descriptor_pool()->FindMessageTypeByName(MessageTypeNameFor())); const auto* prototype = ABSL_DIE_IF_NULL(message_factory()->GetPrototype(descriptor)); return ABSL_DIE_IF_NULL(prototype->New(arena())); } template auto DynamicParseTextProto(absl::string_view text) { return ::cel::internal::DynamicParseTextProto( arena(), text, descriptor_pool(), message_factory()); } template auto EqualsTextProto(absl::string_view text) { return ::cel::internal::EqualsTextProto(arena(), text, descriptor_pool(), message_factory()); } private: google::protobuf::Arena arena_; }; TEST_F(MessageToJsonTest, BoolValue_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(value: true)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(bool_value: true)pb")); } TEST_F(MessageToJsonTest, BoolValue_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(value: true)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(bool_value: true)pb")); } TEST_F(MessageToJsonTest, Int32Value_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, Int32Value_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, Int64Value_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, Int64Value_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, UInt32Value_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, UInt32Value_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, UInt64Value_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, UInt64Value_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, FloatValue_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, FloatValue_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, DoubleValue_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, DoubleValue_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(number_value: 1.0)pb")); } TEST_F(MessageToJsonTest, BytesValue_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "Zm9v")pb")); } TEST_F(MessageToJsonTest, BytesValue_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "Zm9v")pb")); } TEST_F(MessageToJsonTest, StringValue_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "foo")pb")); } TEST_F(MessageToJsonTest, StringValue_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(value: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "foo")pb")); } TEST_F(MessageToJsonTest, Duration_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(seconds: 1 nanos: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "1.000000001s")pb")); } TEST_F(MessageToJsonTest, Duration_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(seconds: 1 nanos: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "1.000000001s")pb")); } TEST_F(MessageToJsonTest, Timestamp_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(seconds: 1 nanos: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "1970-01-01T00:00:01.000000001Z")pb")); } TEST_F(MessageToJsonTest, Timestamp_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(seconds: 1 nanos: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "1970-01-01T00:00:01.000000001Z")pb")); } TEST_F(MessageToJsonTest, Value_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(bool_value: true)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(bool_value: true)pb")); } TEST_F(MessageToJsonTest, Value_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(bool_value: true)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(bool_value: true)pb")); } TEST_F(MessageToJsonTest, ListValue_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(values { bool_value: true })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(list_value: { values { bool_value: true } })pb")); } TEST_F(MessageToJsonTest, ListValue_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(values { bool_value: true })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(list_value: { values { bool_value: true } })pb")); } TEST_F(MessageToJsonTest, Struct_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(fields { key: "foo" value: { bool_value: true } })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "foo" value: { bool_value: true } } })pb")); } TEST_F(MessageToJsonTest, Struct_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(fields { key: "foo" value: { bool_value: true } })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "foo" value: { bool_value: true } } })pb")); } TEST_F(MessageToJsonTest, FieldMask_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(paths: "foo" paths: "bar")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "foo,bar")pb")); } TEST_F(MessageToJsonTest, FieldMask_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(paths: "foo" paths: "bar")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(string_value: "foo,bar")pb")); } TEST_F(MessageToJsonTest, FieldMask_BadUpperCase) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(paths: "Foo")pb"), descriptor_pool(), message_factory(), result), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("field mask path name contains uppercase letters"))); } TEST_F(MessageToJsonTest, FieldMask_BadUnderscoreUpperCase) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(paths: "foo_?")pb"), descriptor_pool(), message_factory(), result), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("field mask path contains '_' not followed by " "a lowercase letter"))); } TEST_F(MessageToJsonTest, FieldMask_BadTrailingUnderscore) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(paths: "foo_")pb"), descriptor_pool(), message_factory(), result), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("field mask path contains trailing '_'"))); } TEST_F(MessageToJsonTest, Any_WellKnownType_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson( *DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.BoolValue" value: "\x08\x01")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "@type" value: { string_value: "type.googleapis.com/google.protobuf.BoolValue" } } fields { key: "value" value: { bool_value: true } } })pb")); } TEST_F(MessageToJsonTest, Any_WellKnownType_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson( *DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.BoolValue" value: "\x08\x01")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "@type" value: { string_value: "type.googleapis.com/google.protobuf.BoolValue" } } fields { key: "value" value: { bool_value: true } } })pb")); } TEST_F(MessageToJsonTest, Any_Empty_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson( *DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Empty")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "@type" value: { string_value: "type.googleapis.com/google.protobuf.Empty" } } fields { key: "value" value: { struct_value: {} } } })pb")); } TEST_F(MessageToJsonTest, Any_Empty_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson( *DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Empty")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "@type" value: { string_value: "type.googleapis.com/google.protobuf.Empty" } } fields { key: "value" value: { struct_value: {} } } })pb")); } TEST_F(MessageToJsonTest, Any_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson( *DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" value: "\x68\x01")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "@type" value: { string_value: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" } } fields { key: "singleBool" value: { bool_value: true } } })pb")); } TEST_F(MessageToJsonTest, Any_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson( *DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" value: "\x68\x01")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "@type" value: { string_value: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes" } } fields { key: "singleBool" value: { bool_value: true } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Bool_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_bool: true)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleBool" value: { bool_value: true } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Bool_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_bool: true)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleBool" value: { bool_value: true } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Int32_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_int32: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleInt32" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Int32_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_int32: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleInt32" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Int64_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_int64: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleInt64" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Int64_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_int64: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleInt64" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_UInt32_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_uint32: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleUint32" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_UInt32_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_uint32: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleUint32" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_UInt64_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_uint64: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleUint64" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_UInt64_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_uint64: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleUint64" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Float_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_float: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleFloat" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Float_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_float: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleFloat" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Double_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_double: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleDouble" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Double_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_double: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleDouble" value: { number_value: 1.0 } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Bytes_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_bytes: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleBytes" value: { string_value: "Zm9v" } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Bytes_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_bytes: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleBytes" value: { string_value: "Zm9v" } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_String_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_string: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleString" value: { string_value: "foo" } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_String_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(single_string: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "singleString" value: { string_value: "foo" } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Message_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(standalone_message: { bb: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "standaloneMessage" value: { struct_value: { fields { key: "bb" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Message_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(standalone_message: { bb: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "standaloneMessage" value: { struct_value: { fields { key: "bb" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Enum_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(standalone_enum: BAR)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "standaloneEnum" value: { string_value: "BAR" } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_Enum_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(standalone_enum: BAR)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "standaloneEnum" value: { string_value: "BAR" } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedBool_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_bool: true)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedBool" value: { list_value: { values: { bool_value: true } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedBool_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_bool: true)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedBool" value: { list_value: { values: { bool_value: true } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedInt32_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_int32: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedInt32" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedInt32_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_int32: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedInt32" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedInt64_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_int64: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedInt64" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedInt64_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_int64: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedInt64" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedUInt32_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_uint32: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedUint32" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedUInt32_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_uint32: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedUint32" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedUInt64_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_uint64: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedUint64" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedUInt64_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_uint64: 1)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedUint64" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedFloat_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_float: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedFloat" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedFloat_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_float: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedFloat" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedDouble_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_double: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedDouble" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedDouble_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_double: 1.0)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedDouble" value: { list_value: { values: { number_value: 1.0 } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedBytes_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_bytes: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedBytes" value: { list_value: { values: { string_value: "Zm9v" } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedBytes_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_bytes: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedBytes" value: { list_value: { values: { string_value: "Zm9v" } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedString_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_string: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedString" value: { list_value: { values: { string_value: "foo" } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedString_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_string: "foo")pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedString" value: { list_value: { values: { string_value: "foo" } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedMessage_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_nested_message: { bb: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedNestedMessage" value: { list_value: { values: { struct_value: { fields { key: "bb" value: { number_value: 1.0 } } } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedMessage_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_nested_message: { bb: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedNestedMessage" value: { list_value: { values: { struct_value: { fields { key: "bb" value: { number_value: 1.0 } } } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedEnum_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_nested_enum: BAR)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedNestedEnum" value: { list_value: { values: { string_value: "BAR" } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedEnum_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_nested_enum: BAR)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedNestedEnum" value: { list_value: { values: { string_value: "BAR" } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedNull_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_null_value: NULL_VALUE)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedNullValue" value: { list_value: { values: { null_value: NULL_VALUE } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_RepeatedNull_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(repeated_null_value: NULL_VALUE)pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT( *result, EqualsTextProto( R"pb(struct_value: { fields { key: "repeatedNullValue" value: { list_value: { values: { null_value: NULL_VALUE } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapBoolBool_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_bool_bool: { key: true value: true })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapBoolBool" value: { struct_value: { fields { key: "true" value: { bool_value: true } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapBoolBool_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_bool_bool: { key: true value: true })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapBoolBool" value: { struct_value: { fields { key: "true" value: { bool_value: true } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapInt32Int32_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(map_int32_int32: { key: 1 value: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapInt32Int32" value: { struct_value: { fields { key: "1" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapInt32Int32_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(map_int32_int32: { key: 1 value: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapInt32Int32" value: { struct_value: { fields { key: "1" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapInt64Int64_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(map_int64_int64: { key: 1 value: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapInt64Int64" value: { struct_value: { fields { key: "1" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapInt64Int64_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(map_int64_int64: { key: 1 value: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapInt64Int64" value: { struct_value: { fields { key: "1" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapUInt32UInt32_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_uint32_uint32: { key: 1 value: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapUint32Uint32" value: { struct_value: { fields { key: "1" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapUInt32UInt32_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_uint32_uint32: { key: 1 value: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapUint32Uint32" value: { struct_value: { fields { key: "1" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapUInt64UInt64_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_uint64_uint64: { key: 1 value: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapUint64Uint64" value: { struct_value: { fields { key: "1" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapUInt64UInt64_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_uint64_uint64: { key: 1 value: 1 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapUint64Uint64" value: { struct_value: { fields { key: "1" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringString_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson( *DynamicParseTextProto( R"pb(map_string_string: { key: "foo" value: "bar" })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringString" value: { struct_value: { fields { key: "foo" value: { string_value: "bar" } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringString_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson( *DynamicParseTextProto( R"pb(map_string_string: { key: "foo" value: "bar" })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringString" value: { struct_value: { fields { key: "foo" value: { string_value: "bar" } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringFloat_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_string_float: { key: "foo" value: 1.0 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringFloat" value: { struct_value: { fields { key: "foo" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringFloat_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_string_float: { key: "foo" value: 1.0 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringFloat" value: { struct_value: { fields { key: "foo" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringDouble_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_string_double: { key: "foo" value: 1.0 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringDouble" value: { struct_value: { fields { key: "foo" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringDouble_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_string_double: { key: "foo" value: 1.0 })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringDouble" value: { struct_value: { fields { key: "foo" value: { number_value: 1.0 } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringBytes_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_string_bytes: { key: "foo" value: "bar" })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringBytes" value: { struct_value: { fields { key: "foo" value: { string_value: "YmFy" } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringBytes_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_string_bytes: { key: "foo" value: "bar" })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringBytes" value: { struct_value: { fields { key: "foo" value: { string_value: "YmFy" } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringMessage_Generated) { auto* result = MakeGenerated(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(map_string_message: { key: "foo" value: { bb: 1 } })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringMessage" value: { struct_value: { fields { key: "foo" value: { struct_value: { fields { key: "bb" value: { number_value: 1.0 } } } } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringMessage_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT(MessageToJson(*DynamicParseTextProto( R"pb(map_string_message: { key: "foo" value: { bb: 1 } })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringMessage" value: { struct_value: { fields { key: "foo" value: { struct_value: { fields { key: "bb" value: { number_value: 1.0 } } } } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringEnum_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_string_enum: { key: "foo" value: BAR })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringEnum" value: { struct_value: { fields { key: "foo" value: { string_value: "BAR" } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringEnum_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson(*DynamicParseTextProto( R"pb(map_string_enum: { key: "foo" value: BAR })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringEnum" value: { struct_value: { fields { key: "foo" value: { string_value: "BAR" } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringNull_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageToJson( *DynamicParseTextProto( R"pb(map_string_null_value: { key: "foo" value: NULL_VALUE })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringNullValue" value: { struct_value: { fields { key: "foo" value: { null_value: NULL_VALUE } } } } } })pb")); } TEST_F(MessageToJsonTest, TestAllTypesProto3_MapStringNull_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageToJson( *DynamicParseTextProto( R"pb(map_string_null_value: { key: "foo" value: NULL_VALUE })pb"), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(struct_value: { fields { key: "mapStringNullValue" value: { struct_value: { fields { key: "foo" value: { null_value: NULL_VALUE } } } } } })pb")); } class MessageFieldToJsonTest : public Test { public: google::protobuf::Arena* absl_nonnull arena() { return &arena_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return GetTestingDescriptorPool(); } google::protobuf::MessageFactory* absl_nonnull message_factory() { return GetTestingMessageFactory(); } template T* MakeGenerated() { return google::protobuf::Arena::Create(arena()); } template google::protobuf::Message* MakeDynamic() { const auto* descriptor = ABSL_DIE_IF_NULL( descriptor_pool()->FindMessageTypeByName(MessageTypeNameFor())); const auto* prototype = ABSL_DIE_IF_NULL(message_factory()->GetPrototype(descriptor)); return ABSL_DIE_IF_NULL(prototype->New(arena())); } template auto DynamicParseTextProto(absl::string_view text) { return ::cel::internal::DynamicParseTextProto( arena(), text, descriptor_pool(), message_factory()); } template auto EqualsTextProto(absl::string_view text) { return ::cel::internal::EqualsTextProto(arena(), text, descriptor_pool(), message_factory()); } private: google::protobuf::Arena arena_; }; TEST_F(MessageFieldToJsonTest, TestAllTypesProto3_Generated) { auto* result = MakeGenerated(); EXPECT_THAT( MessageFieldToJson( *DynamicParseTextProto( R"pb(single_bool: true)pb"), ABSL_DIE_IF_NULL( ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("single_bool")), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(bool_value: true)pb")); } TEST_F(MessageFieldToJsonTest, TestAllTypesProto3_Dynamic) { auto* result = MakeDynamic(); EXPECT_THAT( MessageFieldToJson( *DynamicParseTextProto( R"pb(single_bool: true)pb"), ABSL_DIE_IF_NULL( ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")) ->FindFieldByName("single_bool")), descriptor_pool(), message_factory(), result), IsOk()); EXPECT_THAT(*result, EqualsTextProto( R"pb(bool_value: true)pb")); } class JsonDebugStringTest : public Test { public: google::protobuf::Arena* absl_nonnull arena() { return &arena_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return GetTestingDescriptorPool(); } google::protobuf::MessageFactory* absl_nonnull message_factory() { return GetTestingMessageFactory(); } template auto GeneratedParseTextProto(absl::string_view text) { return ::cel::internal::GeneratedParseTextProto( arena(), text, descriptor_pool(), message_factory()); } template auto DynamicParseTextProto(absl::string_view text) { return ::cel::internal::DynamicParseTextProto( arena(), text, descriptor_pool(), message_factory()); } private: google::protobuf::Arena arena_; }; TEST_F(JsonDebugStringTest, Null_Generated) { EXPECT_EQ(JsonDebugString( *GeneratedParseTextProto(R"pb()pb")), "null"); } TEST_F(JsonDebugStringTest, Null_Dynamic) { EXPECT_EQ(JsonDebugString( *DynamicParseTextProto(R"pb()pb")), "null"); } TEST_F(JsonDebugStringTest, Bool_Generated) { EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(bool_value: false)pb")), "false"); EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(bool_value: true)pb")), "true"); } TEST_F(JsonDebugStringTest, Bool_Dynamic) { EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(bool_value: false)pb")), "false"); EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(bool_value: true)pb")), "true"); } TEST_F(JsonDebugStringTest, Number_Generated) { EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(number_value: 1.0)pb")), "1.0"); EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(number_value: 1.1)pb")), "1.1"); EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(number_value: infinity)pb")), "+infinity"); EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(number_value: -infinity)pb")), "-infinity"); EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(number_value: nan)pb")), "nan"); } TEST_F(JsonDebugStringTest, Number_Dynamic) { EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(number_value: 1.0)pb")), "1.0"); EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(number_value: 1.1)pb")), "1.1"); EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(number_value: infinity)pb")), "+infinity"); EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(number_value: -infinity)pb")), "-infinity"); EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(number_value: nan)pb")), "nan"); } TEST_F(JsonDebugStringTest, String_Generated) { EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(string_value: "foo")pb")), "\"foo\""); } TEST_F(JsonDebugStringTest, String_Dynamic) { EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(string_value: "foo")pb")), "\"foo\""); } TEST_F(JsonDebugStringTest, List_Generated) { EXPECT_EQ(JsonDebugString(*GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")), "[null, true]"); EXPECT_EQ( JsonListDebugString(*GeneratedParseTextProto( R"pb( values {} values { bool_value: true })pb")), "[null, true]"); } TEST_F(JsonDebugStringTest, List_Dynamic) { EXPECT_EQ(JsonDebugString(*DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")), "[null, true]"); EXPECT_EQ( JsonListDebugString(*DynamicParseTextProto( R"pb( values {} values { bool_value: true })pb")), "[null, true]"); } TEST_F(JsonDebugStringTest, Struct_Generated) { EXPECT_THAT(JsonDebugString(*GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")), AnyOf("{\"foo\": null, \"bar\": true}", "{\"bar\": true, \"foo\": null}")); EXPECT_THAT( JsonMapDebugString(*GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb")), AnyOf("{\"foo\": null, \"bar\": true}", "{\"bar\": true, \"foo\": null}")); } TEST_F(JsonDebugStringTest, Struct_Dynamic) { EXPECT_THAT(JsonDebugString(*DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")), AnyOf("{\"foo\": null, \"bar\": true}", "{\"bar\": true, \"foo\": null}")); EXPECT_THAT( JsonMapDebugString(*DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } })pb")), AnyOf("{\"foo\": null, \"bar\": true}", "{\"bar\": true, \"foo\": null}")); } class JsonEqualsTest : public Test { public: google::protobuf::Arena* absl_nonnull arena() { return &arena_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return GetTestingDescriptorPool(); } google::protobuf::MessageFactory* absl_nonnull message_factory() { return GetTestingMessageFactory(); } template auto GeneratedParseTextProto(absl::string_view text) { return ::cel::internal::GeneratedParseTextProto( arena(), text, descriptor_pool(), message_factory()); } template auto DynamicParseTextProto(absl::string_view text) { return ::cel::internal::DynamicParseTextProto( arena(), text, descriptor_pool(), message_factory()); } private: google::protobuf::Arena arena_; }; TEST_F(JsonEqualsTest, Null_Null_Generated_Generated) { EXPECT_TRUE( JsonEquals(*GeneratedParseTextProto(R"pb()pb"), *GeneratedParseTextProto(R"pb()pb"))); } TEST_F(JsonEqualsTest, Null_Null_Generated_Dynamic) { EXPECT_TRUE( JsonEquals(*GeneratedParseTextProto(R"pb()pb"), *DynamicParseTextProto(R"pb()pb"))); } TEST_F(JsonEqualsTest, Null_Null_Dynamic_Generated) { EXPECT_TRUE( JsonEquals(*DynamicParseTextProto(R"pb()pb"), *GeneratedParseTextProto(R"pb()pb"))); } TEST_F(JsonEqualsTest, Null_Null_Dynamic_Dynamic) { EXPECT_TRUE( JsonEquals(*DynamicParseTextProto(R"pb()pb"), *DynamicParseTextProto(R"pb()pb"))); } TEST_F(JsonEqualsTest, Bool_Bool_Generated_Generated) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(bool_value: true)pb"), *GeneratedParseTextProto( R"pb(bool_value: true)pb"))); } TEST_F(JsonEqualsTest, Bool_Bool_Generated_Dynamic) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(bool_value: true)pb"), *DynamicParseTextProto( R"pb(bool_value: true)pb"))); } TEST_F(JsonEqualsTest, Bool_Bool_Dynamic_Generated) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(bool_value: true)pb"), *GeneratedParseTextProto( R"pb(bool_value: true)pb"))); } TEST_F(JsonEqualsTest, Bool_Bool_Dynamic_Dynamic) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(bool_value: true)pb"), *DynamicParseTextProto( R"pb(bool_value: true)pb"))); } TEST_F(JsonEqualsTest, Number_Number_Generated_Generated) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(number_value: 1.0)pb"), *GeneratedParseTextProto( R"pb(number_value: 1.0)pb"))); } TEST_F(JsonEqualsTest, Number_Number_Generated_Dynamic) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(number_value: 1.0)pb"), *DynamicParseTextProto( R"pb(number_value: 1.0)pb"))); } TEST_F(JsonEqualsTest, Number_Number_Dynamic_Generated) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(number_value: 1.0)pb"), *GeneratedParseTextProto( R"pb(number_value: 1.0)pb"))); } TEST_F(JsonEqualsTest, Number_Number_Dynamic_Dynamic) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(number_value: 1.0)pb"), *DynamicParseTextProto( R"pb(number_value: 1.0)pb"))); } TEST_F(JsonEqualsTest, String_String_Generated_Generated) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(string_value: "foo")pb"), *GeneratedParseTextProto( R"pb(string_value: "foo")pb"))); } TEST_F(JsonEqualsTest, String_String_Generated_Dynamic) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(string_value: "foo")pb"), *DynamicParseTextProto( R"pb(string_value: "foo")pb"))); } TEST_F(JsonEqualsTest, String_String_Dynamic_Generated) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(string_value: "foo")pb"), *GeneratedParseTextProto( R"pb(string_value: "foo")pb"))); } TEST_F(JsonEqualsTest, String_String_Dynamic_Dynamic) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(string_value: "foo")pb"), *DynamicParseTextProto( R"pb(string_value: "foo")pb"))); } TEST_F(JsonEqualsTest, List_List_Generated_Generated) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"), *GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"))); EXPECT_TRUE(JsonEquals(static_cast( *GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")), static_cast( *GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")))); EXPECT_TRUE( JsonListEquals(*GeneratedParseTextProto( R"pb( values {} values { bool_value: true } )pb"), *GeneratedParseTextProto( R"pb( values {} values { bool_value: true } )pb"))); EXPECT_TRUE( JsonListEquals(static_cast( *GeneratedParseTextProto( R"pb( values {} values { bool_value: true } )pb")), static_cast( *GeneratedParseTextProto( R"pb( values {} values { bool_value: true } )pb")))); } TEST_F(JsonEqualsTest, List_List_Generated_Dynamic) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"), *DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"))); EXPECT_TRUE(JsonEquals(static_cast( *GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")), static_cast( *DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")))); EXPECT_TRUE( JsonListEquals(*GeneratedParseTextProto( R"pb( values {} values { bool_value: true } )pb"), *DynamicParseTextProto( R"pb( values {} values { bool_value: true } )pb"))); EXPECT_TRUE( JsonListEquals(static_cast( *GeneratedParseTextProto( R"pb( values {} values { bool_value: true } )pb")), static_cast( *DynamicParseTextProto( R"pb( values {} values { bool_value: true } )pb")))); } TEST_F(JsonEqualsTest, List_List_Dynamic_Generated) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"), *GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"))); EXPECT_TRUE(JsonEquals(static_cast( *DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")), static_cast( *GeneratedParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")))); EXPECT_TRUE( JsonListEquals(*DynamicParseTextProto( R"pb( values {} values { bool_value: true } )pb"), *GeneratedParseTextProto( R"pb( values {} values { bool_value: true } )pb"))); EXPECT_TRUE( JsonListEquals(static_cast( *DynamicParseTextProto( R"pb( values {} values { bool_value: true } )pb")), static_cast( *GeneratedParseTextProto( R"pb( values {} values { bool_value: true } )pb")))); } TEST_F(JsonEqualsTest, List_List_Dynamic_Dynamic) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"), *DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb"))); EXPECT_TRUE(JsonEquals(static_cast( *DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")), static_cast( *DynamicParseTextProto( R"pb(list_value: { values {} values { bool_value: true } })pb")))); EXPECT_TRUE( JsonListEquals(*DynamicParseTextProto( R"pb( values {} values { bool_value: true } )pb"), *DynamicParseTextProto( R"pb( values {} values { bool_value: true } )pb"))); EXPECT_TRUE( JsonListEquals(static_cast( *DynamicParseTextProto( R"pb( values {} values { bool_value: true } )pb")), static_cast( *DynamicParseTextProto( R"pb( values {} values { bool_value: true } )pb")))); } TEST_F(JsonEqualsTest, Map_Map_Generated_Generated) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"), *GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"))); EXPECT_TRUE(JsonEquals(static_cast( *GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")), static_cast( *GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")))); EXPECT_TRUE(JsonMapEquals(*GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb"), *GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb"))); EXPECT_TRUE( JsonMapEquals(static_cast( *GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb")), static_cast( *GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb")))); } TEST_F(JsonEqualsTest, Map_Map_Generated_Dynamic) { EXPECT_TRUE(JsonEquals(*GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"), *DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"))); EXPECT_TRUE(JsonEquals(static_cast( *GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")), static_cast( *DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")))); EXPECT_TRUE(JsonMapEquals(*GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb"), *DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb"))); EXPECT_TRUE( JsonMapEquals(static_cast( *GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb")), static_cast( *DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb")))); } TEST_F(JsonEqualsTest, Map_Map_Dynamic_Generated) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"), *GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"))); EXPECT_TRUE(JsonEquals(static_cast( *DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")), static_cast( *GeneratedParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")))); EXPECT_TRUE(JsonMapEquals(*DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb"), *GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb"))); EXPECT_TRUE( JsonMapEquals(static_cast( *DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb")), static_cast( *GeneratedParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb")))); } TEST_F(JsonEqualsTest, Map_Map_Dynamic_Dynamic) { EXPECT_TRUE(JsonEquals(*DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"), *DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb"))); EXPECT_TRUE(JsonEquals(static_cast( *DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")), static_cast( *DynamicParseTextProto( R"pb(struct_value: { fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } })pb")))); EXPECT_TRUE(JsonMapEquals(*DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb"), *DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb"))); EXPECT_TRUE( JsonMapEquals(static_cast( *DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb")), static_cast( *DynamicParseTextProto( R"pb( fields { key: "foo" value: {} } fields { key: "bar" value: { bool_value: true } } )pb")))); } } // namespace } // namespace cel::internal ================================================ FILE: internal/lexis.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/lexis.h" #include "absl/base/call_once.h" #include "absl/base/macros.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/ascii.h" namespace cel::internal { namespace { ABSL_CONST_INIT absl::once_flag reserved_keywords_once_flag = {}; ABSL_CONST_INIT absl::flat_hash_set* reserved_keywords = nullptr; void InitializeReservedKeywords() { ABSL_ASSERT(reserved_keywords == nullptr); reserved_keywords = new absl::flat_hash_set(); reserved_keywords->insert("false"); reserved_keywords->insert("true"); reserved_keywords->insert("null"); reserved_keywords->insert("in"); reserved_keywords->insert("as"); reserved_keywords->insert("break"); reserved_keywords->insert("const"); reserved_keywords->insert("continue"); reserved_keywords->insert("else"); reserved_keywords->insert("for"); reserved_keywords->insert("function"); reserved_keywords->insert("if"); reserved_keywords->insert("import"); reserved_keywords->insert("let"); reserved_keywords->insert("loop"); reserved_keywords->insert("package"); reserved_keywords->insert("namespace"); reserved_keywords->insert("return"); reserved_keywords->insert("var"); reserved_keywords->insert("void"); reserved_keywords->insert("while"); } } // namespace bool LexisIsReserved(absl::string_view text) { absl::call_once(reserved_keywords_once_flag, InitializeReservedKeywords); return reserved_keywords->find(text) != reserved_keywords->end(); } bool LexisIsIdentifier(absl::string_view text) { if (text.empty()) { return false; } char first = text.front(); if (!absl::ascii_isalpha(first) && first != '_') { return false; } for (size_t index = 1; index < text.size(); index++) { if (!absl::ascii_isalnum(text[index]) && text[index] != '_') { return false; } } return !LexisIsReserved(text); } } // namespace cel::internal ================================================ FILE: internal/lexis.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_LEXIS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_LEXIS_H_ #include "absl/strings/string_view.h" namespace cel::internal { // Returns true if the given text matches RESERVED per the lexis of the CEL // specification. bool LexisIsReserved(absl::string_view text); // Returns true if the given text matches IDENT per the lexis of the CEL // specification, fales otherwise. bool LexisIsIdentifier(absl::string_view text); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_LEXIS_H_ ================================================ FILE: internal/lexis_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/lexis.h" #include "internal/testing.h" namespace cel::internal { namespace { struct LexisTestCase final { absl::string_view text; bool ok; }; using LexisIsReservedTest = testing::TestWithParam; TEST_P(LexisIsReservedTest, Compliance) { const LexisTestCase& test_case = GetParam(); if (test_case.ok) { EXPECT_TRUE(LexisIsReserved(test_case.text)); } else { EXPECT_FALSE(LexisIsReserved(test_case.text)); } } INSTANTIATE_TEST_SUITE_P(LexisIsReservedTest, LexisIsReservedTest, testing::ValuesIn({{"true", true}, {"cel", false}})); using LexisIsIdentifierTest = testing::TestWithParam; TEST_P(LexisIsIdentifierTest, Compliance) { const LexisTestCase& test_case = GetParam(); if (test_case.ok) { EXPECT_TRUE(LexisIsIdentifier(test_case.text)); } else { EXPECT_FALSE(LexisIsIdentifier(test_case.text)); } } INSTANTIATE_TEST_SUITE_P( LexisIsIdentifierTest, LexisIsIdentifierTest, testing::ValuesIn( {{"true", false}, {"0abc", false}, {"-abc", false}, {".abc", false}, {"~abc", false}, {"!abc", false}, {"abc-", false}, {"abc.", false}, {"abc~", false}, {"abc!", false}, {"cel", true}, {"cel0", true}, {"_cel", true}, {"_cel0", true}, {"cel_", true}, {"cel0_", true}, {"cel_cel", true}, {"cel0_cel", true}, {"cel_cel0", true}, {"cel0_cel0", true}})); } // namespace } // namespace cel::internal ================================================ FILE: internal/manual.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MANUAL_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_MANUAL_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" namespace cel::internal { template class Manual final { public: static_assert(!std::is_reference_v, "T must not be a reference"); static_assert(!std::is_array_v, "T must not be an array"); static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); using element_type = T; Manual() = default; Manual(const Manual&) = delete; Manual(Manual&&) = delete; ~Manual() = default; Manual& operator=(const Manual&) = delete; Manual& operator=(Manual&&) = delete; constexpr T* absl_nonnull get() ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::launder(reinterpret_cast(&storage_[0])); } constexpr const T* absl_nonnull get() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return std::launder(reinterpret_cast(&storage_[0])); } constexpr T& operator*() ABSL_ATTRIBUTE_LIFETIME_BOUND { return *get(); } constexpr const T& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return *get(); } constexpr T* absl_nonnull operator->() ABSL_ATTRIBUTE_LIFETIME_BOUND { return get(); } constexpr const T* absl_nonnull operator->() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return get(); } template T* absl_nonnull Construct(Args&&... args) ABSL_ATTRIBUTE_LIFETIME_BOUND { return ::new (static_cast(&storage_[0])) T(std::forward(args)...); } T* absl_nonnull DefaultConstruct() { return ::new (static_cast(&storage_[0])) T; } T* absl_nonnull ValueConstruct() { return ::new (static_cast(&storage_[0])) T(); } void Destruct() { get()->~T(); } private: alignas(T) char storage_[sizeof(T)]; }; } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_MANUAL_H_ ================================================ FILE: internal/message_equality.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/message_equality.h" #include #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/variant.h" #include "common/memory.h" #include "extensions/protobuf/internal/map_reflection.h" #include "internal/json.h" #include "internal/number.h" #include "internal/status_macros.h" #include "internal/well_known_types.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/util/message_differencer.h" #undef GetMessage namespace cel::internal { namespace { using ::cel::extensions::protobuf_internal::ConstMapBegin; using ::cel::extensions::protobuf_internal::ConstMapEnd; using ::cel::extensions::protobuf_internal::LookupMapValue; using ::cel::extensions::protobuf_internal::MapSize; using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorPool; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Message; using ::google::protobuf::MessageFactory; using ::google::protobuf::util::MessageDifferencer; class EquatableListValue final : public std::reference_wrapper { public: using std::reference_wrapper::reference_wrapper; }; class EquatableStruct final : public std::reference_wrapper { public: using std::reference_wrapper::reference_wrapper; }; class EquatableAny final : public std::reference_wrapper { public: using std::reference_wrapper::reference_wrapper; }; class EquatableMessage final : public std::reference_wrapper { public: using std::reference_wrapper::reference_wrapper; }; using EquatableValue = absl::variant; struct NullValueEqualer { bool operator()(std::nullptr_t, std::nullptr_t) const { return true; } template std::enable_if_t>, bool> operator()(std::nullptr_t, const T&) const { return false; } }; struct BoolValueEqualer { bool operator()(bool lhs, bool rhs) const { return lhs == rhs; } template std::enable_if_t>, bool> operator()( bool, const T&) const { return false; } }; struct BytesValueEqualer { bool operator()(const well_known_types::BytesValue& lhs, const well_known_types::BytesValue& rhs) const { return lhs == rhs; } template std::enable_if_t< std::negation_v>, bool> operator()(const well_known_types::BytesValue&, const T&) const { return false; } }; struct IntValueEqualer { bool operator()(int64_t lhs, int64_t rhs) const { return lhs == rhs; } bool operator()(int64_t lhs, uint64_t rhs) const { return Number::FromInt64(lhs) == Number::FromUint64(rhs); } bool operator()(int64_t lhs, double rhs) const { return Number::FromInt64(lhs) == Number::FromDouble(rhs); } template std::enable_if_t>, std::negation>, std::negation>>, bool> operator()(int64_t, const T&) const { return false; } }; struct UintValueEqualer { bool operator()(uint64_t lhs, int64_t rhs) const { return Number::FromUint64(lhs) == Number::FromInt64(rhs); } bool operator()(uint64_t lhs, uint64_t rhs) const { return lhs == rhs; } bool operator()(uint64_t lhs, double rhs) const { return Number::FromUint64(lhs) == Number::FromDouble(rhs); } template std::enable_if_t>, std::negation>, std::negation>>, bool> operator()(uint64_t, const T&) const { return false; } }; struct DoubleValueEqualer { bool operator()(double lhs, int64_t rhs) const { return Number::FromDouble(lhs) == Number::FromInt64(rhs); } bool operator()(double lhs, uint64_t rhs) const { return Number::FromDouble(lhs) == Number::FromUint64(rhs); } bool operator()(double lhs, double rhs) const { return lhs == rhs; } template std::enable_if_t>, std::negation>, std::negation>>, bool> operator()(double, const T&) const { return false; } }; struct StringValueEqualer { bool operator()(const well_known_types::StringValue& lhs, const well_known_types::StringValue& rhs) const { return lhs == rhs; } template std::enable_if_t< std::negation_v>, bool> operator()(const well_known_types::StringValue&, const T&) const { return false; } }; struct DurationEqualer { bool operator()(absl::Duration lhs, absl::Duration rhs) const { return lhs == rhs; } template std::enable_if_t>, bool> operator()(absl::Duration, const T&) const { return false; } }; struct TimestampEqualer { bool operator()(absl::Time lhs, absl::Time rhs) const { return lhs == rhs; } template std::enable_if_t>, bool> operator()(absl::Time, const T&) const { return false; } }; struct ListValueEqualer { bool operator()(EquatableListValue lhs, EquatableListValue rhs) const { return JsonListEquals(lhs, rhs); } template std::enable_if_t>, bool> operator()(EquatableListValue, const T&) const { return false; } }; struct StructEqualer { bool operator()(EquatableStruct lhs, EquatableStruct rhs) const { return JsonMapEquals(lhs, rhs); } template std::enable_if_t>, bool> operator()(EquatableStruct, const T&) const { return false; } }; struct AnyEqualer { bool operator()(EquatableAny lhs, EquatableAny rhs) const { auto lhs_reflection = well_known_types::GetAnyReflectionOrDie(lhs.get().GetDescriptor()); std::string lhs_type_url_scratch; std::string lhs_value_scratch; auto rhs_reflection = well_known_types::GetAnyReflectionOrDie(rhs.get().GetDescriptor()); std::string rhs_type_url_scratch; std::string rhs_value_scratch; return lhs_reflection.GetTypeUrl(lhs.get(), lhs_type_url_scratch) == rhs_reflection.GetTypeUrl(rhs.get(), rhs_type_url_scratch) && lhs_reflection.GetValue(lhs.get(), lhs_value_scratch) == rhs_reflection.GetValue(rhs.get(), rhs_value_scratch); } template std::enable_if_t>, bool> operator()(EquatableAny, const T&) const { return false; } }; struct MessageEqualer { bool operator()(EquatableMessage lhs, EquatableMessage rhs) const { return lhs.get().GetDescriptor() == rhs.get().GetDescriptor() && MessageDifferencer::Equals(lhs.get(), rhs.get()); } template std::enable_if_t>, bool> operator()(EquatableMessage, const T&) const { return false; } }; struct EquatableValueReflection final { well_known_types::DoubleValueReflection double_value_reflection; well_known_types::FloatValueReflection float_value_reflection; well_known_types::Int64ValueReflection int64_value_reflection; well_known_types::UInt64ValueReflection uint64_value_reflection; well_known_types::Int32ValueReflection int32_value_reflection; well_known_types::UInt32ValueReflection uint32_value_reflection; well_known_types::StringValueReflection string_value_reflection; well_known_types::BytesValueReflection bytes_value_reflection; well_known_types::BoolValueReflection bool_value_reflection; well_known_types::AnyReflection any_reflection; well_known_types::DurationReflection duration_reflection; well_known_types::TimestampReflection timestamp_reflection; well_known_types::ValueReflection value_reflection; well_known_types::ListValueReflection list_value_reflection; well_known_types::StructReflection struct_reflection; }; absl::StatusOr AsEquatableValue( EquatableValueReflection& reflection, const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const Descriptor* absl_nonnull descriptor, Descriptor::WellKnownType well_known_type, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { switch (well_known_type) { case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: CEL_RETURN_IF_ERROR( reflection.double_value_reflection.Initialize(descriptor)); return reflection.double_value_reflection.GetValue(message); case Descriptor::WELLKNOWNTYPE_FLOATVALUE: CEL_RETURN_IF_ERROR( reflection.float_value_reflection.Initialize(descriptor)); return static_cast( reflection.float_value_reflection.GetValue(message)); case Descriptor::WELLKNOWNTYPE_INT64VALUE: CEL_RETURN_IF_ERROR( reflection.int64_value_reflection.Initialize(descriptor)); return reflection.int64_value_reflection.GetValue(message); case Descriptor::WELLKNOWNTYPE_UINT64VALUE: CEL_RETURN_IF_ERROR( reflection.uint64_value_reflection.Initialize(descriptor)); return reflection.uint64_value_reflection.GetValue(message); case Descriptor::WELLKNOWNTYPE_INT32VALUE: CEL_RETURN_IF_ERROR( reflection.int32_value_reflection.Initialize(descriptor)); return static_cast( reflection.int32_value_reflection.GetValue(message)); case Descriptor::WELLKNOWNTYPE_UINT32VALUE: CEL_RETURN_IF_ERROR( reflection.uint32_value_reflection.Initialize(descriptor)); return static_cast( reflection.uint32_value_reflection.GetValue(message)); case Descriptor::WELLKNOWNTYPE_STRINGVALUE: CEL_RETURN_IF_ERROR( reflection.string_value_reflection.Initialize(descriptor)); return reflection.string_value_reflection.GetValue(message, scratch); case Descriptor::WELLKNOWNTYPE_BYTESVALUE: CEL_RETURN_IF_ERROR( reflection.bytes_value_reflection.Initialize(descriptor)); return reflection.bytes_value_reflection.GetValue(message, scratch); case Descriptor::WELLKNOWNTYPE_BOOLVALUE: CEL_RETURN_IF_ERROR( reflection.bool_value_reflection.Initialize(descriptor)); return reflection.bool_value_reflection.GetValue(message); case Descriptor::WELLKNOWNTYPE_VALUE: { CEL_RETURN_IF_ERROR(reflection.value_reflection.Initialize(descriptor)); const auto kind_case = reflection.value_reflection.GetKindCase(message); switch (kind_case) { case google::protobuf::Value::KIND_NOT_SET: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Value::kNullValue: return nullptr; case google::protobuf::Value::kBoolValue: return reflection.value_reflection.GetBoolValue(message); case google::protobuf::Value::kNumberValue: return reflection.value_reflection.GetNumberValue(message); case google::protobuf::Value::kStringValue: return reflection.value_reflection.GetStringValue(message, scratch); case google::protobuf::Value::kListValue: return EquatableListValue( reflection.value_reflection.GetListValue(message)); case google::protobuf::Value::kStructValue: return EquatableStruct( reflection.value_reflection.GetStructValue(message)); default: return absl::InternalError( absl::StrCat("unexpected value kind case: ", kind_case)); } } case Descriptor::WELLKNOWNTYPE_LISTVALUE: return EquatableListValue(message); case Descriptor::WELLKNOWNTYPE_STRUCT: return EquatableStruct(message); case Descriptor::WELLKNOWNTYPE_DURATION: CEL_RETURN_IF_ERROR( reflection.duration_reflection.Initialize(descriptor)); return reflection.duration_reflection.ToAbslDuration(message); case Descriptor::WELLKNOWNTYPE_TIMESTAMP: CEL_RETURN_IF_ERROR( reflection.timestamp_reflection.Initialize(descriptor)); return reflection.timestamp_reflection.ToAbslTime(message); case Descriptor::WELLKNOWNTYPE_ANY: return EquatableAny(message); default: return EquatableMessage(message); } } absl::StatusOr AsEquatableValue( EquatableValueReflection& reflection, const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const Descriptor* absl_nonnull descriptor, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return AsEquatableValue(reflection, message, descriptor, descriptor->well_known_type(), scratch); } absl::StatusOr AsEquatableValue( EquatableValueReflection& reflection, const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const FieldDescriptor* absl_nonnull field, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(!field->is_repeated() && !field->is_map()); switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: return static_cast( message.GetReflection()->GetInt32(message, field)); case FieldDescriptor::CPPTYPE_INT64: return message.GetReflection()->GetInt64(message, field); case FieldDescriptor::CPPTYPE_UINT32: return static_cast( message.GetReflection()->GetUInt32(message, field)); case FieldDescriptor::CPPTYPE_UINT64: return message.GetReflection()->GetUInt64(message, field); case FieldDescriptor::CPPTYPE_DOUBLE: return message.GetReflection()->GetDouble(message, field); case FieldDescriptor::CPPTYPE_FLOAT: return static_cast( message.GetReflection()->GetFloat(message, field)); case FieldDescriptor::CPPTYPE_BOOL: return message.GetReflection()->GetBool(message, field); case FieldDescriptor::CPPTYPE_ENUM: if (field->enum_type()->full_name() == "google.protobuf.NullValue") { return nullptr; } return static_cast( message.GetReflection()->GetEnumValue(message, field)); case FieldDescriptor::CPPTYPE_STRING: if (field->type() == FieldDescriptor::TYPE_BYTES) { return well_known_types::GetBytesField(message, field, scratch); } return well_known_types::GetStringField(message, field, scratch); case FieldDescriptor::CPPTYPE_MESSAGE: return AsEquatableValue( reflection, message.GetReflection()->GetMessage(message, field), field->message_type(), scratch); default: return absl::InternalError( absl::StrCat("unexpected field type: ", field->cpp_type_name())); } } bool IsAny(const Message& message) { return message.GetDescriptor()->well_known_type() == Descriptor::WELLKNOWNTYPE_ANY; } bool IsAnyField(const FieldDescriptor* absl_nonnull field) { return field->type() == FieldDescriptor::TYPE_MESSAGE && field->message_type()->well_known_type() == Descriptor::WELLKNOWNTYPE_ANY; } absl::StatusOr MapValueAsEquatableValue( google::protobuf::Arena* absl_nonnull arena, const DescriptorPool* absl_nonnull pool, MessageFactory* absl_nonnull factory, EquatableValueReflection& reflection, const google::protobuf::MapValueConstRef& value, const FieldDescriptor* absl_nonnull field, std::string& scratch, Unique& unpacked) { if (IsAnyField(field)) { CEL_ASSIGN_OR_RETURN(unpacked, well_known_types::UnpackAnyIfResolveable( arena, reflection.any_reflection, value.GetMessageValue(), pool, factory)); if (unpacked) { return AsEquatableValue(reflection, *unpacked, unpacked->GetDescriptor(), scratch); } return AsEquatableValue(reflection, value.GetMessageValue(), value.GetMessageValue().GetDescriptor(), scratch); } switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: return static_cast(value.GetInt32Value()); case FieldDescriptor::CPPTYPE_INT64: return value.GetInt64Value(); case FieldDescriptor::CPPTYPE_UINT32: return static_cast(value.GetUInt32Value()); case FieldDescriptor::CPPTYPE_UINT64: return value.GetUInt64Value(); case FieldDescriptor::CPPTYPE_DOUBLE: return value.GetDoubleValue(); case FieldDescriptor::CPPTYPE_FLOAT: return static_cast(value.GetFloatValue()); case FieldDescriptor::CPPTYPE_BOOL: return value.GetBoolValue(); case FieldDescriptor::CPPTYPE_ENUM: if (field->enum_type()->full_name() == "google.protobuf.NullValue") { return nullptr; } return static_cast(value.GetEnumValue()); case FieldDescriptor::CPPTYPE_STRING: if (field->type() == FieldDescriptor::TYPE_BYTES) { return well_known_types::BytesValue( absl::string_view(value.GetStringValue())); } return well_known_types::StringValue( absl::string_view(value.GetStringValue())); case FieldDescriptor::CPPTYPE_MESSAGE: { const auto& message = value.GetMessageValue(); return AsEquatableValue(reflection, message, message.GetDescriptor(), scratch); } default: return absl::InternalError( absl::StrCat("unexpected field type: ", field->cpp_type_name())); } } absl::StatusOr RepeatedFieldAsEquatableValue( google::protobuf::Arena* absl_nonnull arena, const DescriptorPool* absl_nonnull pool, MessageFactory* absl_nonnull factory, EquatableValueReflection& reflection, const Message& message, const FieldDescriptor* absl_nonnull field, int index, std::string& scratch, Unique& unpacked) { if (IsAnyField(field)) { const auto& field_value = message.GetReflection()->GetRepeatedMessage(message, field, index); CEL_ASSIGN_OR_RETURN(unpacked, well_known_types::UnpackAnyIfResolveable( arena, reflection.any_reflection, field_value, pool, factory)); if (unpacked) { return AsEquatableValue(reflection, *unpacked, unpacked->GetDescriptor(), scratch); } return AsEquatableValue(reflection, field_value, field_value.GetDescriptor(), scratch); } switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: return static_cast( message.GetReflection()->GetRepeatedInt32(message, field, index)); case FieldDescriptor::CPPTYPE_INT64: return message.GetReflection()->GetRepeatedInt64(message, field, index); case FieldDescriptor::CPPTYPE_UINT32: return static_cast( message.GetReflection()->GetRepeatedUInt32(message, field, index)); case FieldDescriptor::CPPTYPE_UINT64: return message.GetReflection()->GetRepeatedUInt64(message, field, index); case FieldDescriptor::CPPTYPE_DOUBLE: return message.GetReflection()->GetRepeatedDouble(message, field, index); case FieldDescriptor::CPPTYPE_FLOAT: return static_cast( message.GetReflection()->GetRepeatedFloat(message, field, index)); case FieldDescriptor::CPPTYPE_BOOL: return message.GetReflection()->GetRepeatedBool(message, field, index); case FieldDescriptor::CPPTYPE_ENUM: if (field->enum_type()->full_name() == "google.protobuf.NullValue") { return nullptr; } return static_cast( message.GetReflection()->GetRepeatedEnumValue(message, field, index)); case FieldDescriptor::CPPTYPE_STRING: if (field->type() == FieldDescriptor::TYPE_BYTES) { return well_known_types::GetRepeatedBytesField(message, field, index, scratch); } return well_known_types::GetRepeatedStringField(message, field, index, scratch); case FieldDescriptor::CPPTYPE_MESSAGE: { const auto& submessage = message.GetReflection()->GetRepeatedMessage(message, field, index); return AsEquatableValue(reflection, submessage, submessage.GetDescriptor(), scratch); } default: return absl::InternalError( absl::StrCat("unexpected field type: ", field->cpp_type_name())); } } // Compare two `EquatableValue` for equality. bool EquatableValueEquals(const EquatableValue& lhs, const EquatableValue& rhs) { return absl::visit( absl::Overload(NullValueEqualer{}, BoolValueEqualer{}, BytesValueEqualer{}, IntValueEqualer{}, UintValueEqualer{}, DoubleValueEqualer{}, StringValueEqualer{}, DurationEqualer{}, TimestampEqualer{}, ListValueEqualer{}, StructEqualer{}, AnyEqualer{}, MessageEqualer{}), lhs, rhs); } // Attempts to coalesce one map key to another. Returns true if it was possible, // false otherwise. bool CoalesceMapKey(const google::protobuf::MapKey& src, FieldDescriptor::CppType dest_type, google::protobuf::MapKey* absl_nonnull dest) { switch (src.type()) { case FieldDescriptor::CPPTYPE_BOOL: if (dest_type != FieldDescriptor::CPPTYPE_BOOL) { return false; } dest->SetBoolValue(src.GetBoolValue()); return true; case FieldDescriptor::CPPTYPE_INT32: { const auto src_value = src.GetInt32Value(); switch (dest_type) { case FieldDescriptor::CPPTYPE_INT32: dest->SetInt32Value(src_value); return true; case FieldDescriptor::CPPTYPE_INT64: dest->SetInt64Value(src_value); return true; case FieldDescriptor::CPPTYPE_UINT32: if (src_value < 0) { return false; } dest->SetUInt32Value(static_cast(src_value)); return true; case FieldDescriptor::CPPTYPE_UINT64: if (src_value < 0) { return false; } dest->SetUInt64Value(static_cast(src_value)); return true; default: return false; } } case FieldDescriptor::CPPTYPE_INT64: { const auto src_value = src.GetInt64Value(); switch (dest_type) { case FieldDescriptor::CPPTYPE_INT32: if (src_value < std::numeric_limits::min() || src_value > std::numeric_limits::max()) { return false; } dest->SetInt32Value(static_cast(src_value)); return true; case FieldDescriptor::CPPTYPE_INT64: dest->SetInt64Value(src_value); return true; case FieldDescriptor::CPPTYPE_UINT32: if (src_value < 0 || src_value > std::numeric_limits::max()) { return false; } dest->SetUInt32Value(static_cast(src_value)); return true; case FieldDescriptor::CPPTYPE_UINT64: if (src_value < 0) { return false; } dest->SetUInt64Value(static_cast(src_value)); return true; default: return false; } } case FieldDescriptor::CPPTYPE_UINT32: { const auto src_value = src.GetUInt32Value(); switch (dest_type) { case FieldDescriptor::CPPTYPE_INT32: if (src_value > std::numeric_limits::max()) { return false; } dest->SetInt32Value(static_cast(src_value)); return true; case FieldDescriptor::CPPTYPE_INT64: dest->SetInt64Value(static_cast(src_value)); return true; case FieldDescriptor::CPPTYPE_UINT32: dest->SetUInt32Value(src_value); return true; case FieldDescriptor::CPPTYPE_UINT64: dest->SetUInt64Value(static_cast(src_value)); return true; default: return false; } } case FieldDescriptor::CPPTYPE_UINT64: { const auto src_value = src.GetUInt64Value(); switch (dest_type) { case FieldDescriptor::CPPTYPE_INT32: if (src_value > std::numeric_limits::max()) { return false; } dest->SetInt32Value(static_cast(src_value)); return true; case FieldDescriptor::CPPTYPE_INT64: if (src_value > std::numeric_limits::max()) { return false; } dest->SetInt64Value(static_cast(src_value)); return true; case FieldDescriptor::CPPTYPE_UINT32: if (src_value > std::numeric_limits::max()) { return false; } dest->SetUInt32Value(src_value); return true; case FieldDescriptor::CPPTYPE_UINT64: dest->SetUInt64Value(src_value); return true; default: return false; } } case FieldDescriptor::CPPTYPE_STRING: if (dest_type != FieldDescriptor::CPPTYPE_STRING) { return false; } dest->SetStringValue(src.GetStringValue()); return true; default: // Only bool, integrals, and string may be map keys. ABSL_UNREACHABLE(); } } // Bits used for categorizing equality. Can be used to cheaply check whether two // categories are comparable for equality by performing an AND and checking if // the result against `kNone`. enum class EquatableCategory { kNone = 0, kNullLike = 1 << 0, kBoolLike = 1 << 1, kNumericLike = 1 << 2, kBytesLike = 1 << 3, kStringLike = 1 << 4, kList = 1 << 5, kMap = 1 << 6, kMessage = 1 << 7, kDuration = 1 << 8, kTimestamp = 1 << 9, kAny = kNullLike | kBoolLike | kNumericLike | kBytesLike | kStringLike | kList | kMap | kMessage | kDuration | kTimestamp, kValue = kNullLike | kBoolLike | kNumericLike | kStringLike | kList | kMap, }; constexpr EquatableCategory operator&(EquatableCategory lhs, EquatableCategory rhs) { return static_cast( static_cast>(lhs) & static_cast>(rhs)); } constexpr bool operator==(EquatableCategory lhs, EquatableCategory rhs) { return static_cast>(lhs) == static_cast>(rhs); } EquatableCategory GetEquatableCategory( const Descriptor* absl_nonnull descriptor) { switch (descriptor->well_known_type()) { case Descriptor::WELLKNOWNTYPE_BOOLVALUE: return EquatableCategory::kBoolLike; case Descriptor::WELLKNOWNTYPE_FLOATVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_INT32VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_UINT32VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_INT64VALUE: ABSL_FALLTHROUGH_INTENDED; case Descriptor::WELLKNOWNTYPE_UINT64VALUE: return EquatableCategory::kNumericLike; case Descriptor::WELLKNOWNTYPE_BYTESVALUE: return EquatableCategory::kBytesLike; case Descriptor::WELLKNOWNTYPE_STRINGVALUE: return EquatableCategory::kStringLike; case Descriptor::WELLKNOWNTYPE_VALUE: return EquatableCategory::kValue; case Descriptor::WELLKNOWNTYPE_LISTVALUE: return EquatableCategory::kList; case Descriptor::WELLKNOWNTYPE_STRUCT: return EquatableCategory::kMap; case Descriptor::WELLKNOWNTYPE_ANY: return EquatableCategory::kAny; case Descriptor::WELLKNOWNTYPE_DURATION: return EquatableCategory::kDuration; case Descriptor::WELLKNOWNTYPE_TIMESTAMP: return EquatableCategory::kTimestamp; default: return EquatableCategory::kAny; } } EquatableCategory GetEquatableFieldCategory( const FieldDescriptor* absl_nonnull field) { switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_ENUM: return field->enum_type()->full_name() == "google.protobuf.NullValue" ? EquatableCategory::kNullLike : EquatableCategory::kNumericLike; case FieldDescriptor::CPPTYPE_BOOL: return EquatableCategory::kBoolLike; case FieldDescriptor::CPPTYPE_FLOAT: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::CPPTYPE_DOUBLE: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::CPPTYPE_INT32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::CPPTYPE_UINT32: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::CPPTYPE_INT64: ABSL_FALLTHROUGH_INTENDED; case FieldDescriptor::CPPTYPE_UINT64: return EquatableCategory::kNumericLike; case FieldDescriptor::CPPTYPE_STRING: return field->type() == FieldDescriptor::TYPE_BYTES ? EquatableCategory::kBytesLike : EquatableCategory::kStringLike; case FieldDescriptor::CPPTYPE_MESSAGE: return GetEquatableCategory(field->message_type()); default: // Ugh. Force any future additions to compare instead of short circuiting. return EquatableCategory::kAny; } } class MessageEqualsState final { public: MessageEqualsState(const DescriptorPool* absl_nonnull pool, MessageFactory* absl_nonnull factory) : pool_(pool), factory_(factory) {} // Equality between messages. absl::StatusOr Equals(const Message& lhs, const Message& rhs) { const auto* lhs_descriptor = lhs.GetDescriptor(); const auto* rhs_descriptor = rhs.GetDescriptor(); // Deal with well known types, starting with any. auto lhs_well_known_type = lhs_descriptor->well_known_type(); auto rhs_well_known_type = rhs_descriptor->well_known_type(); const Message* absl_nonnull lhs_ptr = &lhs; const Message* absl_nonnull rhs_ptr = &rhs; Unique lhs_unpacked; Unique rhs_unpacked; // Deal with any first. We could in theory check if we should bother // unpacking, but that is more complicated. We can always implement it // later. if (lhs_well_known_type == Descriptor::WELLKNOWNTYPE_ANY) { CEL_ASSIGN_OR_RETURN( lhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, lhs_reflection_.any_reflection, lhs, pool_, factory_)); if (lhs_unpacked) { lhs_ptr = cel::to_address(lhs_unpacked); lhs_descriptor = lhs_ptr->GetDescriptor(); lhs_well_known_type = lhs_descriptor->well_known_type(); } } if (rhs_well_known_type == Descriptor::WELLKNOWNTYPE_ANY) { CEL_ASSIGN_OR_RETURN( rhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, rhs_reflection_.any_reflection, rhs, pool_, factory_)); if (rhs_unpacked) { rhs_ptr = cel::to_address(rhs_unpacked); rhs_descriptor = rhs_ptr->GetDescriptor(); rhs_well_known_type = rhs_descriptor->well_known_type(); } } CEL_ASSIGN_OR_RETURN( auto lhs_value, AsEquatableValue(lhs_reflection_, *lhs_ptr, lhs_descriptor, lhs_well_known_type, lhs_scratch_)); CEL_ASSIGN_OR_RETURN( auto rhs_value, AsEquatableValue(rhs_reflection_, *rhs_ptr, rhs_descriptor, rhs_well_known_type, rhs_scratch_)); return EquatableValueEquals(lhs_value, rhs_value); } // Equality between map message fields. absl::StatusOr MapFieldEquals( const Message& lhs, const FieldDescriptor* absl_nonnull lhs_field, const Message& rhs, const FieldDescriptor* absl_nonnull rhs_field) { ABSL_DCHECK(lhs_field->is_map()); ABSL_DCHECK_EQ(lhs_field->containing_type(), lhs.GetDescriptor()); ABSL_DCHECK(rhs_field->is_map()); ABSL_DCHECK_EQ(rhs_field->containing_type(), rhs.GetDescriptor()); const auto* lhs_entry = lhs_field->message_type(); const auto* lhs_entry_key_field = lhs_entry->map_key(); const auto* lhs_entry_value_field = lhs_entry->map_value(); const auto* rhs_entry = rhs_field->message_type(); const auto* rhs_entry_key_field = rhs_entry->map_key(); const auto* rhs_entry_value_field = rhs_entry->map_value(); // Perform cheap test which checks whether the left and right can even be // compared for equality. if (lhs_field != rhs_field && ((GetEquatableFieldCategory(lhs_entry_key_field) & GetEquatableFieldCategory(rhs_entry_key_field)) == EquatableCategory::kNone || (GetEquatableFieldCategory(lhs_entry_value_field) & GetEquatableFieldCategory(rhs_entry_value_field)) == EquatableCategory::kNone)) { // Short-circuit. return false; } const auto* lhs_reflection = lhs.GetReflection(); const auto* rhs_reflection = rhs.GetReflection(); if (MapSize(*lhs_reflection, lhs, *lhs_field) != MapSize(*rhs_reflection, rhs, *rhs_field)) { return false; } auto lhs_begin = ConstMapBegin(*lhs_reflection, lhs, *lhs_field); const auto lhs_end = ConstMapEnd(*lhs_reflection, lhs, *lhs_field); Unique lhs_unpacked; EquatableValue lhs_value; Unique rhs_unpacked; EquatableValue rhs_value; google::protobuf::MapKey rhs_map_key; google::protobuf::MapValueConstRef rhs_map_value; for (; lhs_begin != lhs_end; ++lhs_begin) { if (!CoalesceMapKey(lhs_begin.GetKey(), rhs_entry_key_field->cpp_type(), &rhs_map_key)) { return false; } if (!LookupMapValue(*rhs_reflection, rhs, *rhs_field, rhs_map_key, &rhs_map_value)) { return false; } CEL_ASSIGN_OR_RETURN(lhs_value, MapValueAsEquatableValue( &arena_, pool_, factory_, lhs_reflection_, lhs_begin.GetValueRef(), lhs_entry_value_field, lhs_scratch_, lhs_unpacked)); CEL_ASSIGN_OR_RETURN( rhs_value, MapValueAsEquatableValue(&arena_, pool_, factory_, rhs_reflection_, rhs_map_value, rhs_entry_value_field, rhs_scratch_, rhs_unpacked)); if (!EquatableValueEquals(lhs_value, rhs_value)) { return false; } } return true; } // Equality between repeated message fields. absl::StatusOr RepeatedFieldEquals( const Message& lhs, const FieldDescriptor* absl_nonnull lhs_field, const Message& rhs, const FieldDescriptor* absl_nonnull rhs_field) { ABSL_DCHECK(lhs_field->is_repeated() && !lhs_field->is_map()); ABSL_DCHECK_EQ(lhs_field->containing_type(), lhs.GetDescriptor()); ABSL_DCHECK(rhs_field->is_repeated() && !rhs_field->is_map()); ABSL_DCHECK_EQ(rhs_field->containing_type(), rhs.GetDescriptor()); // Perform cheap test which checks whether the left and right can even be // compared for equality. if (lhs_field != rhs_field && (GetEquatableFieldCategory(lhs_field) & GetEquatableFieldCategory(rhs_field)) == EquatableCategory::kNone) { // Short-circuit. return false; } const auto* lhs_reflection = lhs.GetReflection(); const auto* rhs_reflection = rhs.GetReflection(); const auto size = lhs_reflection->FieldSize(lhs, lhs_field); if (size != rhs_reflection->FieldSize(rhs, rhs_field)) { return false; } Unique lhs_unpacked; EquatableValue lhs_value; Unique rhs_unpacked; EquatableValue rhs_value; for (int i = 0; i < size; ++i) { CEL_ASSIGN_OR_RETURN(lhs_value, RepeatedFieldAsEquatableValue( &arena_, pool_, factory_, lhs_reflection_, lhs, lhs_field, i, lhs_scratch_, lhs_unpacked)); CEL_ASSIGN_OR_RETURN(rhs_value, RepeatedFieldAsEquatableValue( &arena_, pool_, factory_, rhs_reflection_, rhs, rhs_field, i, rhs_scratch_, rhs_unpacked)); if (!EquatableValueEquals(lhs_value, rhs_value)) { return false; } } return true; } // Equality between singular message fields and/or messages. If the field is // `nullptr`, we are performing equality on the message itself rather than the // corresponding field. absl::StatusOr SingularFieldEquals( const Message& lhs, const FieldDescriptor* absl_nullable lhs_field, const Message& rhs, const FieldDescriptor* absl_nullable rhs_field) { ABSL_DCHECK(lhs_field == nullptr || (!lhs_field->is_repeated() && !lhs_field->is_map())); ABSL_DCHECK(lhs_field == nullptr || lhs_field->containing_type() == lhs.GetDescriptor()); ABSL_DCHECK(rhs_field == nullptr || (!rhs_field->is_repeated() && !rhs_field->is_map())); ABSL_DCHECK(rhs_field == nullptr || rhs_field->containing_type() == rhs.GetDescriptor()); // Perform cheap test which checks whether the left and right can even be // compared for equality. if (lhs_field != rhs_field && ((lhs_field != nullptr ? GetEquatableFieldCategory(lhs_field) : GetEquatableCategory(lhs.GetDescriptor())) & (rhs_field != nullptr ? GetEquatableFieldCategory(rhs_field) : GetEquatableCategory(rhs.GetDescriptor()))) == EquatableCategory::kNone) { // Short-circuit. return false; } const Message* absl_nonnull lhs_ptr = &lhs; const Message* absl_nonnull rhs_ptr = &rhs; Unique lhs_unpacked; Unique rhs_unpacked; if (lhs_field != nullptr && IsAnyField(lhs_field)) { CEL_ASSIGN_OR_RETURN(lhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, lhs_reflection_.any_reflection, lhs.GetReflection()->GetMessage(lhs, lhs_field), pool_, factory_)); if (lhs_unpacked) { lhs_ptr = cel::to_address(lhs_unpacked); lhs_field = nullptr; } } else if (lhs_field == nullptr && IsAny(lhs)) { CEL_ASSIGN_OR_RETURN( lhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, lhs_reflection_.any_reflection, lhs, pool_, factory_)); if (lhs_unpacked) { lhs_ptr = cel::to_address(lhs_unpacked); } } if (rhs_field != nullptr && IsAnyField(rhs_field)) { CEL_ASSIGN_OR_RETURN(rhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, rhs_reflection_.any_reflection, rhs.GetReflection()->GetMessage(rhs, rhs_field), pool_, factory_)); if (rhs_unpacked) { rhs_ptr = cel::to_address(rhs_unpacked); rhs_field = nullptr; } } else if (rhs_field == nullptr && IsAny(rhs)) { CEL_ASSIGN_OR_RETURN( rhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, rhs_reflection_.any_reflection, rhs, pool_, factory_)); if (rhs_unpacked) { rhs_ptr = cel::to_address(rhs_unpacked); } } EquatableValue lhs_value; if (lhs_field != nullptr) { CEL_ASSIGN_OR_RETURN( lhs_value, AsEquatableValue(lhs_reflection_, *lhs_ptr, lhs_field, lhs_scratch_)); } else { CEL_ASSIGN_OR_RETURN( lhs_value, AsEquatableValue(lhs_reflection_, *lhs_ptr, lhs_ptr->GetDescriptor(), lhs_scratch_)); } EquatableValue rhs_value; if (rhs_field != nullptr) { CEL_ASSIGN_OR_RETURN( rhs_value, AsEquatableValue(rhs_reflection_, *rhs_ptr, rhs_field, rhs_scratch_)); } else { CEL_ASSIGN_OR_RETURN( rhs_value, AsEquatableValue(rhs_reflection_, *rhs_ptr, rhs_ptr->GetDescriptor(), rhs_scratch_)); } return EquatableValueEquals(lhs_value, rhs_value); } absl::StatusOr FieldEquals( const Message& lhs, const FieldDescriptor* absl_nullable lhs_field, const Message& rhs, const FieldDescriptor* absl_nullable rhs_field) { ABSL_DCHECK(lhs_field != nullptr || rhs_field != nullptr); // Both cannot be null. if (lhs_field != nullptr && lhs_field->is_map()) { // map == map // map == google.protobuf.Value // map == google.protobuf.Struct // map == google.protobuf.Any // Right hand side should be a map, `google.protobuf.Value`, // `google.protobuf.Struct`, or `google.protobuf.Any`. if (rhs_field != nullptr && rhs_field->is_map()) { // map == map return MapFieldEquals(lhs, lhs_field, rhs, rhs_field); } if (rhs_field != nullptr && (rhs_field->is_repeated() || rhs_field->type() != FieldDescriptor::TYPE_MESSAGE)) { return false; } const Message* absl_nullable rhs_packed = nullptr; Unique rhs_unpacked; if (rhs_field != nullptr && IsAnyField(rhs_field)) { rhs_packed = &rhs.GetReflection()->GetMessage(rhs, rhs_field); } else if (rhs_field == nullptr && IsAny(rhs)) { rhs_packed = &rhs; } if (rhs_packed != nullptr) { CEL_RETURN_IF_ERROR(rhs_reflection_.any_reflection.Initialize( rhs_packed->GetDescriptor())); auto rhs_type_url = rhs_reflection_.any_reflection.GetTypeUrl( *rhs_packed, rhs_scratch_); if (!rhs_type_url.ConsumePrefix("type.googleapis.com/") && !rhs_type_url.ConsumePrefix("type.googleprod.com/")) { return false; } if (rhs_type_url != "google.protobuf.Value" && rhs_type_url != "google.protobuf.Struct" && rhs_type_url != "google.protobuf.Any") { return false; } CEL_ASSIGN_OR_RETURN(rhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, rhs_reflection_.any_reflection, *rhs_packed, pool_, factory_)); if (rhs_unpacked) { rhs_field = nullptr; } } const Message* absl_nonnull rhs_message = rhs_field != nullptr ? &rhs.GetReflection()->GetMessage(rhs, rhs_field) : rhs_unpacked != nullptr ? cel::to_address(rhs_unpacked) : &rhs; const auto* rhs_descriptor = rhs_message->GetDescriptor(); const auto rhs_well_known_type = rhs_descriptor->well_known_type(); switch (rhs_well_known_type) { case Descriptor::WELLKNOWNTYPE_VALUE: { // map == google.protobuf.Value CEL_RETURN_IF_ERROR( rhs_reflection_.value_reflection.Initialize(rhs_descriptor)); if (rhs_reflection_.value_reflection.GetKindCase(*rhs_message) != google::protobuf::Value::kStructValue) { return false; } CEL_RETURN_IF_ERROR(rhs_reflection_.struct_reflection.Initialize( rhs_reflection_.value_reflection.GetStructDescriptor())); return MapFieldEquals( lhs, lhs_field, rhs_reflection_.value_reflection.GetStructValue(*rhs_message), rhs_reflection_.struct_reflection.GetFieldsDescriptor()); } case Descriptor::WELLKNOWNTYPE_STRUCT: { // map == google.protobuf.Struct CEL_RETURN_IF_ERROR( rhs_reflection_.struct_reflection.Initialize(rhs_descriptor)); return MapFieldEquals( lhs, lhs_field, *rhs_message, rhs_reflection_.struct_reflection.GetFieldsDescriptor()); } default: return false; } // Explicitly unreachable, for ease of reading. Control never leaves this // if statement. ABSL_UNREACHABLE(); } if (rhs_field != nullptr && rhs_field->is_map()) { // google.protobuf.Value == map // google.protobuf.Struct == map // google.protobuf.Any == map // Left hand side should be singular `google.protobuf.Value` // `google.protobuf.Struct`, or `google.protobuf.Any`. ABSL_DCHECK(lhs_field == nullptr || !lhs_field->is_map()); // Handled above. if (lhs_field != nullptr && (lhs_field->is_repeated() || lhs_field->type() != FieldDescriptor::TYPE_MESSAGE)) { return false; } const Message* absl_nullable lhs_packed = nullptr; Unique lhs_unpacked; if (lhs_field != nullptr && IsAnyField(lhs_field)) { lhs_packed = &lhs.GetReflection()->GetMessage(lhs, lhs_field); } else if (lhs_field == nullptr && IsAny(lhs)) { lhs_packed = &lhs; } if (lhs_packed != nullptr) { CEL_RETURN_IF_ERROR(lhs_reflection_.any_reflection.Initialize( lhs_packed->GetDescriptor())); auto lhs_type_url = lhs_reflection_.any_reflection.GetTypeUrl( *lhs_packed, lhs_scratch_); if (!lhs_type_url.ConsumePrefix("type.googleapis.com/") && !lhs_type_url.ConsumePrefix("type.googleprod.com/")) { return false; } if (lhs_type_url != "google.protobuf.Value" && lhs_type_url != "google.protobuf.Struct" && lhs_type_url != "google.protobuf.Any") { return false; } CEL_ASSIGN_OR_RETURN(lhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, lhs_reflection_.any_reflection, *lhs_packed, pool_, factory_)); if (lhs_unpacked) { lhs_field = nullptr; } } const Message* absl_nonnull lhs_message = lhs_field != nullptr ? &lhs.GetReflection()->GetMessage(lhs, lhs_field) : lhs_unpacked != nullptr ? cel::to_address(lhs_unpacked) : &lhs; const auto* lhs_descriptor = lhs_message->GetDescriptor(); const auto lhs_well_known_type = lhs_descriptor->well_known_type(); switch (lhs_well_known_type) { case Descriptor::WELLKNOWNTYPE_VALUE: { // map == google.protobuf.Value CEL_RETURN_IF_ERROR( lhs_reflection_.value_reflection.Initialize(lhs_descriptor)); if (lhs_reflection_.value_reflection.GetKindCase(*lhs_message) != google::protobuf::Value::kStructValue) { return false; } CEL_RETURN_IF_ERROR(lhs_reflection_.struct_reflection.Initialize( lhs_reflection_.value_reflection.GetStructDescriptor())); return MapFieldEquals( lhs_reflection_.value_reflection.GetStructValue(*lhs_message), lhs_reflection_.struct_reflection.GetFieldsDescriptor(), rhs, rhs_field); } case Descriptor::WELLKNOWNTYPE_STRUCT: { // map == google.protobuf.Struct CEL_RETURN_IF_ERROR( lhs_reflection_.struct_reflection.Initialize(lhs_descriptor)); return MapFieldEquals( *lhs_message, lhs_reflection_.struct_reflection.GetFieldsDescriptor(), rhs, rhs_field); } default: return false; } // Explicitly unreachable, for ease of reading. Control never leaves this // if statement. ABSL_UNREACHABLE(); } ABSL_DCHECK(lhs_field == nullptr || !lhs_field->is_map()); // Handled above. ABSL_DCHECK(rhs_field == nullptr || !rhs_field->is_map()); // Handled above. if (lhs_field != nullptr && lhs_field->is_repeated()) { // repeated == repeated // repeated == google.protobuf.Value // repeated == google.protobuf.ListValue // repeated == google.protobuf.Any // Right hand side should be a repeated, `google.protobuf.Value`, // `google.protobuf.ListValue`, or `google.protobuf.Any`. if (rhs_field != nullptr && rhs_field->is_repeated()) { // map == map return RepeatedFieldEquals(lhs, lhs_field, rhs, rhs_field); } if (rhs_field != nullptr && rhs_field->type() != FieldDescriptor::TYPE_MESSAGE) { return false; } const Message* absl_nullable rhs_packed = nullptr; Unique rhs_unpacked; if (rhs_field != nullptr && IsAnyField(rhs_field)) { rhs_packed = &rhs.GetReflection()->GetMessage(rhs, rhs_field); } else if (rhs_field == nullptr && IsAny(rhs)) { rhs_packed = &rhs; } if (rhs_packed != nullptr) { CEL_RETURN_IF_ERROR(rhs_reflection_.any_reflection.Initialize( rhs_packed->GetDescriptor())); auto rhs_type_url = rhs_reflection_.any_reflection.GetTypeUrl( *rhs_packed, rhs_scratch_); if (!rhs_type_url.ConsumePrefix("type.googleapis.com/") && !rhs_type_url.ConsumePrefix("type.googleprod.com/")) { return false; } if (rhs_type_url != "google.protobuf.Value" && rhs_type_url != "google.protobuf.ListValue" && rhs_type_url != "google.protobuf.Any") { return false; } CEL_ASSIGN_OR_RETURN(rhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, rhs_reflection_.any_reflection, *rhs_packed, pool_, factory_)); if (rhs_unpacked) { rhs_field = nullptr; } } const Message* absl_nonnull rhs_message = rhs_field != nullptr ? &rhs.GetReflection()->GetMessage(rhs, rhs_field) : rhs_unpacked != nullptr ? cel::to_address(rhs_unpacked) : &rhs; const auto* rhs_descriptor = rhs_message->GetDescriptor(); const auto rhs_well_known_type = rhs_descriptor->well_known_type(); switch (rhs_well_known_type) { case Descriptor::WELLKNOWNTYPE_VALUE: { // map == google.protobuf.Value CEL_RETURN_IF_ERROR( rhs_reflection_.value_reflection.Initialize(rhs_descriptor)); if (rhs_reflection_.value_reflection.GetKindCase(*rhs_message) != google::protobuf::Value::kListValue) { return false; } CEL_RETURN_IF_ERROR(rhs_reflection_.list_value_reflection.Initialize( rhs_reflection_.value_reflection.GetListValueDescriptor())); return RepeatedFieldEquals( lhs, lhs_field, rhs_reflection_.value_reflection.GetListValue(*rhs_message), rhs_reflection_.list_value_reflection.GetValuesDescriptor()); } case Descriptor::WELLKNOWNTYPE_LISTVALUE: { // map == google.protobuf.ListValue CEL_RETURN_IF_ERROR( rhs_reflection_.list_value_reflection.Initialize(rhs_descriptor)); return RepeatedFieldEquals( lhs, lhs_field, *rhs_message, rhs_reflection_.list_value_reflection.GetValuesDescriptor()); } default: return false; } // Explicitly unreachable, for ease of reading. Control never leaves this // if statement. ABSL_UNREACHABLE(); } if (rhs_field != nullptr && rhs_field->is_repeated()) { // google.protobuf.Value == repeated // google.protobuf.ListValue == repeated // google.protobuf.Any == repeated // Left hand side should be singular `google.protobuf.Value` // `google.protobuf.ListValue`, or `google.protobuf.Any`. ABSL_DCHECK(lhs_field == nullptr || !lhs_field->is_repeated()); // Handled above. if (lhs_field != nullptr && lhs_field->type() != FieldDescriptor::TYPE_MESSAGE) { return false; } const Message* absl_nullable lhs_packed = nullptr; Unique lhs_unpacked; if (lhs_field != nullptr && IsAnyField(lhs_field)) { lhs_packed = &lhs.GetReflection()->GetMessage(lhs, lhs_field); } else if (lhs_field == nullptr && IsAny(lhs)) { lhs_packed = &lhs; } if (lhs_packed != nullptr) { CEL_RETURN_IF_ERROR(lhs_reflection_.any_reflection.Initialize( lhs_packed->GetDescriptor())); auto lhs_type_url = lhs_reflection_.any_reflection.GetTypeUrl( *lhs_packed, lhs_scratch_); if (!lhs_type_url.ConsumePrefix("type.googleapis.com/") && !lhs_type_url.ConsumePrefix("type.googleprod.com/")) { return false; } if (lhs_type_url != "google.protobuf.Value" && lhs_type_url != "google.protobuf.ListValue" && lhs_type_url != "google.protobuf.Any") { return false; } CEL_ASSIGN_OR_RETURN(lhs_unpacked, well_known_types::UnpackAnyIfResolveable( &arena_, lhs_reflection_.any_reflection, *lhs_packed, pool_, factory_)); if (lhs_unpacked) { lhs_field = nullptr; } } const Message* absl_nonnull lhs_message = lhs_field != nullptr ? &lhs.GetReflection()->GetMessage(lhs, lhs_field) : lhs_unpacked != nullptr ? cel::to_address(lhs_unpacked) : &lhs; const auto* lhs_descriptor = lhs_message->GetDescriptor(); const auto lhs_well_known_type = lhs_descriptor->well_known_type(); switch (lhs_well_known_type) { case Descriptor::WELLKNOWNTYPE_VALUE: { // map == google.protobuf.Value CEL_RETURN_IF_ERROR( lhs_reflection_.value_reflection.Initialize(lhs_descriptor)); if (lhs_reflection_.value_reflection.GetKindCase(*lhs_message) != google::protobuf::Value::kListValue) { return false; } CEL_RETURN_IF_ERROR(lhs_reflection_.list_value_reflection.Initialize( lhs_reflection_.value_reflection.GetListValueDescriptor())); return RepeatedFieldEquals( lhs_reflection_.value_reflection.GetListValue(*lhs_message), lhs_reflection_.list_value_reflection.GetValuesDescriptor(), rhs, rhs_field); } case Descriptor::WELLKNOWNTYPE_LISTVALUE: { // map == google.protobuf.ListValue CEL_RETURN_IF_ERROR( lhs_reflection_.list_value_reflection.Initialize(lhs_descriptor)); return RepeatedFieldEquals( *lhs_message, lhs_reflection_.list_value_reflection.GetValuesDescriptor(), rhs, rhs_field); } default: return false; } // Explicitly unreachable, for ease of reading. Control never leaves this // if statement. ABSL_UNREACHABLE(); } return SingularFieldEquals(lhs, lhs_field, rhs, rhs_field); } private: const DescriptorPool* absl_nonnull const pool_; MessageFactory* absl_nonnull const factory_; google::protobuf::Arena arena_; EquatableValueReflection lhs_reflection_; EquatableValueReflection rhs_reflection_; std::string lhs_scratch_; std::string rhs_scratch_; }; } // namespace absl::StatusOr MessageEquals(const Message& lhs, const Message& rhs, const DescriptorPool* absl_nonnull pool, MessageFactory* absl_nonnull factory) { ABSL_DCHECK(pool != nullptr); ABSL_DCHECK(factory != nullptr); if (&lhs == &rhs) { return true; } // MessageEqualsState has quite a large size, so we allocate it on the heap. // Ideally we should just hold most of the state at runtime in something like // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. return std::make_unique(pool, factory)->Equals(lhs, rhs); } absl::StatusOr MessageFieldEquals( const Message& lhs, const FieldDescriptor* absl_nonnull lhs_field, const Message& rhs, const FieldDescriptor* absl_nonnull rhs_field, const DescriptorPool* absl_nonnull pool, MessageFactory* absl_nonnull factory) { ABSL_DCHECK(lhs_field != nullptr); ABSL_DCHECK(rhs_field != nullptr); ABSL_DCHECK(pool != nullptr); ABSL_DCHECK(factory != nullptr); if (&lhs == &rhs && lhs_field == rhs_field) { return true; } // MessageEqualsState has quite a large size, so we allocate it on the heap. // Ideally we should just hold most of the state at runtime in something like // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. return std::make_unique(pool, factory) ->FieldEquals(lhs, lhs_field, rhs, rhs_field); } absl::StatusOr MessageFieldEquals( const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, const google::protobuf::FieldDescriptor* absl_nonnull rhs_field, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory) { ABSL_DCHECK(rhs_field != nullptr); ABSL_DCHECK(pool != nullptr); ABSL_DCHECK(factory != nullptr); // MessageEqualsState has quite a large size, so we allocate it on the heap. // Ideally we should just hold most of the state at runtime in something like // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. return std::make_unique(pool, factory) ->FieldEquals(lhs, nullptr, rhs, rhs_field); } absl::StatusOr MessageFieldEquals( const google::protobuf::Message& lhs, const google::protobuf::FieldDescriptor* absl_nonnull lhs_field, const google::protobuf::Message& rhs, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory) { ABSL_DCHECK(lhs_field != nullptr); ABSL_DCHECK(pool != nullptr); ABSL_DCHECK(factory != nullptr); // MessageEqualsState has quite a large size, so we allocate it on the heap. // Ideally we should just hold most of the state at runtime in something like // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. return std::make_unique(pool, factory) ->FieldEquals(lhs, lhs_field, rhs, nullptr); } } // namespace cel::internal ================================================ FILE: internal/message_equality.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::internal { // Tests whether one message is equal to another following CEL equality // semantics. absl::StatusOr MessageEquals( const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory); // Tests whether one message field is equal to another following CEL equality // semantics. absl::StatusOr MessageFieldEquals( const google::protobuf::Message& lhs, const google::protobuf::FieldDescriptor* absl_nonnull lhs_field, const google::protobuf::Message& rhs, const google::protobuf::FieldDescriptor* absl_nonnull rhs_field, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory); absl::StatusOr MessageFieldEquals( const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, const google::protobuf::FieldDescriptor* absl_nonnull rhs_field, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory); absl::StatusOr MessageFieldEquals( const google::protobuf::Message& lhs, const google::protobuf::FieldDescriptor* absl_nonnull lhs_field, const google::protobuf::Message& rhs, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ ================================================ FILE: internal/message_equality_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/message_equality.h" #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/log/die_if_null.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/allocator.h" #include "common/memory.h" #include "internal/message_type_name.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "internal/well_known_types.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::internal { namespace { using ::absl_testing::IsOkAndHolds; using ::testing::IsFalse; using ::testing::IsTrue; using ::testing::TestParamInfo; using ::testing::TestWithParam; using ::testing::ValuesIn; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; google::protobuf::Arena* GetTestArena() { static absl::NoDestructor arena; return &*arena; } template google::protobuf::Message* ParseTextProto(absl::string_view text) { return DynamicParseTextProto(GetTestArena(), text, GetTestingDescriptorPool(), GetTestingMessageFactory()); } struct UnaryMessageEqualsTestParam { std::string name; std::vector ops; bool equal; }; std::string UnaryMessageEqualsTestParamName( const TestParamInfo& param_info) { return param_info.param.name; } using UnaryMessageEqualsTest = TestWithParam; google::protobuf::Message* PackMessage(const google::protobuf::Message& message) { const auto* descriptor = ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( MessageTypeNameFor())); const auto* prototype = ABSL_DIE_IF_NULL(GetTestingMessageFactory()->GetPrototype(descriptor)); auto instance = prototype->New(GetTestArena()); auto reflection = well_known_types::GetAnyReflectionOrDie(descriptor); reflection.SetTypeUrl( cel::to_address(instance), absl::StrCat("type.googleapis.com/", message.GetTypeName())); absl::Cord value; ABSL_CHECK(message.SerializeToString(&value)); reflection.SetValue(cel::to_address(instance), value); return instance; } TEST_P(UnaryMessageEqualsTest, Equals) { const auto* pool = GetTestingDescriptorPool(); auto* factory = GetTestingMessageFactory(); const auto& test_case = GetParam(); for (const auto& lhs : test_case.ops) { for (const auto& rhs : test_case.ops) { if (!test_case.equal && &lhs == &rhs) { continue; } EXPECT_THAT(MessageEquals(*lhs, *rhs, pool, factory), IsOkAndHolds(test_case.equal)) << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); EXPECT_THAT(MessageEquals(*rhs, *lhs, pool, factory), IsOkAndHolds(test_case.equal)) << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); // Test any. auto lhs_any = PackMessage(*lhs); auto rhs_any = PackMessage(*rhs); EXPECT_THAT(MessageEquals(*lhs_any, *rhs, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_any->ShortDebugString() << " " << rhs->ShortDebugString(); EXPECT_THAT(MessageEquals(*lhs, *rhs_any, pool, factory), IsOkAndHolds(test_case.equal)) << lhs->ShortDebugString() << " " << rhs_any->ShortDebugString(); EXPECT_THAT(MessageEquals(*lhs_any, *rhs_any, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_any->ShortDebugString() << " " << rhs_any->ShortDebugString(); } } } INSTANTIATE_TEST_SUITE_P( UnaryMessageEqualsTest, UnaryMessageEqualsTest, ValuesIn({ { .name = "NullValue_Equal", .ops = { ParseTextProto(R"pb()pb"), ParseTextProto( R"pb(null_value: NULL_VALUE)pb"), }, .equal = true, }, { .name = "BoolValue_False_Equal", .ops = { ParseTextProto(R"pb()pb"), ParseTextProto( R"pb(value: false)pb"), ParseTextProto( R"pb(bool_value: false)pb"), }, .equal = true, }, { .name = "BoolValue_True_Equal", .ops = { ParseTextProto( R"pb(value: true)pb"), ParseTextProto(R"pb(bool_value: true)pb"), }, .equal = true, }, { .name = "StringValue_Empty_Equal", .ops = { ParseTextProto(R"pb()pb"), ParseTextProto( R"pb(value: "")pb"), ParseTextProto( R"pb(string_value: "")pb"), }, .equal = true, }, { .name = "StringValue_Equal", .ops = { ParseTextProto( R"pb(value: "foo")pb"), ParseTextProto( R"pb(string_value: "foo")pb"), }, .equal = true, }, { .name = "BytesValue_Empty_Equal", .ops = { ParseTextProto(R"pb()pb"), ParseTextProto( R"pb(value: "")pb"), }, .equal = true, }, { .name = "BytesValue_Equal", .ops = { ParseTextProto( R"pb(value: "foo")pb"), ParseTextProto( R"pb(value: "foo")pb"), }, .equal = true, }, { .name = "ListValue_Equal", .ops = { ParseTextProto( R"pb(list_value: { values { bool_value: true } })pb"), ParseTextProto( R"pb(values { bool_value: true })pb"), }, .equal = true, }, { .name = "ListValue_NotEqual", .ops = { ParseTextProto( R"pb(list_value: { values { number_value: 0.0 } })pb"), ParseTextProto( R"pb(values { number_value: 1.0 })pb"), ParseTextProto( R"pb(list_value: { values { number_value: 2.0 } })pb"), ParseTextProto( R"pb(values { number_value: 3.0 })pb"), }, .equal = false, }, { .name = "StructValue_Equal", .ops = { ParseTextProto( R"pb(struct_value: { fields { key: "foo" value: { bool_value: true } } })pb"), ParseTextProto( R"pb(fields { key: "foo" value: { bool_value: true } })pb"), }, .equal = true, }, { .name = "StructValue_NotEqual", .ops = { ParseTextProto( R"pb(struct_value: { fields { key: "foo" value: { number_value: 0.0 } } })pb"), ParseTextProto( R"pb( fields { key: "bar" value: { number_value: 0.0 } })pb"), ParseTextProto( R"pb(struct_value: { fields { key: "foo" value: { number_value: 1.0 } } })pb"), ParseTextProto( R"pb( fields { key: "bar" value: { number_value: 1.0 } })pb"), }, .equal = false, }, { .name = "Heterogeneous_Equal", .ops = { ParseTextProto(R"pb()pb"), ParseTextProto(R"pb()pb"), ParseTextProto(R"pb()pb"), ParseTextProto(R"pb()pb"), ParseTextProto(R"pb()pb"), ParseTextProto(R"pb()pb"), ParseTextProto(R"pb(number_value: 0.0)pb"), }, .equal = true, }, { .name = "Message_Equals", .ops = { ParseTextProto(R"pb()pb"), ParseTextProto(R"pb()pb"), }, .equal = true, }, { .name = "Heterogeneous_NotEqual", .ops = { ParseTextProto( R"pb(value: false)pb"), ParseTextProto( R"pb(value: 0)pb"), ParseTextProto( R"pb(value: 1)pb"), ParseTextProto( R"pb(value: 2)pb"), ParseTextProto( R"pb(value: 3)pb"), ParseTextProto( R"pb(value: 4.0)pb"), ParseTextProto( R"pb(value: 5.0)pb"), ParseTextProto(R"pb()pb"), ParseTextProto(R"pb(bool_value: true)pb"), ParseTextProto(R"pb(number_value: 6.0)pb"), ParseTextProto( R"pb(string_value: "bar")pb"), ParseTextProto( R"pb(value: "foo")pb"), ParseTextProto( R"pb(value: "")pb"), ParseTextProto( R"pb(value: "foo")pb"), ParseTextProto( R"pb(list_value: {})pb"), ParseTextProto( R"pb(values { bool_value: true })pb"), ParseTextProto(R"pb(struct_value: {})pb"), ParseTextProto( R"pb(fields { key: "foo" value: { bool_value: false } })pb"), ParseTextProto(R"pb()pb"), ParseTextProto( R"pb(seconds: 1 nanos: 1)pb"), ParseTextProto(R"pb()pb"), ParseTextProto( R"pb(seconds: 1 nanos: 1)pb"), ParseTextProto(R"pb()pb"), ParseTextProto( R"pb(single_bool: true)pb"), }, .equal = false, }, }), UnaryMessageEqualsTestParamName); struct UnaryMessageFieldEqualsTestParam { std::string name; std::string message; std::vector fields; bool equal; }; std::string UnaryMessageFieldEqualsTestParamName( const TestParamInfo& param_info) { return param_info.param.name; } using UnaryMessageFieldEqualsTest = TestWithParam; void PackMessageTo(const google::protobuf::Message& message, google::protobuf::Message* instance) { auto reflection = *well_known_types::GetAnyReflection(instance->GetDescriptor()); reflection.SetTypeUrl( instance, absl::StrCat("type.googleapis.com/", message.GetTypeName())); absl::Cord value; ABSL_CHECK(message.SerializeToString(&value)); reflection.SetValue(instance, value); } absl::optional, const google::protobuf::FieldDescriptor* absl_nonnull>> PackTestAllTypesProto3Field(const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field) { if (field->is_map()) { return absl::nullopt; } if (field->is_repeated() && field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { const auto* descriptor = message.GetDescriptor(); const auto* any_field = descriptor->FindFieldByName("repeated_any"); auto packed = WrapShared(message.New(), NewDeleteAllocator<>{}); const int size = message.GetReflection()->FieldSize(message, field); for (int i = 0; i < size; ++i) { PackMessageTo( message.GetReflection()->GetRepeatedMessage(message, field, i), packed->GetReflection()->AddMessage(cel::to_address(packed), any_field)); } return std::pair{packed, any_field}; } if (!field->is_repeated() && field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { const auto* descriptor = message.GetDescriptor(); const auto* any_field = descriptor->FindFieldByName("single_any"); auto packed = WrapShared(message.New(), NewDeleteAllocator<>{}); PackMessageTo(message.GetReflection()->GetMessage(message, field), packed->GetReflection()->MutableMessage( cel::to_address(packed), any_field)); return std::pair{packed, any_field}; } return absl::nullopt; } TEST_P(UnaryMessageFieldEqualsTest, Equals) { // We perform exhaustive comparison by testing for equality (or inequality) // against all combinations of fields. Additionally we convert to // `google.protobuf.Any` where applicable. This is all done for coverage and // to ensure different combinations, regardless of argument order, produce the // same result. const auto* pool = GetTestingDescriptorPool(); auto* factory = GetTestingMessageFactory(); const auto& test_case = GetParam(); auto lhs_message = ParseTextProto(test_case.message); auto rhs_message = ParseTextProto(test_case.message); const auto* descriptor = ABSL_DIE_IF_NULL( pool->FindMessageTypeByName(MessageTypeNameFor())); for (const auto& lhs : test_case.fields) { for (const auto& rhs : test_case.fields) { if (!test_case.equal && lhs == rhs) { // When testing for inequality, do not compare the same field to itself. continue; } const auto* lhs_field = ABSL_DIE_IF_NULL(descriptor->FindFieldByName(lhs)); const auto* rhs_field = ABSL_DIE_IF_NULL(descriptor->FindFieldByName(rhs)); EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " << rhs_message->ShortDebugString() << " " << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, *lhs_message, lhs_field, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " << rhs_message->ShortDebugString() << " " << rhs_field->name(); if (!lhs_field->is_repeated() && lhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { EXPECT_THAT(MessageFieldEquals(lhs_message->GetReflection()->GetMessage( *lhs_message, lhs_field), *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " << rhs_message->ShortDebugString() << " " << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, lhs_message->GetReflection()->GetMessage( *lhs_message, lhs_field), pool, factory), IsOkAndHolds(test_case.equal)) << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " << rhs_message->ShortDebugString() << " " << rhs_field->name(); } if (!rhs_field->is_repeated() && rhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, rhs_message->GetReflection()->GetMessage( *rhs_message, rhs_field), pool, factory), IsOkAndHolds(test_case.equal)) << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " << rhs_message->ShortDebugString() << " " << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(rhs_message->GetReflection()->GetMessage( *rhs_message, rhs_field), *lhs_message, lhs_field, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " << rhs_message->ShortDebugString() << " " << rhs_field->name(); } // Test `google.protobuf.Any`. absl::optional, const google::protobuf::FieldDescriptor* absl_nonnull>> lhs_any = PackTestAllTypesProto3Field(*lhs_message, lhs_field); absl::optional, const google::protobuf::FieldDescriptor* absl_nonnull>> rhs_any = PackTestAllTypesProto3Field(*rhs_message, rhs_field); if (lhs_any) { EXPECT_THAT(MessageFieldEquals(*lhs_any->first, lhs_any->second, *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_any->first->ShortDebugString() << " " << rhs_message->ShortDebugString(); if (!lhs_any->second->is_repeated()) { EXPECT_THAT( MessageFieldEquals(lhs_any->first->GetReflection()->GetMessage( *lhs_any->first, lhs_any->second), *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_any->first->ShortDebugString() << " " << rhs_message->ShortDebugString(); } } if (rhs_any) { EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_any->first, rhs_any->second, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_message->ShortDebugString() << " " << rhs_any->first->ShortDebugString(); if (!rhs_any->second->is_repeated()) { EXPECT_THAT( MessageFieldEquals(*lhs_message, lhs_field, rhs_any->first->GetReflection()->GetMessage( *rhs_any->first, rhs_any->second), pool, factory), IsOkAndHolds(test_case.equal)) << lhs_message->ShortDebugString() << " " << rhs_any->first->ShortDebugString(); } } if (lhs_any && rhs_any) { EXPECT_THAT( MessageFieldEquals(*lhs_any->first, lhs_any->second, *rhs_any->first, rhs_any->second, pool, factory), IsOkAndHolds(test_case.equal)) << lhs_any->first->ShortDebugString() << " " << rhs_any->first->ShortDebugString(); } } } } INSTANTIATE_TEST_SUITE_P( UnaryMessageFieldEqualsTest, UnaryMessageFieldEqualsTest, ValuesIn({ { .name = "Heterogeneous_Single_Equal", .message = R"pb( single_int32: 1 single_int64: 1 single_uint32: 1 single_uint64: 1 single_float: 1 single_double: 1 single_value: { number_value: 1 } single_int32_wrapper: { value: 1 } single_int64_wrapper: { value: 1 } single_uint32_wrapper: { value: 1 } single_uint64_wrapper: { value: 1 } single_float_wrapper: { value: 1 } single_double_wrapper: { value: 1 } standalone_enum: BAR )pb", .fields = { "single_int32", "single_int64", "single_uint32", "single_uint64", "single_float", "single_double", "single_value", "single_int32_wrapper", "single_int64_wrapper", "single_uint32_wrapper", "single_uint64_wrapper", "single_float_wrapper", "single_double_wrapper", "standalone_enum", }, .equal = true, }, { .name = "Heterogeneous_Single_NotEqual", .message = R"pb( null_value: NULL_VALUE single_bool: false single_int32: 2 single_int64: 3 single_uint32: 4 single_uint64: 5 single_float: NaN single_double: NaN single_string: "foo" single_bytes: "foo" single_value: { number_value: 8 } single_int32_wrapper: { value: 9 } single_int64_wrapper: { value: 10 } single_uint32_wrapper: { value: 11 } single_uint64_wrapper: { value: 12 } single_float_wrapper: { value: 13 } single_double_wrapper: { value: 14 } single_string_wrapper: { value: "bar" } single_bytes_wrapper: { value: "bar" } standalone_enum: BAR )pb", .fields = { "null_value", "single_bool", "single_int32", "single_int64", "single_uint32", "single_uint64", "single_float", "single_double", "single_string", "single_bytes", "single_value", "single_int32_wrapper", "single_int64_wrapper", "single_uint32_wrapper", "single_uint64_wrapper", "single_float_wrapper", "single_double_wrapper", "standalone_enum", }, .equal = false, }, { .name = "Heterogeneous_Repeated_Equal", .message = R"pb( repeated_int32: 1 repeated_int64: 1 repeated_uint32: 1 repeated_uint64: 1 repeated_float: 1 repeated_double: 1 repeated_value: { number_value: 1 } repeated_int32_wrapper: { value: 1 } repeated_int64_wrapper: { value: 1 } repeated_uint32_wrapper: { value: 1 } repeated_uint64_wrapper: { value: 1 } repeated_float_wrapper: { value: 1 } repeated_double_wrapper: { value: 1 } repeated_nested_enum: BAR single_value: { list_value: { values { number_value: 1 } } } list_value: { values { number_value: 1 } } )pb", .fields = { "repeated_int32", "repeated_int64", "repeated_uint32", "repeated_uint64", "repeated_float", "repeated_double", "repeated_value", "repeated_int32_wrapper", "repeated_int64_wrapper", "repeated_uint32_wrapper", "repeated_uint64_wrapper", "repeated_float_wrapper", "repeated_double_wrapper", "repeated_nested_enum", "single_value", "list_value", }, .equal = true, }, { .name = "Heterogeneous_Repeated_NotEqual", .message = R"pb( repeated_null_value: NULL_VALUE repeated_bool: false repeated_int32: 2 repeated_int64: 3 repeated_uint32: 4 repeated_uint64: 5 repeated_float: 6 repeated_double: 7 repeated_string: "foo" repeated_bytes: "foo" repeated_value: { number_value: 8 } repeated_int32_wrapper: { value: 9 } repeated_int64_wrapper: { value: 10 } repeated_uint32_wrapper: { value: 11 } repeated_uint64_wrapper: { value: 12 } repeated_float_wrapper: { value: 13 } repeated_double_wrapper: { value: 14 } repeated_string_wrapper: { value: "bar" } repeated_bytes_wrapper: { value: "bar" } repeated_nested_enum: BAR )pb", .fields = { "repeated_null_value", "repeated_bool", "repeated_int32", "repeated_int64", "repeated_uint32", "repeated_uint64", "repeated_float", "repeated_double", "repeated_string", "repeated_bytes", "repeated_value", "repeated_int32_wrapper", "repeated_int64_wrapper", "repeated_uint32_wrapper", "repeated_uint64_wrapper", "repeated_float_wrapper", "repeated_double_wrapper", "repeated_nested_enum", }, .equal = false, }, { .name = "Heterogeneous_Map_Equal", .message = R"pb( map_int32_int32 { key: 1 value: 1 } map_int32_uint32 { key: 1 value: 1 } map_int32_int64 { key: 1 value: 1 } map_int32_uint64 { key: 1 value: 1 } map_int32_float { key: 1 value: 1 } map_int32_double { key: 1 value: 1 } map_int32_enum { key: 1 value: BAR } map_int32_value { key: 1 value: { number_value: 1 } } map_int32_int32_wrapper { key: 1 value: { value: 1 } } map_int32_uint32_wrapper { key: 1 value: { value: 1 } } map_int32_int64_wrapper { key: 1 value: { value: 1 } } map_int32_uint64_wrapper { key: 1 value: { value: 1 } } map_int32_float_wrapper { key: 1 value: { value: 1 } } map_int32_double_wrapper { key: 1 value: { value: 1 } } map_int64_int32 { key: 1 value: 1 } map_int64_uint32 { key: 1 value: 1 } map_int64_int64 { key: 1 value: 1 } map_int64_uint64 { key: 1 value: 1 } map_int64_float { key: 1 value: 1 } map_int64_double { key: 1 value: 1 } map_int64_enum { key: 1 value: BAR } map_int64_value { key: 1 value: { number_value: 1 } } map_int64_int32_wrapper { key: 1 value: { value: 1 } } map_int64_uint32_wrapper { key: 1 value: { value: 1 } } map_int64_int64_wrapper { key: 1 value: { value: 1 } } map_int64_uint64_wrapper { key: 1 value: { value: 1 } } map_int64_float_wrapper { key: 1 value: { value: 1 } } map_int64_double_wrapper { key: 1 value: { value: 1 } } map_uint32_int32 { key: 1 value: 1 } map_uint32_uint32 { key: 1 value: 1 } map_uint32_int64 { key: 1 value: 1 } map_uint32_uint64 { key: 1 value: 1 } map_uint32_float { key: 1 value: 1 } map_uint32_double { key: 1 value: 1 } map_uint32_enum { key: 1 value: BAR } map_uint32_value { key: 1 value: { number_value: 1 } } map_uint32_int32_wrapper { key: 1 value: { value: 1 } } map_uint32_uint32_wrapper { key: 1 value: { value: 1 } } map_uint32_int64_wrapper { key: 1 value: { value: 1 } } map_uint32_uint64_wrapper { key: 1 value: { value: 1 } } map_uint32_float_wrapper { key: 1 value: { value: 1 } } map_uint32_double_wrapper { key: 1 value: { value: 1 } } map_uint64_int32 { key: 1 value: 1 } map_uint64_uint32 { key: 1 value: 1 } map_uint64_int64 { key: 1 value: 1 } map_uint64_uint64 { key: 1 value: 1 } map_uint64_float { key: 1 value: 1 } map_uint64_double { key: 1 value: 1 } map_uint64_enum { key: 1 value: BAR } map_uint64_value { key: 1 value: { number_value: 1 } } map_uint64_int32_wrapper { key: 1 value: { value: 1 } } map_uint64_uint32_wrapper { key: 1 value: { value: 1 } } map_uint64_int64_wrapper { key: 1 value: { value: 1 } } map_uint64_uint64_wrapper { key: 1 value: { value: 1 } } map_uint64_float_wrapper { key: 1 value: { value: 1 } } map_uint64_double_wrapper { key: 1 value: { value: 1 } } )pb", .fields = { "map_int32_int32", "map_int32_uint32", "map_int32_int64", "map_int32_uint64", "map_int32_float", "map_int32_double", "map_int32_enum", "map_int32_value", "map_int32_int32_wrapper", "map_int32_uint32_wrapper", "map_int32_int64_wrapper", "map_int32_uint64_wrapper", "map_int32_float_wrapper", "map_int32_double_wrapper", "map_int64_int32", "map_int64_uint32", "map_int64_int64", "map_int64_uint64", "map_int64_float", "map_int64_double", "map_int64_enum", "map_int64_value", "map_int64_int32_wrapper", "map_int64_uint32_wrapper", "map_int64_int64_wrapper", "map_int64_uint64_wrapper", "map_int64_float_wrapper", "map_int64_double_wrapper", "map_uint32_int32", "map_uint32_uint32", "map_uint32_int64", "map_uint32_uint64", "map_uint32_float", "map_uint32_double", "map_uint32_enum", "map_uint32_value", "map_uint32_int32_wrapper", "map_uint32_uint32_wrapper", "map_uint32_int64_wrapper", "map_uint32_uint64_wrapper", "map_uint32_float_wrapper", "map_uint32_double_wrapper", "map_uint64_int32", "map_uint64_uint32", "map_uint64_int64", "map_uint64_uint64", "map_uint64_float", "map_uint64_double", "map_uint64_enum", "map_uint64_value", "map_uint64_int32_wrapper", "map_uint64_uint32_wrapper", "map_uint64_int64_wrapper", "map_uint64_uint64_wrapper", "map_uint64_float_wrapper", "map_uint64_double_wrapper", }, .equal = true, }, { .name = "Heterogeneous_Map_NotEqual", .message = R"pb( map_bool_bool { key: false value: false } map_bool_int32 { key: false value: 1 } map_bool_uint32 { key: false value: 0 } map_int32_int32 { key: 0x7FFFFFFF value: 1 } map_int64_int64 { key: 0x7FFFFFFFFFFFFFFF value: 1 } map_uint32_uint32 { key: 0xFFFFFFFF value: 1 } map_uint64_uint64 { key: 0xFFFFFFFFFFFFFFFF value: 1 } map_string_string { key: "foo" value: "bar" } map_string_bytes { key: "foo" value: "bar" } map_int32_bytes { key: -2147483648 value: "bar" } map_int64_bytes { key: -9223372036854775808 value: "bar" } map_int32_float { key: -2147483648 value: 1 } map_int64_double { key: -9223372036854775808 value: 1 } map_uint32_string { key: 0xFFFFFFFF value: "bar" } map_uint64_string { key: 0xFFFFFFFF value: "foo" } map_uint32_bytes { key: 0xFFFFFFFF value: "bar" } map_uint64_bytes { key: 0xFFFFFFFF value: "foo" } map_uint32_bool { key: 0xFFFFFFFF value: false } map_uint64_bool { key: 0xFFFFFFFF value: true } single_value: { struct_value: { fields { key: "bar" value: { string_value: "foo" } } } } single_struct: { fields { key: "baz" value: { string_value: "foo" } } } standalone_message: {} )pb", .fields = { "map_bool_bool", "map_bool_int32", "map_bool_uint32", "map_int32_int32", "map_int64_int64", "map_uint32_uint32", "map_uint64_uint64", "map_string_string", "map_string_bytes", "map_int32_bytes", "map_int64_bytes", "map_int32_float", "map_int64_double", "map_uint32_string", "map_uint64_string", "map_uint32_bytes", "map_uint64_bytes", "map_uint32_bool", "map_uint64_bool", "single_value", "single_struct", "standalone_message", }, .equal = false, }, }), UnaryMessageFieldEqualsTestParamName); TEST(MessageEquals, AnyFallback) { const auto* pool = GetTestingDescriptorPool(); auto* factory = GetTestingMessageFactory(); google::protobuf::Arena arena; auto message1 = DynamicParseTextProto( &arena, R"pb(single_any: { type_url: "type.googleapis.com/message.that.does.not.Exist" value: "foo" })pb", pool, factory); auto message2 = DynamicParseTextProto( &arena, R"pb(single_any: { type_url: "type.googleapis.com/message.that.does.not.Exist" value: "foo" })pb", pool, factory); auto message3 = DynamicParseTextProto( &arena, R"pb(single_any: { type_url: "type.googleapis.com/message.that.does.not.Exist" value: "bar" })pb", pool, factory); EXPECT_THAT(MessageEquals(*message1, *message2, pool, factory), IsOkAndHolds(IsTrue())); EXPECT_THAT(MessageEquals(*message2, *message1, pool, factory), IsOkAndHolds(IsTrue())); EXPECT_THAT(MessageEquals(*message1, *message3, pool, factory), IsOkAndHolds(IsFalse())); EXPECT_THAT(MessageEquals(*message3, *message1, pool, factory), IsOkAndHolds(IsFalse())); } TEST(MessageFieldEquals, AnyFallback) { const auto* pool = GetTestingDescriptorPool(); auto* factory = GetTestingMessageFactory(); google::protobuf::Arena arena; auto message1 = DynamicParseTextProto( &arena, R"pb(single_any: { type_url: "type.googleapis.com/message.that.does.not.Exist" value: "foo" })pb", pool, factory); auto message2 = DynamicParseTextProto( &arena, R"pb(single_any: { type_url: "type.googleapis.com/message.that.does.not.Exist" value: "foo" })pb", pool, factory); auto message3 = DynamicParseTextProto( &arena, R"pb(single_any: { type_url: "type.googleapis.com/message.that.does.not.Exist" value: "bar" })pb", pool, factory); EXPECT_THAT(MessageFieldEquals( *message1, ABSL_DIE_IF_NULL( message1->GetDescriptor()->FindFieldByName("single_any")), *message2, ABSL_DIE_IF_NULL( message2->GetDescriptor()->FindFieldByName("single_any")), pool, factory), IsOkAndHolds(IsTrue())); EXPECT_THAT(MessageFieldEquals( *message2, ABSL_DIE_IF_NULL( message2->GetDescriptor()->FindFieldByName("single_any")), *message1, ABSL_DIE_IF_NULL( message1->GetDescriptor()->FindFieldByName("single_any")), pool, factory), IsOkAndHolds(IsTrue())); EXPECT_THAT(MessageFieldEquals( *message1, ABSL_DIE_IF_NULL( message1->GetDescriptor()->FindFieldByName("single_any")), *message3, ABSL_DIE_IF_NULL( message3->GetDescriptor()->FindFieldByName("single_any")), pool, factory), IsOkAndHolds(IsFalse())); EXPECT_THAT(MessageFieldEquals( *message3, ABSL_DIE_IF_NULL( message3->GetDescriptor()->FindFieldByName("single_any")), *message1, ABSL_DIE_IF_NULL( message1->GetDescriptor()->FindFieldByName("single_any")), pool, factory), IsOkAndHolds(IsFalse())); } } // namespace } // namespace cel::internal ================================================ FILE: internal/message_type_name.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_TYPE_NAME_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_TYPE_NAME_H_ #include #include #include "absl/base/no_destructor.h" #include "absl/strings/string_view.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" namespace cel::internal { // MessageTypeNameFor returns the fully qualified message type name of a // generated message. This is a portable version which works with the lite // runtime as well. template std::enable_if_t< std::conjunction_v, std::negation>>, absl::string_view> MessageTypeNameFor() { static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(!std::is_reference_v, "T must not be a reference"); static const absl::NoDestructor kTypeName(T().GetTypeName()); return *kTypeName; } template std::enable_if_t, absl::string_view> MessageTypeNameFor() { static_assert(!std::is_const_v, "T must not be const qualified"); static_assert(!std::is_volatile_v, "T must not be volatile qualified"); static_assert(!std::is_reference_v, "T must not be a reference"); return T::descriptor()->full_name(); } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_TYPE_NAME_H_ ================================================ FILE: internal/message_type_name_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/message_type_name.h" #include "google/protobuf/any.pb.h" #include "internal/testing.h" namespace cel::internal { namespace { TEST(MessageTypeNameFor, Generated) { EXPECT_EQ(MessageTypeNameFor(), "google.protobuf.Any"); } } // namespace } // namespace cel::internal ================================================ FILE: internal/minimal_descriptor_database.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_DATABASE_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_DATABASE_H_ #include "absl/base/nullability.h" #include "google/protobuf/descriptor_database.h" namespace cel::internal { // GetMinimalDescriptorDatabase returns a pointer to a // `google::protobuf::DescriptorDatabase` which includes has the minimally necessary // descriptors required by the Common Expression Language. The returning // `proto2::DescripDescriptorDatabasetorPool` is valid for the lifetime of the // process. google::protobuf::DescriptorDatabase* absl_nonnull GetMinimalDescriptorDatabase(); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_DATABASE_H_ ================================================ FILE: internal/minimal_descriptor_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_POOL_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_POOL_H_ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "google/protobuf/descriptor.h" namespace cel::internal { // GetMinimalDescriptorPool returns a pointer to a `google::protobuf::DescriptorPool` // which includes has the minimally necessary descriptors required by the Common // Expression Language. The returning `google::protobuf::DescriptorPool` is valid for the // lifetime of the process. // // This descriptor pool can be used as an underlay for another descriptor pool: // // google::protobuf::DescriptorPool my_descriptor_pool(GetMinimalDescriptorPool()); const google::protobuf::DescriptorPool* absl_nonnull GetMinimalDescriptorPool(); // If required, adds the minimally required descriptors to the pool. absl::Status AddMinimumRequiredDescriptorsToPool( google::protobuf::DescriptorPool* absl_nonnull pool); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_MINIMAL_DESCRIPTOR_POOL_H_ ================================================ FILE: internal/minimal_descriptors.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "internal/minimal_descriptor_database.h" #include "internal/minimal_descriptor_pool.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/descriptor_database.h" namespace cel::internal { namespace { ABSL_CONST_INIT const uint8_t kMinimalDescriptorSet[] = { #include "internal/minimal_descriptor_set_embed.inc" }; const google::protobuf::FileDescriptorSet* GetMinimumFileDescriptorSet() { static google::protobuf::FileDescriptorSet* const file_desc_set = []() { google::protobuf::FileDescriptorSet* file_desc_set = new google::protobuf::FileDescriptorSet(); ABSL_CHECK(file_desc_set->ParseFromArray( // Crash OK kMinimalDescriptorSet, ABSL_ARRAYSIZE(kMinimalDescriptorSet))); return file_desc_set; }(); return file_desc_set; } } // namespace const google::protobuf::DescriptorPool* absl_nonnull GetMinimalDescriptorPool() { static const google::protobuf::DescriptorPool* absl_nonnull const pool = []() { const google::protobuf::FileDescriptorSet* file_desc_set = GetMinimumFileDescriptorSet(); auto* pool = new google::protobuf::DescriptorPool(); for (const auto& file_desc : file_desc_set->file()) { ABSL_CHECK(pool->BuildFile(file_desc) != nullptr); // Crash OK } return pool; }(); return pool; } google::protobuf::DescriptorDatabase* absl_nonnull GetMinimalDescriptorDatabase() { static absl::NoDestructor database( *GetMinimalDescriptorPool()); return &*database; } namespace { class DescriptorErrorCollector final : public google::protobuf::DescriptorPool::ErrorCollector { public: void RecordError(absl::string_view, absl::string_view element_name, const google::protobuf::Message*, ErrorLocation, absl::string_view message) override { errors_.push_back(absl::StrCat(element_name, ": ", message)); } bool FoundErrors() const { return !errors_.empty(); } std::string FormatErrors() const { return absl::StrJoin(errors_, "\n\t"); } private: std::vector errors_; }; } // namespace absl::Status AddMinimumRequiredDescriptorsToPool( google::protobuf::DescriptorPool* absl_nonnull pool) { const google::protobuf::FileDescriptorSet* file_desc_set = GetMinimumFileDescriptorSet(); for (const auto& file_desc : file_desc_set->file()) { if (pool->FindFileByName(file_desc.name()) != nullptr) { continue; } DescriptorErrorCollector error_collector; if (pool->BuildFileCollectingErrors(file_desc, &error_collector) == nullptr) { ABSL_DCHECK(error_collector.FoundErrors()); return absl::UnknownError( absl::StrCat("Failed to build file descriptor for ", file_desc.name(), ":\n\t", error_collector.FormatErrors())); } } return absl::OkStatus(); } } // namespace cel::internal ================================================ FILE: internal/names.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/names.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "internal/lexis.h" namespace cel::internal { bool IsValidRelativeName(absl::string_view name) { if (name.empty()) { return false; } for (const auto& id : absl::StrSplit(name, '.')) { if (!LexisIsIdentifier(id)) { return false; } } return true; } } // namespace cel::internal ================================================ FILE: internal/names.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NAMES_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_NAMES_H_ #include "absl/strings/string_view.h" namespace cel::internal { bool IsValidRelativeName(absl::string_view name); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_NAMES_H_ ================================================ FILE: internal/names_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/names.h" #include "internal/testing.h" namespace cel::internal { namespace { struct NamesTestCase final { absl::string_view text; bool ok; }; using IsValidRelativeNameTest = testing::TestWithParam; TEST_P(IsValidRelativeNameTest, Compliance) { const NamesTestCase& test_case = GetParam(); if (test_case.ok) { EXPECT_TRUE(IsValidRelativeName(test_case.text)); } else { EXPECT_FALSE(IsValidRelativeName(test_case.text)); } } INSTANTIATE_TEST_SUITE_P(IsValidRelativeNameTest, IsValidRelativeNameTest, testing::ValuesIn({{"foo", true}, {"foo.Bar", true}, {"", false}, {".", false}, {".foo", false}, {".foo.Bar", false}, {"foo..Bar", false}, {"foo.Bar.", false}})); } // namespace } // namespace cel::internal ================================================ FILE: internal/new.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/new.h" #include #include #include #include #ifdef _MSC_VER #include #endif #include "absl/base/config.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/numeric/bits.h" #include "internal/align.h" #if defined(__cpp_aligned_new) && __cpp_aligned_new >= 201606L #define CEL_INTERNAL_HAVE_ALIGNED_NEW 1 #endif #if defined(__cpp_sized_deallocation) && __cpp_sized_deallocation >= 201309L #define CEL_INTERNAL_HAVE_SIZED_DELETE 1 #endif namespace cel::internal { namespace { [[noreturn, maybe_unused]] void ThrowStdBadAlloc() { #ifdef ABSL_HAVE_EXCEPTIONS throw std::bad_alloc(); #else std::abort(); #endif } } // namespace void* New(size_t size) { return ::operator new(size); } void* AlignedNew(size_t size, std::align_val_t alignment) { ABSL_DCHECK(absl::has_single_bit(static_cast(alignment))); #ifdef CEL_INTERNAL_HAVE_ALIGNED_NEW return ::operator new(size, alignment); #else if (static_cast(alignment) <= kDefaultNewAlignment) { return New(size); } #if defined(_MSC_VER) void* ptr = _aligned_malloc(size, static_cast(alignment)); if (ABSL_PREDICT_FALSE(size != 0 && ptr == nullptr)) { ThrowStdBadAlloc(); } return ptr; #elif defined(__APPLE__) void* ptr; if (ABSL_PREDICT_FALSE( posix_memalign(&ptr, static_cast(alignment), size) != 0)) { ThrowStdBadAlloc(); } return ptr; #else void* ptr = std::aligned_alloc(static_cast(alignment), size); if (ABSL_PREDICT_FALSE(size != 0 && ptr == nullptr)) { ThrowStdBadAlloc(); } return ptr; #endif #endif } std::pair SizeReturningNew(size_t size) { return std::pair{::operator new(size), size}; } std::pair SizeReturningAlignedNew(size_t size, std::align_val_t alignment) { ABSL_DCHECK(absl::has_single_bit(static_cast(alignment))); #ifdef CEL_INTERNAL_HAVE_ALIGNED_NEW return std::pair{::operator new(size, alignment), size}; #else return std::pair{AlignedNew(size, alignment), size}; #endif } void Delete(void* ptr) noexcept { ::operator delete(ptr); } void SizedDelete(void* ptr, size_t size) noexcept { #ifdef CEL_INTERNAL_HAVE_SIZED_DELETE ::operator delete(ptr, size); #else ::operator delete(ptr); #endif } void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept { ABSL_DCHECK(absl::has_single_bit(static_cast(alignment))); #ifdef CEL_INTERNAL_HAVE_ALIGNED_NEW ::operator delete(ptr, alignment); #else if (static_cast(alignment) <= kDefaultNewAlignment) { ::operator delete(ptr); } else { #if defined(_MSC_VER) _aligned_free(ptr); #else std::free(ptr); #endif } #endif } void SizedAlignedDelete(void* ptr, size_t size, std::align_val_t alignment) noexcept { ABSL_DCHECK(absl::has_single_bit(static_cast(alignment))); #ifdef CEL_INTERNAL_HAVE_ALIGNED_NEW #ifdef CEL_INTERNAL_HAVE_SIZED_DELETE ::operator delete(ptr, size, alignment); #else ::operator delete(ptr, alignment); #endif #else AlignedDelete(ptr, alignment); #endif } } // namespace cel::internal ================================================ FILE: internal/new.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NEW_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_NEW_H_ #include #include #include namespace cel::internal { inline constexpr size_t kDefaultNewAlignment = #ifdef __STDCPP_DEFAULT_NEW_ALIGNMENT__ __STDCPP_DEFAULT_NEW_ALIGNMENT__ #else alignof(std::max_align_t) #endif ; // NOLINT(whitespace/semicolon) // Allocates memory which has a size of at least `size` and a minimum alignment // of `kDefaultNewAlignment`. void* New(size_t size); // Allocates memory which has a size of at least `size` and a minimum alignment // of `alignment`. To deallocate, the caller must use `AlignedDelete` or // `SizedAlignedDelete`. void* AlignedNew(size_t size, std::align_val_t alignment); std::pair SizeReturningNew(size_t size); // Allocates memory which has a size of at least `size` and a minimum alignment // of `alignment`, returns a pointer to the allocated memory and the actual // usable allocation size. To deallocate, the caller must use `AlignedDelete` or // `SizedAlignedDelete`. std::pair SizeReturningAlignedNew(size_t size, std::align_val_t alignment); void Delete(void* ptr) noexcept; void SizedDelete(void* ptr, size_t size) noexcept; void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept; void SizedAlignedDelete(void* ptr, size_t size, std::align_val_t alignment) noexcept; } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_NEW_H_ ================================================ FILE: internal/new_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/new.h" #include #include #include #include #include "internal/testing.h" namespace cel::internal { namespace { using ::testing::Ge; using ::testing::NotNull; TEST(New, Basic) { void* p = New(sizeof(uint64_t)); EXPECT_THAT(p, NotNull()); Delete(p); } TEST(AlignedNew, Basic) { void* p = AlignedNew(alignof(std::max_align_t) * 2, static_cast(alignof(std::max_align_t) * 2)); EXPECT_THAT(p, NotNull()); AlignedDelete(p, static_cast(alignof(std::max_align_t) * 2)); } TEST(SizeReturningNew, Basic) { void* p; size_t n; std::tie(p, n) = SizeReturningNew(sizeof(uint64_t)); EXPECT_THAT(p, NotNull()); EXPECT_THAT(n, Ge(sizeof(uint64_t))); SizedDelete(p, n); } TEST(SizeReturningAlignedNew, Basic) { void* p; size_t n; std::tie(p, n) = SizeReturningAlignedNew( alignof(std::max_align_t) * 2, static_cast(alignof(std::max_align_t) * 2)); EXPECT_THAT(p, NotNull()); EXPECT_THAT(n, Ge(alignof(std::max_align_t) * 2)); SizedAlignedDelete( p, n, static_cast(alignof(std::max_align_t) * 2)); } } // namespace } // namespace cel::internal ================================================ FILE: internal/noop_delete.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ #include #include "absl/base/nullability.h" namespace cel::internal { // Like `std::default_delete`, except it does nothing. template struct NoopDelete { static_assert(!std::is_function::value, "NoopDelete cannot be instantiated for function types"); constexpr NoopDelete() noexcept = default; constexpr NoopDelete(const NoopDelete&) noexcept = default; template < typename U, typename = std::enable_if_t>, std::is_convertible>>> // NOLINTNEXTLINE(google-explicit-constructor) constexpr NoopDelete(const NoopDelete&) noexcept {} constexpr void operator()(T* absl_nullable) const noexcept { static_assert(sizeof(T) >= 0, "cannot delete an incomplete type"); static_assert(!std::is_void::value, "cannot delete an incomplete type"); } }; template inline constexpr NoopDelete NoopDeleteFor() noexcept { return NoopDelete{}; } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_NOOP_DELETE_H_ ================================================ FILE: internal/number.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_NUMBER_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_NUMBER_H_ #include #include #include "absl/types/variant.h" namespace cel::internal { constexpr int64_t kInt64Max = std::numeric_limits::max(); constexpr int64_t kInt64Min = std::numeric_limits::lowest(); constexpr uint64_t kUint64Max = std::numeric_limits::max(); constexpr uint64_t kUintToIntMax = static_cast(kInt64Max); constexpr double kDoubleToIntMax = static_cast(kInt64Max); constexpr double kDoubleToIntMin = static_cast(kInt64Min); constexpr double kDoubleToUintMax = static_cast(kUint64Max); // The highest integer values that are round-trippable after rounding and // casting to double. template constexpr int RoundingError() { return 1 << (std::numeric_limits::digits - std::numeric_limits::digits - 1); } constexpr double kMaxDoubleRepresentableAsInt = static_cast(kInt64Max - RoundingError()); constexpr double kMaxDoubleRepresentableAsUint = static_cast(kUint64Max - RoundingError()); #define CEL_ABSL_VISIT_CONSTEXPR using NumberVariant = absl::variant; enum class ComparisonResult { kLesser, kEqual, kGreater, // Special case for nan. kNanInequal }; // Return the inverse relation (i.e. Invert(cmp(b, a)) is the same as cmp(a, b). constexpr ComparisonResult Invert(ComparisonResult result) { switch (result) { case ComparisonResult::kLesser: return ComparisonResult::kGreater; case ComparisonResult::kGreater: return ComparisonResult::kLesser; case ComparisonResult::kEqual: return ComparisonResult::kEqual; case ComparisonResult::kNanInequal: return ComparisonResult::kNanInequal; } } template struct ConversionVisitor { template constexpr OutType operator()(InType v) { return static_cast(v); } }; template constexpr ComparisonResult Compare(T a, T b) { return (a > b) ? ComparisonResult::kGreater : (a == b) ? ComparisonResult::kEqual : ComparisonResult::kLesser; } constexpr ComparisonResult DoubleCompare(double a, double b) { // constexpr friendly isnan check. if (!(a == a) || !(b == b)) { return ComparisonResult::kNanInequal; } return Compare(a, b); } // Implement generic numeric comparison against double value. struct DoubleCompareVisitor { constexpr explicit DoubleCompareVisitor(double v) : v(v) {} constexpr ComparisonResult operator()(double other) const { return DoubleCompare(v, other); } constexpr ComparisonResult operator()(uint64_t other) const { if (v > kDoubleToUintMax) { return ComparisonResult::kGreater; } else if (v < 0) { return ComparisonResult::kLesser; } else { return DoubleCompare(v, static_cast(other)); } } constexpr ComparisonResult operator()(int64_t other) const { if (v > kDoubleToIntMax) { return ComparisonResult::kGreater; } else if (v < kDoubleToIntMin) { return ComparisonResult::kLesser; } else { return DoubleCompare(v, static_cast(other)); } } double v; }; // Implement generic numeric comparison against uint value. // Delegates to double comparison if either variable is double. struct UintCompareVisitor { constexpr explicit UintCompareVisitor(uint64_t v) : v(v) {} constexpr ComparisonResult operator()(double other) const { return Invert(DoubleCompareVisitor(other)(v)); } constexpr ComparisonResult operator()(uint64_t other) const { return Compare(v, other); } constexpr ComparisonResult operator()(int64_t other) const { if (v > kUintToIntMax || other < 0) { return ComparisonResult::kGreater; } else { return Compare(v, static_cast(other)); } } uint64_t v; }; // Implement generic numeric comparison against int value. // Delegates to uint / double if either value is uint / double. struct IntCompareVisitor { constexpr explicit IntCompareVisitor(int64_t v) : v(v) {} constexpr ComparisonResult operator()(double other) { return Invert(DoubleCompareVisitor(other)(v)); } constexpr ComparisonResult operator()(uint64_t other) { return Invert(UintCompareVisitor(other)(v)); } constexpr ComparisonResult operator()(int64_t other) { return Compare(v, other); } int64_t v; }; struct CompareVisitor { explicit constexpr CompareVisitor(NumberVariant rhs) : rhs(rhs) {} CEL_ABSL_VISIT_CONSTEXPR ComparisonResult operator()(double v) { return absl::visit(DoubleCompareVisitor(v), rhs); } CEL_ABSL_VISIT_CONSTEXPR ComparisonResult operator()(uint64_t v) { return absl::visit(UintCompareVisitor(v), rhs); } CEL_ABSL_VISIT_CONSTEXPR ComparisonResult operator()(int64_t v) { return absl::visit(IntCompareVisitor(v), rhs); } NumberVariant rhs; }; struct LosslessConvertibleToIntVisitor { constexpr bool operator()(double value) const { return value >= kDoubleToIntMin && value <= kMaxDoubleRepresentableAsInt && value == static_cast(static_cast(value)); } constexpr bool operator()(uint64_t value) const { return value <= kUintToIntMax; } constexpr bool operator()(int64_t value) const { return true; } }; struct LosslessConvertibleToUintVisitor { constexpr bool operator()(double value) const { return value >= 0 && value <= kMaxDoubleRepresentableAsUint && value == static_cast(static_cast(value)); } constexpr bool operator()(uint64_t value) const { return true; } constexpr bool operator()(int64_t value) const { return value >= 0; } }; // Utility class for CEL number operations. // // In CEL expressions, comparisons between different numeric types are treated // as all happening on the same continuous number line. This generally means // that integers and doubles in convertible range are compared after converting // to doubles (tolerating some loss of precision). // // This extends to key lookups -- {1: 'abc'}[1.0f] is expected to work since // 1.0 == 1 in CEL. class Number { public: // Factories to resolve ambiguous overload resolution against literals. static constexpr Number FromInt64(int64_t value) { return Number(value); } static constexpr Number FromUint64(uint64_t value) { return Number(value); } static constexpr Number FromDouble(double value) { return Number(value); } constexpr explicit Number(double double_value) : value_(double_value) {} constexpr explicit Number(int64_t int_value) : value_(int_value) {} constexpr explicit Number(uint64_t uint_value) : value_(uint_value) {} // Return a double representation of the value. CEL_ABSL_VISIT_CONSTEXPR double AsDouble() const { return absl::visit(internal::ConversionVisitor(), value_); } // Return signed int64 representation for the value. // Caller must guarantee the underlying value is representatble as an // int. CEL_ABSL_VISIT_CONSTEXPR int64_t AsInt() const { return absl::visit(internal::ConversionVisitor(), value_); } // Return unsigned int64 representation for the value. // Caller must guarantee the underlying value is representable as an // uint. CEL_ABSL_VISIT_CONSTEXPR uint64_t AsUint() const { return absl::visit(internal::ConversionVisitor(), value_); } // For key lookups, check if the conversion to signed int is lossless. CEL_ABSL_VISIT_CONSTEXPR bool LosslessConvertibleToInt() const { return absl::visit(internal::LosslessConvertibleToIntVisitor(), value_); } // For key lookups, check if the conversion to unsigned int is lossless. CEL_ABSL_VISIT_CONSTEXPR bool LosslessConvertibleToUint() const { return absl::visit(internal::LosslessConvertibleToUintVisitor(), value_); } CEL_ABSL_VISIT_CONSTEXPR bool operator<(Number other) const { return Compare(other) == internal::ComparisonResult::kLesser; } CEL_ABSL_VISIT_CONSTEXPR bool operator<=(Number other) const { internal::ComparisonResult cmp = Compare(other); return cmp != internal::ComparisonResult::kGreater && cmp != internal::ComparisonResult::kNanInequal; } CEL_ABSL_VISIT_CONSTEXPR bool operator>(Number other) const { return Compare(other) == internal::ComparisonResult::kGreater; } CEL_ABSL_VISIT_CONSTEXPR bool operator>=(Number other) const { internal::ComparisonResult cmp = Compare(other); return cmp != internal::ComparisonResult::kLesser && cmp != internal::ComparisonResult::kNanInequal; } CEL_ABSL_VISIT_CONSTEXPR bool operator==(Number other) const { return Compare(other) == internal::ComparisonResult::kEqual; } CEL_ABSL_VISIT_CONSTEXPR bool operator!=(Number other) const { return Compare(other) != internal::ComparisonResult::kEqual; } // Visit the underlying number representation, a variant of double, uint64_t, // or int64_t. template T visit(Op&& op) const { return absl::visit(std::forward(op), value_); } private: internal::NumberVariant value_; CEL_ABSL_VISIT_CONSTEXPR internal::ComparisonResult Compare( Number other) const { return absl::visit(internal::CompareVisitor(other.value_), value_); } }; } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_NUMBER_H_ ================================================ FILE: internal/number_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/number.h" #include #include #include "internal/testing.h" namespace cel::internal { namespace { TEST(Number, Basic) { EXPECT_GT(Number(1.1), Number::FromInt64(1)); EXPECT_LT(Number::FromUint64(1), Number(1.1)); EXPECT_EQ(Number(1.1), Number(1.1)); EXPECT_EQ(Number::FromUint64(1), Number::FromUint64(1)); EXPECT_EQ(Number::FromInt64(1), Number::FromUint64(1)); EXPECT_GT(Number::FromUint64(1), Number::FromInt64(-1)); EXPECT_EQ(Number::FromInt64(-1), Number::FromInt64(-1)); } TEST(Number, Conversions) { EXPECT_TRUE(Number::FromDouble(1.0).LosslessConvertibleToInt()); EXPECT_TRUE(Number::FromDouble(1.0).LosslessConvertibleToUint()); EXPECT_FALSE(Number::FromDouble(1.1).LosslessConvertibleToInt()); EXPECT_FALSE(Number::FromDouble(1.1).LosslessConvertibleToUint()); EXPECT_TRUE(Number::FromDouble(-1.0).LosslessConvertibleToInt()); EXPECT_FALSE(Number::FromDouble(-1.0).LosslessConvertibleToUint()); EXPECT_TRUE(Number::FromDouble(kDoubleToIntMin).LosslessConvertibleToInt()); // Need to add/substract a large number since double resolution is low at this // range. EXPECT_FALSE(Number::FromDouble(kMaxDoubleRepresentableAsUint + RoundingError()) .LosslessConvertibleToUint()); EXPECT_FALSE(Number::FromDouble(kMaxDoubleRepresentableAsInt + RoundingError()) .LosslessConvertibleToInt()); EXPECT_FALSE( Number::FromDouble(kDoubleToIntMin - 1025).LosslessConvertibleToInt()); EXPECT_EQ(Number::FromInt64(1).AsUint(), 1u); EXPECT_EQ(Number::FromUint64(1).AsInt(), 1); EXPECT_EQ(Number::FromDouble(1.0).AsUint(), 1); EXPECT_EQ(Number::FromDouble(1.0).AsInt(), 1); } } // namespace } // namespace cel::internal ================================================ FILE: internal/overflow.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/overflow.h" #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" #include "internal/status_macros.h" #include "internal/time.h" namespace cel::internal { namespace { constexpr int64_t kInt32Max = std::numeric_limits::max(); constexpr int64_t kInt32Min = std::numeric_limits::lowest(); constexpr int64_t kInt64Max = std::numeric_limits::max(); constexpr int64_t kInt64Min = std::numeric_limits::lowest(); constexpr uint64_t kUint32Max = std::numeric_limits::max(); ABSL_ATTRIBUTE_UNUSED constexpr uint64_t kUint64Max = std::numeric_limits::max(); constexpr uint64_t kUintToIntMax = static_cast(kInt64Max); constexpr double kDoubleToIntMax = static_cast(kInt64Max); constexpr double kDoubleToIntMin = static_cast(kInt64Min); const double kDoubleTwoTo64 = std::ldexp(1.0, 64); // 1.0 * 2^64 const absl::Duration kOneSecondDuration = absl::Seconds(1); const int64_t kOneSecondNanos = absl::ToInt64Nanoseconds(kOneSecondDuration); // Number of seconds between `0001-01-01T00:00:00Z` and Unix epoch. const int64_t kMinUnixTime = absl::ToInt64Seconds(MinTimestamp() - absl::UnixEpoch()); // Number of seconds between `9999-12-31T23:59:59.999999999Z` and Unix epoch. const int64_t kMaxUnixTime = absl::ToInt64Seconds(MaxTimestamp() - absl::UnixEpoch()); absl::Status CheckRange(bool valid_expression, absl::string_view error_message) { return valid_expression ? absl::OkStatus() : absl::OutOfRangeError(error_message); } absl::Status CheckArgument(bool valid_expression, absl::string_view error_message) { return valid_expression ? absl::OkStatus() : absl::InvalidArgumentError(error_message); } // Determine whether the duration is finite. bool IsFinite(absl::Duration d) { return d != absl::InfiniteDuration() && d != -absl::InfiniteDuration(); } // Determine whether the time is finite. bool IsFinite(absl::Time t) { return t != absl::InfiniteFuture() && t != absl::InfinitePast(); } } // namespace absl::StatusOr CheckedAdd(int64_t x, int64_t y) { #if ABSL_HAVE_BUILTIN(__builtin_add_overflow) int64_t sum; if (!__builtin_add_overflow(x, y, &sum)) { return sum; } return absl::OutOfRangeError("integer overflow"); #else CEL_RETURN_IF_ERROR(CheckRange( y > 0 ? x <= kInt64Max - y : x >= kInt64Min - y, "integer overflow")); return x + y; #endif } absl::StatusOr CheckedSub(int64_t x, int64_t y) { #if ABSL_HAVE_BUILTIN(__builtin_sub_overflow) int64_t diff; if (!__builtin_sub_overflow(x, y, &diff)) { return diff; } return absl::OutOfRangeError("integer overflow"); #else CEL_RETURN_IF_ERROR(CheckRange( y < 0 ? x <= kInt64Max + y : x >= kInt64Min + y, "integer overflow")); return x - y; #endif } absl::StatusOr CheckedNegation(int64_t v) { #if ABSL_HAVE_BUILTIN(__builtin_mul_overflow) int64_t prod; if (!__builtin_mul_overflow(v, -1, &prod)) { return prod; } return absl::OutOfRangeError("integer overflow"); #else CEL_RETURN_IF_ERROR(CheckRange(v != kInt64Min, "integer overflow")); return -v; #endif } absl::StatusOr CheckedMul(int64_t x, int64_t y) { #if ABSL_HAVE_BUILTIN(__builtin_mul_overflow) int64_t prod; if (!__builtin_mul_overflow(x, y, &prod)) { return prod; } return absl::OutOfRangeError("integer overflow"); #else CEL_RETURN_IF_ERROR( CheckRange(!((x == -1 && y == kInt64Min) || (y == -1 && x == kInt64Min) || (x > 0 && y > 0 && x > kInt64Max / y) || (x < 0 && y < 0 && x < kInt64Max / y) || // Avoid dividing kInt64Min by -1, use whichever value of x // or y is positive as the divisor. (x > 0 && y < 0 && y < kInt64Min / x) || (x < 0 && y > 0 && x < kInt64Min / y)), "integer overflow")); return x * y; #endif } absl::StatusOr CheckedDiv(int64_t x, int64_t y) { CEL_RETURN_IF_ERROR( CheckRange(x != kInt64Min || y != -1, "integer overflow")); CEL_RETURN_IF_ERROR(CheckArgument(y != 0, "divide by zero")); return x / y; } absl::StatusOr CheckedMod(int64_t x, int64_t y) { CEL_RETURN_IF_ERROR( CheckRange(x != kInt64Min || y != -1, "integer overflow")); CEL_RETURN_IF_ERROR(CheckArgument(y != 0, "modulus by zero")); return x % y; } absl::StatusOr CheckedAdd(uint64_t x, uint64_t y) { #if ABSL_HAVE_BUILTIN(__builtin_add_overflow) uint64_t sum; if (!__builtin_add_overflow(x, y, &sum)) { return sum; } return absl::OutOfRangeError("unsigned integer overflow"); #else CEL_RETURN_IF_ERROR( CheckRange(x <= kUint64Max - y, "unsigned integer overflow")); return x + y; #endif } absl::StatusOr CheckedSub(uint64_t x, uint64_t y) { #if ABSL_HAVE_BUILTIN(__builtin_sub_overflow) uint64_t diff; if (!__builtin_sub_overflow(x, y, &diff)) { return diff; } return absl::OutOfRangeError("unsigned integer overflow"); #else CEL_RETURN_IF_ERROR(CheckRange(y <= x, "unsigned integer overflow")); return x - y; #endif } absl::StatusOr CheckedMul(uint64_t x, uint64_t y) { #if ABSL_HAVE_BUILTIN(__builtin_mul_overflow) uint64_t prod; if (!__builtin_mul_overflow(x, y, &prod)) { return prod; } return absl::OutOfRangeError("unsigned integer overflow"); #else CEL_RETURN_IF_ERROR( CheckRange(y == 0 || x <= kUint64Max / y, "unsigned integer overflow")); return x * y; #endif } absl::StatusOr CheckedDiv(uint64_t x, uint64_t y) { CEL_RETURN_IF_ERROR(CheckArgument(y != 0, "divide by zero")); return x / y; } absl::StatusOr CheckedMod(uint64_t x, uint64_t y) { CEL_RETURN_IF_ERROR(CheckArgument(y != 0, "modulus by zero")); return x % y; } absl::StatusOr CheckedAdd(absl::Duration x, absl::Duration y) { CEL_RETURN_IF_ERROR( CheckRange(IsFinite(x) && IsFinite(y), "integer overflow")); // absl::Duration can handle +- infinite durations, but the Go time.Duration // implementation caps the durations to those expressible within a single // int64 rather than (seconds int64, nanos int32). // // The absl implementation mirrors the protobuf implementation which supports // durations on the order of +- 10,000 years, but Go only supports +- 290 year // durations. // // Since Go is the more conservative of the implementations and 290 year // durations seem quite reasonable, this code mirrors the conservative // overflow behavior which would be observed in Go. CEL_ASSIGN_OR_RETURN(int64_t nanos, CheckedAdd(absl::ToInt64Nanoseconds(x), absl::ToInt64Nanoseconds(y))); return absl::Nanoseconds(nanos); } absl::StatusOr CheckedSub(absl::Duration x, absl::Duration y) { CEL_RETURN_IF_ERROR( CheckRange(IsFinite(x) && IsFinite(y), "integer overflow")); CEL_ASSIGN_OR_RETURN(int64_t nanos, CheckedSub(absl::ToInt64Nanoseconds(x), absl::ToInt64Nanoseconds(y))); return absl::Nanoseconds(nanos); } absl::StatusOr CheckedNegation(absl::Duration v) { CEL_RETURN_IF_ERROR(CheckRange(IsFinite(v), "integer overflow")); CEL_ASSIGN_OR_RETURN(int64_t nanos, CheckedNegation(absl::ToInt64Nanoseconds(v))); return absl::Nanoseconds(nanos); } absl::StatusOr CheckedAdd(absl::Time t, absl::Duration d) { CEL_RETURN_IF_ERROR( CheckRange(IsFinite(t) && IsFinite(d), "timestamp overflow")); // First we break time into its components by truncating and subtracting. const int64_t s1 = absl::ToUnixSeconds(t); const int64_t ns1 = (t - absl::FromUnixSeconds(s1)) / absl::Nanoseconds(1); // Second we break duration into its components by dividing and modulo. // Truncate to seconds. const int64_t s2 = d / kOneSecondDuration; // Get remainder. const int64_t ns2 = absl::ToInt64Nanoseconds(d % kOneSecondDuration); // Add seconds first, detecting any overflow. CEL_ASSIGN_OR_RETURN(int64_t s, CheckedAdd(s1, s2)); // Nanoseconds cannot overflow as nanos are normalized to [0, 999999999]. absl::Duration ns = absl::Nanoseconds(ns2 + ns1); // Normalize nanoseconds to be positive and carry extra nanos to seconds. if (ns < absl::ZeroDuration() || ns >= kOneSecondDuration) { // Add seconds, or no-op if nanseconds negative (ns never < -999_999_999ns) CEL_ASSIGN_OR_RETURN(s, CheckedAdd(s, ns / kOneSecondDuration)); ns -= (ns / kOneSecondDuration) * kOneSecondDuration; // Subtract a second to make the nanos positive. if (ns < absl::ZeroDuration()) { CEL_ASSIGN_OR_RETURN(s, CheckedAdd(s, -1)); ns += kOneSecondDuration; } } // Check if the the number of seconds from Unix epoch is within our acceptable // range. CEL_RETURN_IF_ERROR( CheckRange(s >= kMinUnixTime && s <= kMaxUnixTime, "timestamp overflow")); // Return resulting time. return absl::FromUnixSeconds(s) + ns; } absl::StatusOr CheckedSub(absl::Time t, absl::Duration d) { CEL_ASSIGN_OR_RETURN(auto neg_duration, CheckedNegation(d)); return CheckedAdd(t, neg_duration); } absl::StatusOr CheckedSub(absl::Time t1, absl::Time t2) { CEL_RETURN_IF_ERROR( CheckRange(IsFinite(t1) && IsFinite(t2), "integer overflow")); // First we break time into its components by truncating and subtracting. const int64_t s1 = absl::ToUnixSeconds(t1); const int64_t ns1 = (t1 - absl::FromUnixSeconds(s1)) / absl::Nanoseconds(1); const int64_t s2 = absl::ToUnixSeconds(t2); const int64_t ns2 = (t2 - absl::FromUnixSeconds(s2)) / absl::Nanoseconds(1); // Subtract seconds first, detecting any overflow. CEL_ASSIGN_OR_RETURN(int64_t s, CheckedSub(s1, s2)); // Nanoseconds cannot overflow as nanos are normalized to [0, 999999999]. absl::Duration ns = absl::Nanoseconds(ns1 - ns2); // Scale the seconds result to nanos. CEL_ASSIGN_OR_RETURN(const int64_t t, CheckedMul(s, kOneSecondNanos)); // Add the seconds (scaled to nanos) to the nanosecond value. CEL_ASSIGN_OR_RETURN(const int64_t v, CheckedAdd(t, absl::ToInt64Nanoseconds(ns))); return absl::Nanoseconds(v); } absl::StatusOr CheckedDoubleToInt64(double v) { CEL_RETURN_IF_ERROR( CheckRange(std::isfinite(v) && v < kDoubleToIntMax && v > kDoubleToIntMin, "double out of int64 range")); return static_cast(v); } absl::StatusOr CheckedDoubleToUint64(double v) { CEL_RETURN_IF_ERROR( CheckRange(std::isfinite(v) && v >= 0 && v < kDoubleTwoTo64, "double out of uint64 range")); return static_cast(v); } absl::StatusOr CheckedInt64ToUint64(int64_t v) { CEL_RETURN_IF_ERROR(CheckRange(v >= 0, "int64 out of uint64 range")); return static_cast(v); } absl::StatusOr CheckedInt64ToInt32(int64_t v) { CEL_RETURN_IF_ERROR( CheckRange(v >= kInt32Min && v <= kInt32Max, "int64 out of int32 range")); return static_cast(v); } absl::StatusOr CheckedUint64ToInt64(uint64_t v) { CEL_RETURN_IF_ERROR( CheckRange(v <= kUintToIntMax, "uint64 out of int64 range")); return static_cast(v); } absl::StatusOr CheckedUint64ToUint32(uint64_t v) { CEL_RETURN_IF_ERROR( CheckRange(v <= kUint32Max, "uint64 out of uint32 range")); return static_cast(v); } } // namespace cel::internal ================================================ FILE: internal/overflow.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_OVERFLOW_H_ #define THIRD_PARTY_CEL_CPP_COMMON_OVERFLOW_H_ #include #include "absl/status/statusor.h" #include "absl/time/time.h" namespace cel::internal { // Add two int64_t values together. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // int64_t_max + 1 absl::StatusOr CheckedAdd(int64_t x, int64_t y); // Subtract two int64_t values from each other. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError. e.g. // int64_t_min - 1 absl::StatusOr CheckedSub(int64_t x, int64_t y); // Negate an int64_t value. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // negate(int64_t_min) absl::StatusOr CheckedNegation(int64_t v); // Multiply two int64_t values together. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError. e.g. // 2 * int64_t_max absl::StatusOr CheckedMul(int64_t x, int64_t y); // Divide one int64_t value into another. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // int64_t_min / -1 absl::StatusOr CheckedDiv(int64_t x, int64_t y); // Compute the modulus of x into y. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // int64_t_min % -1 absl::StatusOr CheckedMod(int64_t x, int64_t y); // Add two uint64_t values together. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // uint64_t_max + 1 absl::StatusOr CheckedAdd(uint64_t x, uint64_t y); // Subtract two uint64_t values from each other. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // 1 - uint64_t_max absl::StatusOr CheckedSub(uint64_t x, uint64_t y); // Multiply two uint64_t values together. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // 2 * uint64_t_max absl::StatusOr CheckedMul(uint64_t x, uint64_t y); // Divide one uint64_t value into another. absl::StatusOr CheckedDiv(uint64_t x, uint64_t y); // Compute the modulus of x into y. // If 'y' is zero, the function will return an // absl::StatusCode::kInvalidArgumentError, e.g. 1 / 0. absl::StatusOr CheckedMod(uint64_t x, uint64_t y); // Add two durations together. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // duration(int64_t_max, "ns") + duration(int64_t_max, "ns") // // Note, absl::Duration is effectively an int64_t under the covers, which means // the same cases that would result in overflow for int64_t values would hold // true for absl::Duration values. absl::StatusOr CheckedAdd(absl::Duration x, absl::Duration y); // Subtract two durations from each other. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // duration(int64_t_min, "ns") - duration(1, "ns") // // Note, absl::Duration is effectively an int64_t under the covers, which means // the same cases that would result in overflow for int64_t values would hold // true for absl::Duration values. absl::StatusOr CheckedSub(absl::Duration x, absl::Duration y); // Negate a duration. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // negate(duration(int64_t_min, "ns")). absl::StatusOr CheckedNegation(absl::Duration v); // Add an absl::Time and absl::Duration value together. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // timestamp(unix_epoch_max) + duration(1, "ns") // // Valid time values must be between `0001-01-01T00:00:00Z` (-62135596800s) and // `9999-12-31T23:59:59.999999999Z` (253402300799s). absl::StatusOr CheckedAdd(absl::Time t, absl::Duration d); // Subtract an absl::Time and absl::Duration value together. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // timestamp(unix_epoch_min) - duration(1, "ns") // // Valid time values must be between `0001-01-01T00:00:00Z` (-62135596800s) and // `9999-12-31T23:59:59.999999999Z` (253402300799s). absl::StatusOr CheckedSub(absl::Time t, absl::Duration d); // Subtract two absl::Time values from each other to produce an absl::Duration. // If overflow is detected, return an absl::StatusCode::kOutOfRangeError, e.g. // timestamp(unix_epoch_min) - timestamp(unix_epoch_max) absl::StatusOr CheckedSub(absl::Time t1, absl::Time t2); // Convert a double value to an int64_t if possible. // If the double exceeds the values representable in an int64_t the function // will return an absl::StatusCode::kOutOfRangeError. // // Only finite double values may be converted to an int64_t. CEL may also reject // some conversions if the value falls into a range where overflow would be // ambiguous. // // The behavior of the static_cast(double) assembly instruction on // x86 (cvttsd2si) can be manipulated by the header: // https://en.cppreference.com/w/cpp/numeric/fenv/feround. This means that the // set of values which will result in a valid or invalid conversion are // environment dependent and the implementation must err on the side of caution // and reject possibly valid values which might be invalid based on environment // settings. absl::StatusOr CheckedDoubleToInt64(double v); // Convert a double value to a uint64_t if possible. // If the double exceeds the values representable in a uint64_t the function // will return an absl::StatusCode::kOutOfRangeError. // // Only finite double values may be converted to a uint64_t. CEL may also reject // some conversions if the value falls into a range where overflow would be // ambiguous. // // The behavior of the static_cast(double) assembly instruction on // x86 (cvttsd2si) can be manipulated by the header: // https://en.cppreference.com/w/cpp/numeric/fenv/feround. This means that the // set of values which will result in a valid or invalid conversion are // environment dependent and the implementation must err on the side of caution // and reject possibly valid values which might be invalid based on environment // settings. absl::StatusOr CheckedDoubleToUint64(double v); // Convert an int64_t value to a uint64_t value if possible. // If the int64_t exceeds the values representable in a uint64_t the function // will return an absl::StatusCode::kOutOfRangeError. absl::StatusOr CheckedInt64ToUint64(int64_t v); // Convert an int64_t value to an int32_t value if possible. // If the int64_t exceeds the values representable in an int32_t the function // will return an absl::StatusCode::kOutOfRangeError. absl::StatusOr CheckedInt64ToInt32(int64_t v); // Convert a uint64_t value to an int64_t value if possible. // If the uint64_t exceeds the values representable in an int64_t the function // will return an absl::StatusCode::kOutOfRangeError. absl::StatusOr CheckedUint64ToInt64(uint64_t v); // Convert a uint64_t value to a uint32_t value if possible. // If the uint64_t exceeds the values representable in a uint32_t the function // will return an absl::StatusCode::kOutOfRangeError. absl::StatusOr CheckedUint64ToUint32(uint64_t v); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_COMMON_OVERFLOW_H_ ================================================ FILE: internal/overflow_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/overflow.h" #include #include #include #include #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/time/time.h" #include "internal/testing.h" namespace cel::internal { namespace { using ::testing::HasSubstr; using ::testing::ValuesIn; template struct TestCase { std::string test_name; absl::FunctionRef()> op; absl::StatusOr result; }; template void ExpectResult(const T& test_case) { auto result = test_case.op(); ASSERT_EQ(result.status().code(), test_case.result.status().code()); if (result.ok()) { EXPECT_EQ(*result, *test_case.result); } else { EXPECT_THAT(result.status().message(), HasSubstr(test_case.result.status().message())); } } using IntTestCase = TestCase; using CheckedIntResultTest = testing::TestWithParam; TEST_P(CheckedIntResultTest, IntOperations) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedIntMathTest, CheckedIntResultTest, ValuesIn(std::vector{ // Addition tests. {"OneAddOne", [] { return CheckedAdd(int64_t{1L}, 1L); }, 2L}, {"ZeroAddOne", [] { return CheckedAdd(int64_t{0}, 1L); }, 1L}, {"ZeroAddMinusOne", [] { return CheckedAdd(int64_t{0}, -1L); }, -1L}, {"OneAddZero", [] { return CheckedAdd(int64_t{1L}, 0); }, 1L}, {"MinusOneAddZero", [] { return CheckedAdd(int64_t{-1L}, 0); }, -1L}, {"OneAddIntMax", [] { return CheckedAdd(int64_t{1L}, std::numeric_limits::max()); }, absl::OutOfRangeError("integer overflow")}, {"MinusOneAddIntMin", [] { return CheckedAdd(int64_t{-1L}, std::numeric_limits::lowest()); }, absl::OutOfRangeError("integer overflow")}, // Subtraction tests. {"TwoSubThree", [] { return CheckedSub(int64_t{2L}, 3L); }, -1L}, {"TwoSubZero", [] { return CheckedSub(int64_t{2L}, 0); }, 2L}, {"ZeroSubTwo", [] { return CheckedSub(int64_t{0}, 2L); }, -2L}, {"MinusTwoSubThree", [] { return CheckedSub(int64_t{-2L}, 3L); }, -5L}, {"MinusTwoSubZero", [] { return CheckedSub(int64_t{-2L}, 0); }, -2L}, {"ZeroSubMinusTwo", [] { return CheckedSub(int64_t{0}, -2L); }, 2L}, {"IntMinSubIntMax", [] { return CheckedSub(std::numeric_limits::max(), std::numeric_limits::lowest()); }, absl::OutOfRangeError("integer overflow")}, // Multiplication tests. {"TwoMulThree", [] { return CheckedMul(int64_t{2L}, 3L); }, 6L}, {"MinusTwoMulThree", [] { return CheckedMul(int64_t{-2L}, 3L); }, -6L}, {"MinusTwoMulMinusThree", [] { return CheckedMul(int64_t{-2L}, -3L); }, 6L}, {"TwoMulMinusThree", [] { return CheckedMul(int64_t{2L}, -3L); }, -6L}, {"TwoMulIntMax", [] { return CheckedMul(int64_t{2L}, std::numeric_limits::max()); }, absl::OutOfRangeError("integer overflow")}, {"MinusOneMulIntMin", [] { return CheckedMul(int64_t{-1L}, std::numeric_limits::lowest()); }, absl::OutOfRangeError("integer overflow")}, {"IntMinMulMinusOne", [] { return CheckedMul(std::numeric_limits::lowest(), int64_t{-1L}); }, absl::OutOfRangeError("integer overflow")}, {"IntMinMulZero", [] { return CheckedMul(std::numeric_limits::lowest(), int64_t{0}); }, 0}, {"ZeroMulIntMin", [] { return CheckedMul(int64_t{0}, std::numeric_limits::lowest()); }, 0}, {"IntMaxMulZero", [] { return CheckedMul(std::numeric_limits::max(), int64_t{0}); }, 0}, {"ZeroMulIntMax", [] { return CheckedMul(int64_t{0}, std::numeric_limits::max()); }, 0}, // Division cases. {"ZeroDivOne", [] { return CheckedDiv(int64_t{0}, 1L); }, 0}, {"TenDivTwo", [] { return CheckedDiv(int64_t{10L}, 2L); }, 5}, {"TenDivMinusOne", [] { return CheckedDiv(int64_t{10L}, -1L); }, -10}, {"MinusTenDivMinusOne", [] { return CheckedDiv(int64_t{-10L}, -1L); }, 10}, {"MinusTenDivTwo", [] { return CheckedDiv(int64_t{-10L}, 2L); }, -5}, {"OneDivZero", [] { return CheckedDiv(int64_t{1L}, 0L); }, absl::InvalidArgumentError("divide by zero")}, {"IntMinDivMinusOne", [] { return CheckedDiv(std::numeric_limits::lowest(), int64_t{-1L}); }, absl::OutOfRangeError("integer overflow")}, // Modulus cases. {"ZeroModTwo", [] { return CheckedMod(int64_t{0}, 2L); }, 0}, {"TwoModTwo", [] { return CheckedMod(int64_t{2L}, 2L); }, 0}, {"ThreeModTwo", [] { return CheckedMod(int64_t{3L}, 2L); }, 1L}, {"TwoModZero", [] { return CheckedMod(int64_t{2L}, 0); }, absl::InvalidArgumentError("modulus by zero")}, {"IntMinModTwo", [] { return CheckedMod(std::numeric_limits::lowest(), int64_t{2L}); }, 0}, {"IntMaxModMinusOne", [] { return CheckedMod(std::numeric_limits::max(), int64_t{-1L}); }, 0}, {"IntMinModMinusOne", [] { return CheckedMod(std::numeric_limits::lowest(), int64_t{-1L}); }, absl::OutOfRangeError("integer overflow")}, // Negation cases. {"NegateOne", [] { return CheckedNegation(int64_t{1L}); }, -1L}, {"NegateMinInt64", [] { return CheckedNegation(std::numeric_limits::lowest()); }, absl::OutOfRangeError("integer overflow")}, // Numeric conversion cases for uint -> int, double -> int {"Uint64Conversion", [] { return CheckedUint64ToInt64(uint64_t{1UL}); }, 1L}, {"Uint32MaxConversion", [] { return CheckedUint64ToInt64( static_cast(std::numeric_limits::max())); }, std::numeric_limits::max()}, {"Uint32MaxConversionError", [] { return CheckedUint64ToInt64( static_cast(std::numeric_limits::max())); }, absl::OutOfRangeError("out of int64 range")}, {"DoubleConversion", [] { return CheckedDoubleToInt64(double{100.1}); }, 100L}, {"DoubleInt64MaxConversionError", [] { return CheckedDoubleToInt64( static_cast(std::numeric_limits::max())); }, absl::OutOfRangeError("out of int64 range")}, {"DoubleInt64MaxMinus512Conversion", [] { return CheckedDoubleToInt64( static_cast(std::numeric_limits::max() - 512)); }, std::numeric_limits::max() - 1023}, {"DoubleInt64MaxMinus1024Conversion", [] { return CheckedDoubleToInt64( static_cast(std::numeric_limits::max() - 1024)); }, std::numeric_limits::max() - 1023}, {"DoubleInt64MinConversionError", [] { return CheckedDoubleToInt64( static_cast(std::numeric_limits::lowest())); }, absl::OutOfRangeError("out of int64 range")}, {"DoubleInt64MinMinusOneConversionError", [] { return CheckedDoubleToInt64( static_cast(std::numeric_limits::lowest()) - 1.0); }, absl::OutOfRangeError("out of int64 range")}, {"DoubleInt64MinMinus511ConversionError", [] { return CheckedDoubleToInt64( static_cast(std::numeric_limits::lowest()) - 511.0); }, absl::OutOfRangeError("out of int64 range")}, {"InfiniteConversionError", [] { return CheckedDoubleToInt64(std::numeric_limits::infinity()); }, absl::OutOfRangeError("out of int64 range")}, {"NegRangeConversionError", [] { return CheckedDoubleToInt64(double{-1.0e99}); }, absl::OutOfRangeError("out of int64 range")}, {"PosRangeConversionError", [] { return CheckedDoubleToInt64(double{1.0e99}); }, absl::OutOfRangeError("out of int64 range")}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; }); using UintTestCase = TestCase; using CheckedUintResultTest = testing::TestWithParam; TEST_P(CheckedUintResultTest, UnsignedOperations) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedUintMathTest, CheckedUintResultTest, ValuesIn(std::vector{ // Addition tests. {"OneAddOne", [] { return CheckedAdd(uint64_t{1UL}, 1UL); }, 2UL}, {"ZeroAddOne", [] { return CheckedAdd(uint64_t{0}, 1UL); }, 1UL}, {"OneAddZero", [] { return CheckedAdd(uint64_t{1UL}, 0); }, 1UL}, {"OneAddIntMax", [] { return CheckedAdd(uint64_t{1UL}, std::numeric_limits::max()); }, absl::OutOfRangeError("unsigned integer overflow")}, // Subtraction tests. {"OneSubOne", [] { return CheckedSub(uint64_t{1UL}, 1UL); }, 0}, {"ZeroSubOne", [] { return CheckedSub(uint64_t{0}, 1UL); }, absl::OutOfRangeError("unsigned integer overflow")}, {"OneSubZero", [] { return CheckedSub(uint64_t{1UL}, 0); }, 1UL}, // Multiplication tests. {"OneMulOne", [] { return CheckedMul(uint64_t{1UL}, 1UL); }, 1UL}, {"ZeroMulOne", [] { return CheckedMul(uint64_t{0}, 1UL); }, 0}, {"OneMulZero", [] { return CheckedMul(uint64_t{1UL}, 0); }, 0}, {"TwoMulUintMax", [] { return CheckedMul(uint64_t{2UL}, std::numeric_limits::max()); }, absl::OutOfRangeError("unsigned integer overflow")}, // Division tests. {"TwoDivTwo", [] { return CheckedDiv(uint64_t{2UL}, 2UL); }, 1UL}, {"TwoDivFour", [] { return CheckedDiv(uint64_t{2UL}, 4UL); }, 0}, {"OneDivZero", [] { return CheckedDiv(uint64_t{1UL}, 0); }, absl::InvalidArgumentError("divide by zero")}, // Modulus tests. {"TwoModTwo", [] { return CheckedMod(uint64_t{2UL}, 2UL); }, 0}, {"TwoModFour", [] { return CheckedMod(uint64_t{2UL}, 4UL); }, 2UL}, {"OneModZero", [] { return CheckedMod(uint64_t{1UL}, 0); }, absl::InvalidArgumentError("modulus by zero")}, // Conversion test cases for int -> uint, double -> uint. {"Int64Conversion", [] { return CheckedInt64ToUint64(int64_t{1L}); }, 1UL}, {"Int64MaxConversion", [] { return CheckedInt64ToUint64(std::numeric_limits::max()); }, static_cast(std::numeric_limits::max())}, {"NegativeInt64ConversionError", [] { return CheckedInt64ToUint64(int64_t{-1L}); }, absl::OutOfRangeError("out of uint64 range")}, {"DoubleConversion", [] { return CheckedDoubleToUint64(double{100.1}); }, 100UL}, {"DoubleUint64MaxConversionError", [] { return CheckedDoubleToUint64( static_cast(std::numeric_limits::max())); }, absl::OutOfRangeError("out of uint64 range")}, {"DoubleUint64MaxMinus512Conversion", [] { return CheckedDoubleToUint64( static_cast(std::numeric_limits::max() - 512)); }, absl::OutOfRangeError("out of uint64 range")}, {"DoubleUint64MaxMinus1024Conversion", [] { return CheckedDoubleToUint64(static_cast( std::numeric_limits::max() - 1024)); }, std::numeric_limits::max() - 2047}, {"InfiniteConversionError", [] { return CheckedDoubleToUint64( std::numeric_limits::infinity()); }, absl::OutOfRangeError("out of uint64 range")}, {"NegConversionError", [] { return CheckedDoubleToUint64(double{-1.1}); }, absl::OutOfRangeError("out of uint64 range")}, {"NegRangeConversionError", [] { return CheckedDoubleToUint64(double{-1.0e99}); }, absl::OutOfRangeError("out of uint64 range")}, {"PosRangeConversionError", [] { return CheckedDoubleToUint64(double{1.0e99}); }, absl::OutOfRangeError("out of uint64 range")}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; }); using DurationTestCase = TestCase; using CheckedDurationResultTest = testing::TestWithParam; TEST_P(CheckedDurationResultTest, DurationOperations) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedDurationMathTest, CheckedDurationResultTest, ValuesIn(std::vector{ // Addition tests. {"OneSecondAddOneSecond", [] { return CheckedAdd(absl::Seconds(1), absl::Seconds(1)); }, absl::Seconds(2)}, {"MaxDurationAddOneNano", [] { return CheckedAdd( absl::Nanoseconds(std::numeric_limits::max()), absl::Nanoseconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"MinDurationAddMinusOneNano", [] { return CheckedAdd( absl::Nanoseconds(std::numeric_limits::lowest()), absl::Nanoseconds(-1)); }, absl::OutOfRangeError("integer overflow")}, {"InfinityAddOneNano", [] { return CheckedAdd(absl::InfiniteDuration(), absl::Nanoseconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"NegInfinityAddOneNano", [] { return CheckedAdd(-absl::InfiniteDuration(), absl::Nanoseconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"OneSecondAddInfinity", [] { return CheckedAdd(absl::Nanoseconds(1), absl::InfiniteDuration()); }, absl::OutOfRangeError("integer overflow")}, {"OneSecondAddNegInfinity", [] { return CheckedAdd(absl::Nanoseconds(1), -absl::InfiniteDuration()); }, absl::OutOfRangeError("integer overflow")}, // Subtraction tests for duration - duration. {"OneSecondSubOneSecond", [] { return CheckedSub(absl::Seconds(1), absl::Seconds(1)); }, absl::ZeroDuration()}, {"MinDurationSubOneSecond", [] { return CheckedSub( absl::Nanoseconds(std::numeric_limits::lowest()), absl::Nanoseconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"InfinitySubOneNano", [] { return CheckedSub(absl::InfiniteDuration(), absl::Nanoseconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"NegInfinitySubOneNano", [] { return CheckedSub(-absl::InfiniteDuration(), absl::Nanoseconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"OneNanoSubInfinity", [] { return CheckedSub(absl::Nanoseconds(1), absl::InfiniteDuration()); }, absl::OutOfRangeError("integer overflow")}, {"OneNanoSubNegInfinity", [] { return CheckedSub(absl::Nanoseconds(1), -absl::InfiniteDuration()); }, absl::OutOfRangeError("integer overflow")}, // Subtraction tests for time - time. {"TimeSubOneSecond", [] { return CheckedSub(absl::FromUnixSeconds(100), absl::FromUnixSeconds(1)); }, absl::Seconds(99)}, {"TimeWithNanosPositive", [] { return CheckedSub(absl::FromUnixSeconds(2) + absl::Nanoseconds(1), absl::FromUnixSeconds(1) - absl::Nanoseconds(1)); }, absl::Seconds(1) + absl::Nanoseconds(2)}, {"TimeWithNanosNegative", [] { return CheckedSub(absl::FromUnixSeconds(1) + absl::Nanoseconds(1), absl::FromUnixSeconds(2) + absl::Seconds(1) - absl::Nanoseconds(1)); }, absl::Seconds(-2) + absl::Nanoseconds(2)}, {"MinTimestampMinusOne", [] { return CheckedSub( absl::FromUnixSeconds(std::numeric_limits::lowest()), absl::FromUnixSeconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"InfinitePastSubOneSecond", [] { return CheckedSub(absl::InfinitePast(), absl::FromUnixSeconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"InfiniteFutureSubOneMinusSecond", [] { return CheckedSub(absl::InfiniteFuture(), absl::FromUnixSeconds(-1)); }, absl::OutOfRangeError("integer overflow")}, {"InfiniteFutureSubInfinitePast", [] { return CheckedSub(absl::InfiniteFuture(), absl::InfinitePast()); }, absl::OutOfRangeError("integer overflow")}, {"InfinitePastSubInfiniteFuture", [] { return CheckedSub(absl::InfinitePast(), absl::InfiniteFuture()); }, absl::OutOfRangeError("integer overflow")}, // Negation cases. {"NegateOneSecond", [] { return CheckedNegation(absl::Seconds(1)); }, absl::Seconds(-1)}, {"NegateMinDuration", [] { return CheckedNegation( absl::Nanoseconds(std::numeric_limits::lowest())); }, absl::OutOfRangeError("integer overflow")}, {"NegateInfiniteDuration", [] { return CheckedNegation(absl::InfiniteDuration()); }, absl::OutOfRangeError("integer overflow")}, {"NegateNegInfiniteDuration", [] { return CheckedNegation(-absl::InfiniteDuration()); }, absl::OutOfRangeError("integer overflow")}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; }); using TimeTestCase = TestCase; using CheckedTimeResultTest = testing::TestWithParam; TEST_P(CheckedTimeResultTest, TimeDurationOperations) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedTimeDurationMathTest, CheckedTimeResultTest, ValuesIn(std::vector{ // Addition tests. {"DateAddOneHourMinusOneMilli", [] { return CheckedAdd(absl::FromUnixSeconds(3506), absl::Hours(1) + absl::Milliseconds(-1)); }, absl::FromUnixSeconds(7106) + absl::Milliseconds(-1)}, {"DateAddOneHourOneNano", [] { return CheckedAdd(absl::FromUnixSeconds(3506), absl::Hours(1) + absl::Nanoseconds(1)); }, absl::FromUnixSeconds(7106) + absl::Nanoseconds(1)}, {"MaxIntAddOneSecond", [] { return CheckedAdd( absl::FromUnixSeconds(std::numeric_limits::max()), absl::Seconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"MaxTimestampAddOneSecond", [] { return CheckedAdd(absl::FromUnixSeconds(253402300799), absl::Seconds(1)); }, absl::OutOfRangeError("timestamp overflow")}, {"TimeWithNanosNegative", [] { return CheckedAdd(absl::FromUnixSeconds(1) + absl::Nanoseconds(1), absl::Nanoseconds(-999999999)); }, absl::FromUnixNanos(2)}, {"TimeWithNanosPositive", [] { return CheckedAdd( absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), absl::Nanoseconds(999999999)); }, absl::FromUnixSeconds(2) + absl::Nanoseconds(999999998)}, {"SecondsAddInfinity", [] { return CheckedAdd( absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), absl::InfiniteDuration()); }, absl::OutOfRangeError("timestamp overflow")}, {"SecondsAddNegativeInfinity", [] { return CheckedAdd( absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), -absl::InfiniteDuration()); }, absl::OutOfRangeError("timestamp overflow")}, {"InfiniteFutureAddNegativeInfinity", [] { return CheckedAdd(absl::InfiniteFuture(), -absl::InfiniteDuration()); }, absl::OutOfRangeError("timestamp overflow")}, {"InfinitePastAddInfinity", [] { return CheckedAdd(absl::InfinitePast(), absl::InfiniteDuration()); }, absl::OutOfRangeError("timestamp overflow")}, // Subtraction tests. {"DateSubOneHour", [] { return CheckedSub(absl::FromUnixSeconds(3506), absl::Hours(1)); }, absl::FromUnixSeconds(-94)}, {"MinTimestampSubOneSecond", [] { return CheckedSub(absl::FromUnixSeconds(-62135596800), absl::Seconds(1)); }, absl::OutOfRangeError("timestamp overflow")}, {"MinIntSubOneViaNanos", [] { return CheckedSub( absl::FromUnixSeconds(std::numeric_limits::min()), absl::Nanoseconds(1)); }, absl::OutOfRangeError("integer overflow")}, {"MinTimestampSubOneViaNanosScaleOverflow", [] { return CheckedSub( absl::FromUnixSeconds(-62135596800) + absl::Nanoseconds(1), absl::Nanoseconds(999999999)); }, absl::OutOfRangeError("timestamp overflow")}, {"SecondsSubInfinity", [] { return CheckedSub( absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), absl::InfiniteDuration()); }, absl::OutOfRangeError("integer overflow")}, {"SecondsSubNegInfinity", [] { return CheckedSub( absl::FromUnixSeconds(1) + absl::Nanoseconds(999999999), -absl::InfiniteDuration()); }, absl::OutOfRangeError("integer overflow")}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; }); using ConvertInt64Int32TestCase = TestCase; using CheckedConvertInt64Int32Test = testing::TestWithParam; TEST_P(CheckedConvertInt64Int32Test, Conversions) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedConvertInt64Int32Test, CheckedConvertInt64Int32Test, ValuesIn(std::vector{ {"SimpleConversion", [] { return CheckedInt64ToInt32(int64_t{1L}); }, 1}, {"Int32MaxConversion", [] { return CheckedInt64ToInt32( static_cast(std::numeric_limits::max())); }, std::numeric_limits::max()}, {"Int32MaxConversionError", [] { return CheckedInt64ToInt32( static_cast(std::numeric_limits::max())); }, absl::OutOfRangeError("out of int32 range")}, {"Int32MinConversion", [] { return CheckedInt64ToInt32( static_cast(std::numeric_limits::lowest())); }, std::numeric_limits::lowest()}, {"Int32MinConversionError", [] { return CheckedInt64ToInt32( static_cast(std::numeric_limits::lowest())); }, absl::OutOfRangeError("out of int32 range")}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; }); using ConvertUint64Uint32TestCase = TestCase; using CheckedConvertUint64Uint32Test = testing::TestWithParam; TEST_P(CheckedConvertUint64Uint32Test, Conversions) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedConvertUint64Uint32Test, CheckedConvertUint64Uint32Test, ValuesIn(std::vector{ {"SimpleConversion", [] { return CheckedUint64ToUint32(uint64_t{1UL}); }, 1U}, {"Uint32MaxConversion", [] { return CheckedUint64ToUint32( static_cast(std::numeric_limits::max())); }, std::numeric_limits::max()}, {"Uint32MaxConversionError", [] { return CheckedUint64ToUint32( static_cast(std::numeric_limits::max())); }, absl::OutOfRangeError("out of uint32 range")}, }), [](const testing::TestParamInfo& info) { return info.param.test_name; }); } // namespace } // namespace cel::internal ================================================ FILE: internal/parse_text_proto.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PARSE_TEXT_PROTO_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_PARSE_TEXT_PROTO_H_ #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/log/die_if_null.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "common/memory.h" #include "internal/message_type_name.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/text_format.h" namespace cel::internal { // `GeneratedParseTextProto` parses the text format protocol buffer message as // the message with the same name as `T`, looked up in the provided descriptor // pool, returning as the generated message. This works regardless of whether // all messages are built with the lite runtime or not. template std::enable_if_t, T* absl_nonnull> GeneratedParseTextProto( google::protobuf::Arena* absl_nonnull arena, absl::string_view text, const google::protobuf::DescriptorPool* absl_nonnull pool = GetTestingDescriptorPool(), google::protobuf::MessageFactory* absl_nonnull factory = GetTestingMessageFactory()) { // Full runtime. const auto* descriptor = ABSL_DIE_IF_NULL( // Crash OK pool->FindMessageTypeByName(MessageTypeNameFor())); const auto* dynamic_message_prototype = ABSL_DIE_IF_NULL(factory->GetPrototype(descriptor)); // Crash OK auto* dynamic_message = dynamic_message_prototype->New(arena); ABSL_CHECK( // Crash OK google::protobuf::TextFormat::ParseFromString(text, dynamic_message)); if (auto* generated_message = google::protobuf::DynamicCastMessage(dynamic_message); generated_message != nullptr) { // Same thing, no need to serialize and parse. return generated_message; } auto* message = google::protobuf::Arena::Create(arena); absl::Cord serialized_message; ABSL_CHECK( // Crash OK dynamic_message->SerializeToCord(&serialized_message)); ABSL_CHECK(message->ParseFromCord(serialized_message)); // Crash OK return message; } // `GeneratedParseTextProto` parses the text format protocol buffer message as // the message with the same name as `T`, looked up in the provided descriptor // pool, returning as the generated message. This works regardless of whether // all messages are built with the lite runtime or not. template std::enable_if_t< std::conjunction_v, std::negation>>, T* absl_nonnull> GeneratedParseTextProto( google::protobuf::Arena* absl_nonnull arena, absl::string_view text, const google::protobuf::DescriptorPool* absl_nonnull pool = GetTestingDescriptorPool(), google::protobuf::MessageFactory* absl_nonnull factory = GetTestingMessageFactory()) { // Lite runtime. const auto* descriptor = ABSL_DIE_IF_NULL( // Crash OK pool->FindMessageTypeByName(MessageTypeNameFor())); const auto* dynamic_message_prototype = ABSL_DIE_IF_NULL(factory->GetPrototype(descriptor)); // Crash OK auto* dynamic_message = dynamic_message_prototype->New(arena); ABSL_CHECK( // Crash OK google::protobuf::TextFormat::ParseFromString(text, dynamic_message)); auto* message = google::protobuf::Arena::Create(arena); absl::Cord serialized_message; ABSL_CHECK( // Crash OK dynamic_message->SerializeToCord(&serialized_message)); ABSL_CHECK(message->ParseFromCord(serialized_message)); // Crash OK return message; } // `DynamicParseTextProto` parses the text format protocol buffer message as the // dynamic message with the same name as `T`, looked up in the provided // descriptor pool, returning the dynamic message. template google::protobuf::Message* absl_nonnull DynamicParseTextProto( google::protobuf::Arena* absl_nonnull arena, absl::string_view text, const google::protobuf::DescriptorPool* absl_nonnull pool = GetTestingDescriptorPool(), google::protobuf::MessageFactory* absl_nonnull factory = GetTestingMessageFactory()) { static_assert(std::is_base_of_v); const auto* descriptor = ABSL_DIE_IF_NULL( // Crash OK pool->FindMessageTypeByName(MessageTypeNameFor())); const auto* dynamic_message_prototype = ABSL_DIE_IF_NULL(factory->GetPrototype(descriptor)); // Crash OK auto* dynamic_message = dynamic_message_prototype->New(arena); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString( // Crash OK text, cel::to_address(dynamic_message))); return dynamic_message; } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_PARSE_TEXT_PROTO_H_ ================================================ FILE: internal/proto_file_util.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_FILE_UTIL_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_FILE_UTIL_H_ #include #include #include #include #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/text_format.h" namespace cel::internal::test { // Reads a binary protobuf message of MessageType from the given path. template absl::Status ReadBinaryProtoFromFile(absl::string_view file_name, MessageType& message) { std::ifstream file; file.open(std::string(file_name), std::fstream::in | std::fstream::binary); if (!file.is_open()) { return absl::NotFoundError(absl::StrFormat("Failed to open file '%s': %s", file_name, strerror(errno))); } if (!message.ParseFromIstream(&file)) { return absl::InvalidArgumentError( absl::StrFormat("Failed to parse proto of type '%s' from file '%s'", message.GetTypeName(), file_name)); } return absl::OkStatus(); } // Reads a text protobuf message of MessageType from the given path. template absl::Status ReadTextProtoFromFile(absl::string_view file_name, MessageType& message) { std::ifstream file; file.open(std::string(file_name), std::fstream::in | std::fstream::binary); if (!file.is_open()) { return absl::NotFoundError(absl::StrFormat("Failed to open file '%s': %s", file_name, strerror(errno))); } google::protobuf::io::IstreamInputStream stream(&file); if (!google::protobuf::TextFormat::Parse(&stream, &message)) { return absl::InvalidArgumentError( absl::StrFormat("Failed to parse proto of type '%s' from file '%s'", message.GetTypeName(), file_name)); } return absl::OkStatus(); } } // namespace cel::internal::test #endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_FILE_UTIL_H_ ================================================ FILE: internal/proto_matchers.h ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_MATCHERS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_MATCHERS_H_ #include #include #include #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/text_format.h" #include "google/protobuf/util/message_differencer.h" namespace cel::internal::test { /** * Simple implementation of a proto matcher comparing string representations. * * IMPORTANT: Only use this for protos whose textual representation is * deterministic (that may not be the case for the map collection type). */ class TextProtoMatcher { public: explicit inline TextProtoMatcher(absl::string_view expected) : expected_(expected) {} bool MatchAndExplain(const google::protobuf::MessageLite& p, ::testing::MatchResultListener* listener) const { return MatchAndExplain(cel::internal::down_cast(p), listener); } bool MatchAndExplain(const google::protobuf::MessageLite* p, ::testing::MatchResultListener* listener) const { return MatchAndExplain(cel::internal::down_cast(p), listener); } bool MatchAndExplain(const google::protobuf::Message& p, ::testing::MatchResultListener* listener) const { auto message = absl::WrapUnique(p.New()); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); return google::protobuf::util::MessageDifferencer::Equals( *message, cel::internal::down_cast(p)); } bool MatchAndExplain(const google::protobuf::Message* p, ::testing::MatchResultListener* listener) const { auto message = absl::WrapUnique(p->New()); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); return google::protobuf::util::MessageDifferencer::Equals( *message, cel::internal::down_cast(*p)); } inline void DescribeTo(::std::ostream* os) const { *os << expected_; } inline void DescribeNegationTo(::std::ostream* os) const { *os << "not equal to expected message: " << expected_; } private: const std::string expected_; }; /** * Simple implementation of a proto matcher comparing string representations. * * IMPORTANT: Only use this for protos whose textual representation is * deterministic (that may not be the case for the map collection type). */ class ProtoMatcher { public: explicit inline ProtoMatcher(const google::protobuf::Message& expected) : expected_(expected.New()) { expected_->CopyFrom(expected); } bool MatchAndExplain(const google::protobuf::MessageLite& p, ::testing::MatchResultListener* listener) const { return MatchAndExplain(cel::internal::down_cast(p), listener); } bool MatchAndExplain(const google::protobuf::MessageLite* p, ::testing::MatchResultListener* listener) const { return MatchAndExplain(cel::internal::down_cast(p), listener); } bool MatchAndExplain(const google::protobuf::Message& p, ::testing::MatchResultListener* /* listener */) const { return google::protobuf::util::MessageDifferencer::Equals(*expected_, p); } bool MatchAndExplain(const google::protobuf::Message* p, ::testing::MatchResultListener* /* listener */) const { return google::protobuf::util::MessageDifferencer::Equals(*expected_, *p); } inline void DescribeTo(::std::ostream* os) const { *os << expected_->DebugString(); } inline void DescribeNegationTo(::std::ostream* os) const { *os << "not equal to expected message: " << expected_->DebugString(); } private: std::shared_ptr expected_; }; // Polymorphic matcher to compare any two protos. inline ::testing::PolymorphicMatcher EqualsProto( absl::string_view x) { return ::testing::MakePolymorphicMatcher(TextProtoMatcher(x)); } // Polymorphic matcher to compare any two protos. inline ::testing::PolymorphicMatcher EqualsProto( const google::protobuf::Message& x) { return ::testing::MakePolymorphicMatcher(ProtoMatcher(x)); } } // namespace cel::internal::test #endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_MATCHERS_H_ ================================================ FILE: internal/proto_time_encoding.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/proto_time_encoding.h" #include #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "internal/status_macros.h" #include "internal/time.h" #include "google/protobuf/util/time_util.h" namespace cel::internal { namespace { absl::Status Validate(absl::Time time) { if (time < cel::internal::MinTimestamp()) { return absl::InvalidArgumentError("time below min"); } if (time > cel::internal::MaxTimestamp()) { return absl::InvalidArgumentError("time above max"); } return absl::OkStatus(); } absl::Status CelValidateDuration(absl::Duration duration) { if (duration < cel::internal::MinDuration()) { return absl::InvalidArgumentError("duration below min"); } if (duration > cel::internal::MaxDuration()) { return absl::InvalidArgumentError("duration above max"); } return absl::OkStatus(); } } // namespace absl::Duration DecodeDuration(const google::protobuf::Duration& proto) { return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); } absl::Time DecodeTime(const google::protobuf::Timestamp& proto) { return absl::FromUnixSeconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); } absl::Status EncodeDuration(absl::Duration duration, google::protobuf::Duration* proto) { CEL_RETURN_IF_ERROR(CelValidateDuration(duration)); // s and n may both be negative, per the Duration proto spec. const int64_t s = absl::IDivDuration(duration, absl::Seconds(1), &duration); const int64_t n = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); proto->set_seconds(s); proto->set_nanos(n); return absl::OkStatus(); } absl::StatusOr EncodeDurationToString(absl::Duration duration) { google::protobuf::Duration d; auto status = EncodeDuration(duration, &d); if (!status.ok()) { return status; } return google::protobuf::util::TimeUtil::ToString(d); } absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto) { CEL_RETURN_IF_ERROR(Validate(time)); const int64_t s = absl::ToUnixSeconds(time); proto->set_seconds(s); proto->set_nanos((time - absl::FromUnixSeconds(s)) / absl::Nanoseconds(1)); return absl::OkStatus(); } absl::StatusOr EncodeTimeToString(absl::Time time) { google::protobuf::Timestamp t; auto status = EncodeTime(time, &t); if (!status.ok()) { return status; } return google::protobuf::util::TimeUtil::ToString(t); } } // namespace cel::internal ================================================ FILE: internal/proto_time_encoding.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Defines basic encode/decode operations for proto time and duration formats. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/time/time.h" namespace cel::internal { /** Helper function to encode a duration in a google::protobuf::Duration. */ absl::Status EncodeDuration(absl::Duration duration, google::protobuf::Duration* proto); /** Helper function to encode an absl::Duration to a JSON-formatted string. */ absl::StatusOr EncodeDurationToString(absl::Duration duration); /** Helper function to encode a time in a google::protobuf::Timestamp. */ absl::Status EncodeTime(absl::Time time, google::protobuf::Timestamp* proto); /** Helper function to encode an absl::Time to a JSON-formatted string. */ absl::StatusOr EncodeTimeToString(absl::Time time); /** Helper function to decode a duration from a google::protobuf::Duration. */ absl::Duration DecodeDuration(const google::protobuf::Duration& proto); /** Helper function to decode a time from a google::protobuf::Timestamp. */ absl::Time DecodeTime(const google::protobuf::Timestamp& proto); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_TIME_ENCODING_H_ ================================================ FILE: internal/proto_time_encoding_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/proto_time_encoding.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/time/time.h" #include "internal/testing.h" #include "testutil/util.h" namespace cel::internal { namespace { using ::google::api::expr::testutil::EqualsProto; TEST(EncodeDuration, Basic) { google::protobuf::Duration proto_duration; ASSERT_OK( EncodeDuration(absl::Seconds(2) + absl::Nanoseconds(3), &proto_duration)); EXPECT_THAT(proto_duration, EqualsProto("seconds: 2 nanos: 3")); } TEST(EncodeDurationToString, Basic) { ASSERT_OK_AND_ASSIGN( std::string json, EncodeDurationToString(absl::Seconds(5) + absl::Nanoseconds(20))); EXPECT_EQ(json, "5.000000020s"); } TEST(EncodeTime, Basic) { google::protobuf::Timestamp proto_timestamp; ASSERT_OK(EncodeTime(absl::FromUnixMillis(300000), &proto_timestamp)); EXPECT_THAT(proto_timestamp, EqualsProto("seconds: 300")); } TEST(EncodeTimeToString, Basic) { ASSERT_OK_AND_ASSIGN(std::string json, EncodeTimeToString(absl::FromUnixMillis(80030))); EXPECT_EQ(json, "1970-01-01T00:01:20.030Z"); } TEST(DecodeDuration, Basic) { google::protobuf::Duration proto_duration; proto_duration.set_seconds(450); proto_duration.set_nanos(4); EXPECT_EQ(DecodeDuration(proto_duration), absl::Seconds(450) + absl::Nanoseconds(4)); } TEST(DecodeTime, Basic) { google::protobuf::Timestamp proto_timestamp; proto_timestamp.set_seconds(450); EXPECT_EQ(DecodeTime(proto_timestamp), absl::FromUnixSeconds(450)); } } // namespace } // namespace cel::internal ================================================ FILE: internal/proto_util.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_UTIL_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_UTIL_H_ #include #include #include "google/protobuf/descriptor.pb.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "google/protobuf/util/message_differencer.h" namespace google { namespace api { namespace expr { namespace internal { template absl::Status ValidateStandardMessageType( const google::protobuf::DescriptorPool& descriptor_pool) { if constexpr (std::is_base_of_v) { const google::protobuf::Descriptor* descriptor = MessageType::descriptor(); const google::protobuf::Descriptor* descriptor_from_pool = descriptor_pool.FindMessageTypeByName(descriptor->full_name()); if (descriptor_from_pool == nullptr) { return absl::NotFoundError( absl::StrFormat("Descriptor '%s' not found in descriptor pool", descriptor->full_name())); } if (descriptor_from_pool == descriptor) { return absl::OkStatus(); } google::protobuf::DescriptorProto descriptor_proto; google::protobuf::DescriptorProto descriptor_from_pool_proto; descriptor->CopyTo(&descriptor_proto); descriptor_from_pool->CopyTo(&descriptor_from_pool_proto); google::protobuf::util::MessageDifferencer descriptor_differencer; std::string differences; descriptor_differencer.ReportDifferencesToString(&differences); // The json_name is a compiler detail and does not change the message // content. It can differ, e.g., between C++ and Go compilers. Hence ignore. const google::protobuf::FieldDescriptor* json_name_field_desc = google::protobuf::FieldDescriptorProto::descriptor()->FindFieldByName( "json_name"); if (json_name_field_desc != nullptr) { descriptor_differencer.IgnoreField(json_name_field_desc); } if (!descriptor_differencer.Compare(descriptor_proto, descriptor_from_pool_proto)) { return absl::FailedPreconditionError(absl::StrFormat( "The descriptor for '%s' in the descriptor pool differs from the " "compiled-in generated version as follows: %s", descriptor->full_name(), differences)); } } else { // Lite runtime. Just verify the message exists. const auto& type_name = MessageType::default_instance().GetTypeName(); const google::protobuf::Descriptor* descriptor_from_pool = descriptor_pool.FindMessageTypeByName(type_name); if (descriptor_from_pool == nullptr) { return absl::NotFoundError(absl::StrFormat( "Descriptor '%s' not found in descriptor pool", type_name)); } } return absl::OkStatus(); } } // namespace internal } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTO_UTIL_H_ ================================================ FILE: internal/proto_util_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/proto_util.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/status/status.h" #include "eval/public/structs/cel_proto_descriptor_pool_builder.h" #include "internal/testing.h" namespace cel::internal { namespace { using google::api::expr::internal::ValidateStandardMessageType; using google::api::expr::runtime::GetStandardMessageTypesFileDescriptorSet; using ::absl_testing::StatusIs; using ::testing::HasSubstr; TEST(ProtoUtil, ValidateStandardMessageTypesRejectsIncompatible) { google::protobuf::DescriptorPool descriptor_pool; google::protobuf::FileDescriptorSet standard_fds = GetStandardMessageTypesFileDescriptorSet(); const google::protobuf::Descriptor* descriptor = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( "google.protobuf.Duration"); ASSERT_NE(descriptor, nullptr); google::protobuf::FileDescriptorProto file_descriptor_proto; descriptor->file()->CopyTo(&file_descriptor_proto); // We emulate a modification by external code that replaced the nanos by a // millis field. google::protobuf::FieldDescriptorProto seconds_desc_proto; google::protobuf::FieldDescriptorProto nanos_desc_proto; descriptor->FindFieldByName("seconds")->CopyTo(&seconds_desc_proto); descriptor->FindFieldByName("nanos")->CopyTo(&nanos_desc_proto); nanos_desc_proto.set_name("millis"); file_descriptor_proto.mutable_message_type(0)->clear_field(); *file_descriptor_proto.mutable_message_type(0)->add_field() = seconds_desc_proto; *file_descriptor_proto.mutable_message_type(0)->add_field() = nanos_desc_proto; descriptor_pool.BuildFile(file_descriptor_proto); EXPECT_THAT( ValidateStandardMessageType(descriptor_pool), StatusIs(absl::StatusCode::kFailedPrecondition, HasSubstr("differs"))); } } // namespace } // namespace cel::internal ================================================ FILE: internal/protobuf_runtime_version.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_PROTOBUF_VERSION_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_PROTOBUF_VERSION_H_ #ifdef __has_include #if __has_include("third_party/protobuf/runtime_version.h") #include "google/protobuf/runtime_version.h" // IWYU pragma: keep #endif #endif #ifdef PROTOBUF_OSS_VERSION #define CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(major, minor, patch) \ ((major) * 1000000 + (minor) * 1000 + (patch) <= PROTOBUF_OSS_VERSION) #else // Older versions of protobuf did not have the macro. #define CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(major, minor, patch) 0 #endif #endif // THIRD_PARTY_CEL_CPP_INTERNAL_PROTOBUF_VERSION_H_ ================================================ FILE: internal/re2_options.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "re2/re2.h" namespace cel::internal { inline RE2::Options MakeRE2Options() { RE2::Options options; options.set_log_errors(false); return options; } inline absl::Status CheckRE2(const RE2& re, int max_program_size) { if (!re.ok()) { switch (re.error_code()) { case RE2::ErrorInternal: return absl::InternalError( absl::StrCat("internal RE2 error: ", re.error())); case RE2::ErrorPatternTooLarge: return absl::InvalidArgumentError( absl::StrCat("regular expression too large: ", re.error())); default: return absl::InvalidArgumentError( absl::StrCat("invalid regular expression: ", re.error())); } } int program_size = re.ProgramSize(); if (max_program_size > 0 && program_size > 0 && program_size > max_program_size) { return absl::InvalidArgumentError( "regular expression exceeds max allowed size"); } int reverse_program_size = re.ReverseProgramSize(); if (max_program_size > 0 && reverse_program_size > 0 && reverse_program_size > max_program_size) { return absl::InvalidArgumentError( "regular expression exceeds max allowed size"); } return absl::OkStatus(); } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_RE2_OPTIONS_H_ ================================================ FILE: internal/status_builder.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_BUILDER_H_ #include #include #include "absl/base/attributes.h" #include "absl/status/status.h" namespace cel::internal { class StatusBuilder; template inline constexpr bool StatusBuilderResultMatches = std::is_same_v>, Expected>; template using StatusBuilderPurePolicy = std::enable_if_t< StatusBuilderResultMatches, std::invoke_result_t>; template using StatusBuilderSideEffect = std::enable_if_t, std::invoke_result_t>; template using StatusBuilderConversion = std::enable_if_t< !StatusBuilderResultMatches && !StatusBuilderResultMatches, std::invoke_result_t>; class StatusBuilder final { public: StatusBuilder() = default; explicit StatusBuilder(const absl::Status& status) : status_(status) {} StatusBuilder(const StatusBuilder&) = default; StatusBuilder(StatusBuilder&&) = default; ~StatusBuilder() = default; StatusBuilder& operator=(const StatusBuilder&) = default; StatusBuilder& operator=(StatusBuilder&&) = default; bool ok() const { return status_.ok(); } absl::StatusCode code() const { return status_.code(); } operator absl::Status() const& { return status_; } // NOLINT operator absl::Status() && { return std::move(status_); } // NOLINT template auto With( Adaptor&& adaptor) & -> StatusBuilderPurePolicy { return std::forward(adaptor)(*this); } template ABSL_MUST_USE_RESULT auto With( Adaptor&& adaptor) && -> StatusBuilderPurePolicy { return std::forward(adaptor)(std::move(*this)); } template auto With( Adaptor&& adaptor) & -> StatusBuilderSideEffect { return std::forward(adaptor)(*this); } template ABSL_MUST_USE_RESULT auto With( Adaptor&& adaptor) && -> StatusBuilderSideEffect { return std::forward(adaptor)(std::move(*this)); } template auto With( Adaptor&& adaptor) & -> StatusBuilderConversion { return std::forward(adaptor)(*this); } template ABSL_MUST_USE_RESULT auto With( Adaptor&& adaptor) && -> StatusBuilderConversion { return std::forward(adaptor)(std::move(*this)); } private: absl::Status status_; }; } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_BUILDER_H_ ================================================ FILE: internal/status_macros.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_MACROS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_MACROS_H_ #include #include "absl/base/optimization.h" #include "absl/status/status.h" #include "internal/status_builder.h" #define CEL_RETURN_IF_ERROR(expr) \ CEL_INTERNAL_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ if (::cel::internal::StatusAdaptor cel_internal_status_macro = {(expr)}) { \ } else /* NOLINT */ \ return cel_internal_status_macro.Consume() // The GNU compiler historically emitted warnings for obscure usages of // `if (foo) if (bar) {} else`. This suppresses that. // clang-format off #define CEL_INTERNAL_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ switch (0) case 0: default: /* NOLINT */ // clang-format on #define CEL_ASSIGN_OR_RETURN(...) \ CEL_INTERNAL_STATUS_MACROS_GET_VARIADIC_( \ (__VA_ARGS__, CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_3_, \ CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_2_)) \ (__VA_ARGS__) // The following are macro magic to select either the 2 arg variant or 3 arg // variant of CEL_ASSIGN_OR_RETURN. #define CEL_INTERNAL_STATUS_MACROS_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) \ NAME #define CEL_INTERNAL_STATUS_MACROS_GET_VARIADIC_(args) \ CEL_INTERNAL_STATUS_MACROS_GET_VARIADIC_HELPER_ args #define CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_( \ CEL_INTERNAL_STATUS_MACROS_CONCAT(_status_or_value, __LINE__), lhs, \ rexpr, \ return absl::Status(std::move(CEL_INTERNAL_STATUS_MACROS_CONCAT( \ _status_or_value, __LINE__)) \ .status())) #define CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_3_(lhs, rexpr, \ error_expression) \ CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_( \ CEL_INTERNAL_STATUS_MACROS_CONCAT(_status_or_value, __LINE__), lhs, \ rexpr, \ ::cel::internal::StatusBuilder _( \ std::move( \ CEL_INTERNAL_STATUS_MACROS_CONCAT(_status_or_value, __LINE__)) \ .status()); \ (void)_; /* error_expression is allowed to not use this variable */ \ return (error_expression)) // Common implementation of CEL_ASSIGN_OR_RETURN. Both the 2 arg variant and 3 // arg variant are implemented by this macro. #define CEL_INTERNAL_STATUS_MACROS_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ error_expression) \ auto statusor = (rexpr); \ if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ error_expression; \ } \ CEL_INTERNAL_STATUS_MACROS_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \ std::move(statusor).value() #define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER(...) \ CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER_HELPER((__VA_ARGS__, 0, 1)) // MSVC historically expands variadic macros incorrectly, so another level of // indirection is required. #define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER_HELPER(args) \ CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER_I args #define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) \ is_empty #define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY(...) \ CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_I(__VA_ARGS__) #define CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_I(...) \ CEL_INTERNAL_STATUS_MACROS_IS_EMPTY_INNER(_, ##__VA_ARGS__) #define CEL_INTERNAL_STATUS_MACROS_IF_1(_Then, _Else) _Then #define CEL_INTERNAL_STATUS_MACROS_IF_0(_Then, _Else) _Else #define CEL_INTERNAL_STATUS_MACROS_IF(_Cond, _Then, _Else) \ CEL_INTERNAL_STATUS_MACROS_CONCAT(CEL_INTERNAL_STATUS_MACROS_IF_, _Cond) \ (_Then, _Else) #define CEL_INTERNAL_STATUS_MACROS_EAT(...) #define CEL_INTERNAL_STATUS_MACROS_REM(...) __VA_ARGS__ #define CEL_INTERNAL_STATUS_MACROS_EMPTY() // Expands to 1 if the input is surrounded by parenthesis, 0 otherwise. #define CEL_INTERNAL_STATUS_MACROS_IS_PARENTHESIZED(...) \ CEL_INTERNAL_STATUS_MACROS_IS_EMPTY( \ CEL_INTERNAL_STATUS_MACROS_EAT __VA_ARGS__) // If the input is surrounded by parenthesis, remove them. Otherwise expand it // unchanged. #define CEL_INTERNAL_STATUS_MACROS_UNPARENTHESIZE_IF_PARENTHESIZED(...) \ CEL_INTERNAL_STATUS_MACROS_IF( \ CEL_INTERNAL_STATUS_MACROS_IS_PARENTHESIZED(__VA_ARGS__), \ CEL_INTERNAL_STATUS_MACROS_REM, CEL_INTERNAL_STATUS_MACROS_EMPTY()) \ __VA_ARGS__ #define CEL_INTERNAL_STATUS_MACROS_CONCAT_HELPER(x, y) x##y #define CEL_INTERNAL_STATUS_MACROS_CONCAT(x, y) \ CEL_INTERNAL_STATUS_MACROS_CONCAT_HELPER(x, y) namespace cel::internal { class StatusAdaptor final { public: StatusAdaptor() = default; StatusAdaptor(const StatusAdaptor&) = delete; StatusAdaptor(StatusAdaptor&&) = delete; StatusAdaptor(const absl::Status& status) : builder_(status) {} // NOLINT StatusAdaptor& operator=(const StatusAdaptor&) = delete; StatusAdaptor& operator=(StatusAdaptor&&) = delete; StatusBuilder&& Consume() { return std::move(builder_); } explicit operator bool() const { return ABSL_PREDICT_TRUE(builder_.ok()); } private: StatusBuilder builder_; }; } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_STATUS_MACROS_H_ ================================================ FILE: internal/string_pool.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/string_pool.h" #include #include #include #include "absl/base/optimization.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "google/protobuf/arena.h" namespace cel::internal { absl::string_view StringPool::InternString(absl::string_view string) { if (string.empty()) { return ""; } return *strings_.lazy_emplace(string, [&](const auto& ctor) { char* data = reinterpret_cast(arena()->AllocateAligned(string.size())); std::memcpy(data, string.data(), string.size()); ctor(absl::string_view(data, string.size())); }); } absl::string_view StringPool::InternString(std::string&& string) { if (string.empty()) { return ""; } return *strings_.lazy_emplace(string, [&](const auto& ctor) { if (string.size() <= sizeof(std::string)) { char* data = reinterpret_cast(arena()->AllocateAligned(string.size())); std::memcpy(data, string.data(), string.size()); ctor(absl::string_view(data, string.size())); } else { google::protobuf::Arena* arena = this->arena(); ABSL_ASSUME(arena != nullptr); ctor(absl::string_view( *google::protobuf::Arena::Create(arena, std::move(string)))); } }); } absl::string_view StringPool::InternString(const absl::Cord& string) { if (string.empty()) { return ""; } return *strings_.lazy_emplace(string, [&](const auto& ctor) { char* data = reinterpret_cast(arena()->AllocateAligned(string.size())); absl::Cord::CharIterator string_begin = string.char_begin(); const absl::Cord::CharIterator string_end = string.char_end(); char* p = data; while (string_begin != string_end) { absl::string_view chunk = absl::Cord::ChunkRemaining(string_begin); std::memcpy(p, chunk.data(), chunk.size()); p += chunk.size(); absl::Cord::Advance(&string_begin, chunk.size()); } ctor(absl::string_view(data, string.size())); }); } } // namespace cel::internal ================================================ FILE: internal/string_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STRING_POOL_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_STRING_POOL_H_ #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/log/die_if_null.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "google/protobuf/arena.h" namespace cel::internal { // `StringPool` efficiently performs string interning using `google::protobuf::Arena`. // // This class is thread compatible, but typically requires external // synchronization or serial usage. class StringPool final { public: explicit StringPool( google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND) : arena_(ABSL_DIE_IF_NULL(arena)) {} // Crash OK google::protobuf::Arena* absl_nonnull arena() const { return arena_; } absl::string_view InternString(const char* absl_nullable string) { return InternString(absl::NullSafeStringView(string)); } absl::string_view InternString(absl::string_view string); absl::string_view InternString(std::string&& string); absl::string_view InternString(const absl::Cord& string); private: google::protobuf::Arena* absl_nonnull const arena_; absl::flat_hash_set strings_; }; } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_STRING_POOL_H_ ================================================ FILE: internal/string_pool_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/string_pool.h" #include "absl/strings/string_view.h" #include "internal/testing.h" #include "google/protobuf/arena.h" namespace cel::internal { namespace { TEST(StringPool, EmptyString) { google::protobuf::Arena arena; StringPool string_pool(&arena); absl::string_view interned_string = string_pool.InternString(""); EXPECT_EQ(interned_string.data(), string_pool.InternString("").data()); } TEST(StringPool, InternString) { google::protobuf::Arena arena; StringPool string_pool(&arena); absl::string_view interned_string = string_pool.InternString("Hello, world!"); EXPECT_EQ(interned_string.data(), string_pool.InternString("Hello, world!").data()); } } // namespace } // namespace cel::internal ================================================ FILE: internal/strings.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/strings.h" #include #include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/strings/ascii.h" #include "absl/strings/cord.h" #include "absl/strings/escaping.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "internal/lexis.h" #include "internal/unicode.h" #include "internal/utf8.h" namespace cel::internal { namespace { constexpr char kHexTable[] = "0123456789abcdef"; constexpr int HexDigitToInt(char x) { if (x > '9') { x += 9; } return x & 0xf; } constexpr bool IsOctalDigit(char x) { return x >= '0' && x <= '7'; } // Returns true when following conditions are met: // - is a suffix of . // - No other unescaped occurrence of inside (apart from // being a suffix). // Returns false otherwise. If is non-NULL, returns an error message in // . If is non-NULL, returns the offset in that // corresponds to the location of the error. bool CheckForClosingString(absl::string_view source, absl::string_view closing_str, std::string* error) { if (closing_str.empty()) return true; const char* p = source.data(); const char* end = p + source.size(); bool is_closed = false; while (p + closing_str.length() <= end) { if (*p != '\\') { size_t cur_pos = p - source.data(); bool is_closing = absl::StartsWith(absl::ClippedSubstr(source, cur_pos), closing_str); if (is_closing && p + closing_str.length() < end) { if (error) { *error = absl::StrCat("String cannot contain unescaped ", closing_str); } return false; } is_closed = is_closing && (p + closing_str.length() == end); } else { p++; // Read past the escaped character. } p++; } if (!is_closed) { if (error) { *error = absl::StrCat("String must end with ", closing_str); } return false; } return true; } // ---------------------------------------------------------------------- // CUnescapeInternal() // Unescapes C escape sequences and is the reverse of CEscape(). // // If 'source' is valid, stores the unescaped string and its size in // 'dest' and 'dest_len' respectively, and returns true. Otherwise // returns false and optionally stores the error description in // 'error' and the error offset in 'error_offset'. If 'error' is // nonempty on return, 'error_offset' is in range [0, str.size()]. // Set 'error' and 'error_offset' to NULL to disable error reporting. // // 'dest' must point to a buffer that is at least as big as 'source'. The // unescaped string cannot grow bigger than the source string since no // unescaped sequence is longer than the corresponding escape sequence. // 'source' and 'dest' must not be the same. // // If is non-empty, for to be valid: // - It must end with . // - Should not contain any other unescaped occurrence of . // ---------------------------------------------------------------------- bool UnescapeInternal(absl::string_view source, absl::string_view closing_str, bool is_raw_literal, bool is_bytes_literal, std::string* dest, std::string* error) { if (!CheckForClosingString(source, closing_str, error)) { return false; } if (ABSL_PREDICT_FALSE(source.empty())) { *dest = std::string(); return true; } // Strip off the closing_str from the end before unescaping. source = source.substr(0, source.size() - closing_str.size()); if (!is_bytes_literal) { if (!Utf8IsValid(source)) { if (error) { *error = absl::StrCat("Structurally invalid UTF8 string: ", EscapeBytes(source)); } return false; } } dest->reserve(source.size()); const char* p = source.data(); const char* end = p + source.size(); const char* last_byte = end - 1; while (p < end) { if (*p != '\\') { if (*p != '\r') { dest->push_back(*p++); } else { // All types of newlines in different platforms i.e. '\r', '\n', '\r\n' // are replaced with '\n'. dest->push_back('\n'); p++; if (p < end && *p == '\n') { p++; } } } else { if ((p + 1) > last_byte) { if (error) { *error = is_raw_literal ? "Raw literals cannot end with odd number of \\" : is_bytes_literal ? "Bytes literal cannot end with \\" : "String literal cannot end with \\"; } return false; } if (is_raw_literal) { // For raw literals, all escapes are valid and those characters ('\\' // and the escaped character) come through literally in the string. dest->push_back(*p++); dest->push_back(*p++); continue; } // Any error that occurs in the escape is accounted to the start of // the escape. p++; // Read past the escape character. switch (*p) { case 'a': dest->push_back('\a'); break; case 'b': dest->push_back('\b'); break; case 'f': dest->push_back('\f'); break; case 'n': dest->push_back('\n'); break; case 'r': dest->push_back('\r'); break; case 't': dest->push_back('\t'); break; case 'v': dest->push_back('\v'); break; case '\\': dest->push_back('\\'); break; case '?': dest->push_back('\?'); break; // \? Who knew? case '\'': dest->push_back('\''); break; case '"': dest->push_back('\"'); break; case '`': dest->push_back('`'); break; case '0': ABSL_FALLTHROUGH_INTENDED; case '1': ABSL_FALLTHROUGH_INTENDED; case '2': ABSL_FALLTHROUGH_INTENDED; case '3': { // Octal escape '\ddd': requires exactly 3 octal digits. Note that // the highest valid escape sequence is '\377'. // For string literals, octal and hex escape sequences are interpreted // as unicode code points, and the related UTF8-encoded character is // added to the destination. For bytes literals, octal and hex // escape sequences are interpreted as a single byte value. const char* octal_start = p; if (p + 2 >= end) { if (error) { *error = "Illegal escape sequence: Octal escape must be followed by 3 " "octal digits but saw: \\" + std::string(octal_start, end - p); } // Error offset was set to the start of the escape above the switch. return false; } const char* octal_end = p + 2; char32_t ch = 0; for (; p <= octal_end; ++p) { if (IsOctalDigit(*p)) { ch = ch * 8 + *p - '0'; } else { if (error) { *error = "Illegal escape sequence: Octal escape must be followed by " "3 octal digits but saw: \\" + std::string(octal_start, 3); } // Error offset was set to the start of the escape above the // switch. return false; } } p = octal_end; // p points at last digit. if (is_bytes_literal) { dest->push_back(static_cast(ch)); } else { Utf8Encode(*dest, ch); } break; } case 'x': ABSL_FALLTHROUGH_INTENDED; case 'X': { // Hex escape '\xhh': requires exactly 2 hex digits. // For string literals, octal and hex escape sequences are // interpreted as unicode code points, and the related UTF8-encoded // character is added to the destination. For bytes literals, octal // and hex escape sequences are interpreted as a single byte value. const char* hex_start = p; if (p + 2 >= end) { if (error) { *error = "Illegal escape sequence: Hex escape must be followed by 2 " "hex digits but saw: \\" + std::string(hex_start, end - p); } // Error offset was set to the start of the escape above the switch. return false; } char32_t ch = 0; const char* hex_end = p + 2; for (++p; p <= hex_end; ++p) { if (absl::ascii_isxdigit(*p)) { ch = (ch << 4) + HexDigitToInt(*p); } else { if (error) { *error = "Illegal escape sequence: Hex escape must be followed by 2 " "hex digits but saw: \\" + std::string(hex_start, 3); } // Error offset was set to the start of the escape above the // switch. return false; } } p = hex_end; // p points at last digit. if (is_bytes_literal) { dest->push_back(static_cast(ch)); } else { Utf8Encode(*dest, ch); } break; } case 'u': { if (is_bytes_literal) { if (error) { *error = std::string( "Illegal escape sequence: Unicode escape sequence \\") + *p + " cannot be used in bytes literals"; } // Error offset was set to the start of the escape above the switch. return false; } // \uhhhh => Read 4 hex digits as a code point, // then write it as UTF-8 bytes. char32_t cp = 0; const char* hex_start = p; if (p + 4 >= end) { if (error) { *error = "Illegal escape sequence: \\u must be followed by 4 hex " "digits but saw: \\" + std::string(hex_start, end - p); } // Error offset was set to the start of the escape above the switch. return false; } for (int i = 0; i < 4; ++i) { // Look one char ahead. if (absl::ascii_isxdigit(p[1])) { cp = (cp << 4) + HexDigitToInt(*++p); // Advance p. } else { if (error) { *error = "Illegal escape sequence: \\u must be followed by 4 " "hex digits but saw: \\" + std::string(hex_start, 5); } // Error offset was set to the start of the escape above the // switch. return false; } } if (!UnicodeIsValid(cp)) { if (error) { *error = "Illegal escape sequence: Unicode value \\" + std::string(hex_start, 5) + " is invalid"; } // Error offset was set to the start of the escape above the switch. return false; } Utf8Encode(*dest, cp); break; } case 'U': { if (is_bytes_literal) { if (error) { *error = std::string( "Illegal escape sequence: Unicode escape sequence \\") + *p + " cannot be used in bytes literals"; } return false; } // \Uhhhhhhhh => convert 8 hex digits to UTF-8. Note that the // first two digits must be 00: The valid range is // '\U00000000' to '\U0010FFFF' (excluding surrogates). char32_t cp = 0; const char* hex_start = p; if (p + 8 >= end) { if (error) { *error = "Illegal escape sequence: \\U must be followed by 8 hex " "digits but saw: \\" + std::string(hex_start, end - p); } // Error offset was set to the start of the escape above the switch. return false; } for (int i = 0; i < 8; ++i) { // Look one char ahead. if (absl::ascii_isxdigit(p[1])) { cp = (cp << 4) + HexDigitToInt(*++p); if (cp > 0x10FFFF) { if (error) { *error = "Illegal escape sequence: Value of \\" + std::string(hex_start, 9) + " exceeds Unicode limit (0x0010FFFF)"; } // Error offset was set to the start of the escape above the // switch. return false; } } else { if (error) { *error = "Illegal escape sequence: \\U must be followed by 8 " "hex digits but saw: \\" + std::string(hex_start, 9); } // Error offset was set to the start of the escape above the // switch. return false; } } if (!UnicodeIsValid(cp)) { if (error) { *error = "Illegal escape sequence: Unicode value \\" + std::string(hex_start, 9) + " is invalid"; } // Error offset was set to the start of the escape above the switch. return false; } Utf8Encode(*dest, cp); break; } case '\r': ABSL_FALLTHROUGH_INTENDED; case '\n': { if (error) { *error = "Illegal escaped newline"; } // Error offset was set to the start of the escape above the switch. return false; } default: { if (error) { *error = std::string("Illegal escape sequence: \\") + *p; } // Error offset was set to the start of the escape above the switch. return false; } } p++; // read past letter we escaped } } dest->shrink_to_fit(); return true; } std::string EscapeInternal(absl::string_view src, bool escape_all_bytes, char escape_quote_char) { std::string dest; // Worst case size is every byte has to be hex escaped, so 4 char for every // byte. dest.reserve(src.size() * 4); bool last_hex_escape = false; // true if last output char was \xNN. const char* p = src.data(); const char* end = p + src.size(); for (; p < end; ++p) { unsigned char c = static_cast(*p); bool is_hex_escape = false; switch (c) { case '\n': dest.append("\\n"); break; case '\r': dest.append("\\r"); break; case '\t': dest.append("\\t"); break; case '\\': dest.append("\\\\"); break; case '\'': ABSL_FALLTHROUGH_INTENDED; case '\"': ABSL_FALLTHROUGH_INTENDED; case '`': // Escape only quote chars that match escape_quote_char. if (escape_quote_char == 0 || c == escape_quote_char) { dest.push_back('\\'); } dest.push_back(c); break; default: // Note that if we emit \xNN and the src character after that is a hex // digit then that digit must be escaped too to prevent it being // interpreted as part of the character code by C. if ((!escape_all_bytes || c < 0x80) && (!absl::ascii_isprint(c) || (last_hex_escape && absl::ascii_isxdigit(c)))) { dest.append("\\x"); dest.push_back(kHexTable[c / 16]); dest.push_back(kHexTable[c % 16]); is_hex_escape = true; } else { dest.push_back(c); break; } } last_hex_escape = is_hex_escape; } dest.shrink_to_fit(); return dest; } bool MayBeTripleQuotedString(absl::string_view str) { return (str.size() >= 6 && ((absl::StartsWith(str, "\"\"\"") && absl::EndsWith(str, "\"\"\"")) || (absl::StartsWith(str, "'''") && absl::EndsWith(str, "'''")))); } bool MayBeStringLiteral(absl::string_view str) { return (str.size() >= 2 && str[0] == str[str.size() - 1] && (str[0] == '\'' || str[0] == '"')); } bool MayBeBytesLiteral(absl::string_view str) { return (str.size() >= 3 && absl::StartsWithIgnoreCase(str, "b") && str[1] == str[str.size() - 1] && (str[1] == '\'' || str[1] == '"')); } bool MayBeRawStringLiteral(absl::string_view str) { return (str.size() >= 3 && absl::StartsWithIgnoreCase(str, "r") && str[1] == str[str.size() - 1] && (str[1] == '\'' || str[1] == '"')); } bool MayBeRawBytesLiteral(absl::string_view str) { return (str.size() >= 4 && (absl::StartsWithIgnoreCase(str, "rb") || absl::StartsWithIgnoreCase(str, "br")) && (str[2] == str[str.size() - 1]) && (str[2] == '\'' || str[2] == '"')); } } // namespace absl::StatusOr UnescapeString(absl::string_view str) { std::string out; std::string error; if (!UnescapeInternal(str, "", false, false, &out, &error)) { return absl::InvalidArgumentError( absl::StrCat("Invalid escaped string: ", error)); } return out; } absl::StatusOr UnescapeBytes(absl::string_view str) { std::string out; std::string error; if (!UnescapeInternal(str, "", false, true, &out, &error)) { return absl::InvalidArgumentError( absl::StrCat("Invalid escaped bytes: ", error)); } return out; } std::string EscapeString(absl::string_view str) { return EscapeInternal(str, true, '\0'); } std::string EscapeBytes(absl::string_view str, bool escape_all_bytes, char escape_quote_char) { std::string escaped_bytes; const char* p = str.data(); const char* end = p + str.size(); for (; p < end; ++p) { unsigned char c = *p; if (escape_all_bytes || !absl::ascii_isprint(c)) { escaped_bytes += "\\x"; escaped_bytes += absl::BytesToHexString(absl::string_view(p, 1)); } else { switch (c) { // Note that we only handle printable escape characters here. All // unprintable (\n, \r, \t, etc.) are hex escaped above. case '\\': escaped_bytes += "\\\\"; break; case '\'': case '"': case '`': // Escape only quote chars that match escape_quote_char. if (escape_quote_char == 0 || c == escape_quote_char) { escaped_bytes += '\\'; } escaped_bytes += c; break; default: escaped_bytes += c; break; } } } return escaped_bytes; } absl::StatusOr ParseStringLiteral(absl::string_view str) { std::string out; bool is_string_literal = MayBeStringLiteral(str); bool is_raw_string_literal = MayBeRawStringLiteral(str); if (!is_string_literal && !is_raw_string_literal) { return absl::InvalidArgumentError("Invalid string literal"); } absl::string_view copy_str = str; if (is_raw_string_literal) { // Strip off the prefix 'r' from the raw string content before parsing. copy_str = absl::ClippedSubstr(copy_str, 1); } bool is_triple_quoted = MayBeTripleQuotedString(copy_str); // Starts after the opening quotes {""", '''} or {", '}. int quotes_length = is_triple_quoted ? 3 : 1; absl::string_view quotes = copy_str.substr(0, quotes_length); copy_str = absl::ClippedSubstr(copy_str, quotes_length); std::string error; if (!UnescapeInternal(copy_str, quotes, is_raw_string_literal, false, &out, &error)) { return absl::InvalidArgumentError( absl::StrCat("Invalid string literal: ", error)); } return out; } absl::StatusOr ParseBytesLiteral(absl::string_view str) { std::string out; bool is_bytes_literal = MayBeBytesLiteral(str); bool is_raw_bytes_literal = MayBeRawBytesLiteral(str); if (!is_bytes_literal && !is_raw_bytes_literal) { return absl::InvalidArgumentError("Invalid bytes literal"); } absl::string_view copy_str = str; if (is_raw_bytes_literal) { // Strip off the prefix {"rb", "br"} from the raw bytes content before copy_str = absl::ClippedSubstr(copy_str, 2); } else { // Strip off the prefix 'b' from the bytes content before parsing. copy_str = absl::ClippedSubstr(copy_str, 1); } bool is_triple_quoted = MayBeTripleQuotedString(copy_str); // Starts after the opening quotes {""", '''} or {", '}. int quotes_length = is_triple_quoted ? 3 : 1; absl::string_view quotes = copy_str.substr(0, quotes_length); // Includes the closing quotes. copy_str = absl::ClippedSubstr(copy_str, quotes_length); std::string error; if (!UnescapeInternal(copy_str, quotes, is_raw_bytes_literal, true, &out, &error)) { return absl::InvalidArgumentError( absl::StrCat("Invalid bytes literal: ", error)); } return out; } std::string FormatStringLiteral(absl::string_view str) { absl::string_view quote = (str.find('"') != str.npos && str.find('\'') == str.npos) ? "'" : "\""; return absl::StrCat(quote, EscapeInternal(str, true, quote[0]), quote); } std::string FormatStringLiteral(const absl::Cord& str) { if (auto flat = str.TryFlat(); flat) { return FormatStringLiteral(*flat); } return FormatStringLiteral(static_cast(str)); } std::string FormatSingleQuotedStringLiteral(absl::string_view str) { return absl::StrCat("'", EscapeInternal(str, true, '\''), "'"); } std::string FormatDoubleQuotedStringLiteral(absl::string_view str) { return absl::StrCat("\"", EscapeInternal(str, true, '"'), "\""); } std::string FormatBytesLiteral(absl::string_view str) { absl::string_view quote = (str.find('"') != str.npos && str.find('\'') == str.npos) ? "'" : "\""; return absl::StrCat("b", quote, EscapeBytes(str, false, quote[0]), quote); } std::string FormatSingleQuotedBytesLiteral(absl::string_view str) { return absl::StrCat("b'", EscapeBytes(str, false, '\''), "'"); } std::string FormatDoubleQuotedBytesLiteral(absl::string_view str) { return absl::StrCat("b\"", EscapeBytes(str, false, '"'), "\""); } absl::StatusOr ParseIdentifier(absl::string_view str) { if (!LexisIsIdentifier(str)) { return absl::InvalidArgumentError("Invalid identifier"); } return std::string(str); } } // namespace cel::internal ================================================ FILE: internal/strings.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_STRINGS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_STRINGS_H_ #include #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" namespace cel::internal { // Expand escaped characters according to CEL escaping rules. // This is for raw strings with no quoting. absl::StatusOr UnescapeString(absl::string_view str); // Expand escaped characters according to CEL escaping rules. // Rules for bytes values are slightly different than those for strings. This // is for raw literals with no quoting. absl::StatusOr UnescapeBytes(absl::string_view str); // Escape a string without quoting it. All quote characters are escaped. std::string EscapeString(absl::string_view str); // Escape a bytes value without quoting it. Escaped bytes use hex escapes. // If is true then all bytes are escaped. Otherwise only // unprintable bytes and escape/quote characters are escaped. // If is not 0, then quotes that do not match are not // escaped. std::string EscapeBytes(absl::string_view str, bool escape_all_bytes = false, char escape_quote_char = '\0'); // Unquote and unescape a quoted CEL string literal (of the form '...', // "...", r'...' or r"..."). // If an error occurs and is not NULL, then it is populated with // the relevant error message. If is not NULL, it is populated // with the offset in at which the invalid input occurred. absl::StatusOr ParseStringLiteral(absl::string_view str); // Unquote and unescape a CEL bytes literal (of the form b'...', // b"...", rb'...', rb"...", br'...' or br"..."). // If an error occurs and is not NULL, then it is populated with // the relevant error message. If is not NULL, it is populated // with the offset in at which the invalid input occurred. absl::StatusOr ParseBytesLiteral(absl::string_view str); // Return a quoted and escaped CEL string literal for . // May choose to quote with ' or " to produce nicer output. std::string FormatStringLiteral(absl::string_view str); std::string FormatStringLiteral(const absl::Cord& str); // Return a quoted and escaped CEL string literal for . // Always uses single quotes. std::string FormatSingleQuotedStringLiteral(absl::string_view str); // Return a quoted and escaped CEL string literal for . // Always uses double quotes. std::string FormatDoubleQuotedStringLiteral(absl::string_view str); // Return a quoted and escaped CEL bytes literal for . // Prefixes with b and may choose to quote with ' or " to produce nicer output. std::string FormatBytesLiteral(absl::string_view str); // Return a quoted and escaped CEL bytes literal for . // Prefixes with b and always uses single quotes. std::string FormatSingleQuotedBytesLiteral(absl::string_view str); // Return a quoted and escaped CEL bytes literal for . // Prefixes with b and always uses double quotes. std::string FormatDoubleQuotedBytesLiteral(absl::string_view str); // Parse a CEL identifier. absl::StatusOr ParseIdentifier(absl::string_view str); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_STRINGS_H_ ================================================ FILE: internal/strings_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/strings.h" #include #include #include #include #include "absl/status/status.h" #include "absl/strings/ascii.h" #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "internal/testing.h" #include "internal/utf8.h" namespace cel::internal { namespace { using ::absl_testing::StatusIs; constexpr char kUnicodeNotAllowedInBytes1[] = "Unicode escape sequence \\u cannot be used in bytes literals"; constexpr char kUnicodeNotAllowedInBytes2[] = "Unicode escape sequence \\U cannot be used in bytes literals"; // takes a string literal of the form '...', r'...', "..." or r"...". // is the expected parsed form of . void TestQuotedString(const std::string& unquoted, const std::string& quoted) { auto status_or_unquoted = ParseStringLiteral(quoted); EXPECT_OK(status_or_unquoted) << unquoted; EXPECT_EQ(unquoted, status_or_unquoted.value()) << quoted; } void TestString(const std::string& unquoted) { TestQuotedString(unquoted, FormatStringLiteral(unquoted)); TestQuotedString(unquoted, FormatStringLiteral(absl::Cord(unquoted))); if (unquoted.size() > 1) { const size_t mid = unquoted.size() / 2; TestQuotedString(unquoted, FormatStringLiteral(absl::MakeFragmentedCord( {absl::string_view(unquoted).substr(0, mid), absl::string_view(unquoted).substr(mid)}))); } TestQuotedString(unquoted, absl::StrCat("'''", EscapeString(unquoted), "'''")); TestQuotedString(unquoted, absl::StrCat("\"\"\"", EscapeString(unquoted), "\"\"\"")); } void TestRawString(const std::string& unquoted) { const std::string quote = (!absl::StrContains(unquoted, "'")) ? "'" : "\""; TestQuotedString(unquoted, absl::StrCat("r", quote, unquoted, quote)); TestQuotedString(unquoted, absl::StrCat("r\"", unquoted, "\"")); TestQuotedString(unquoted, absl::StrCat("r'''", unquoted, "'''")); TestQuotedString(unquoted, absl::StrCat("r\"\"\"", unquoted, "\"\"\"")); } // is the quoted version of and represents the original // string mentioned in the test case. // This method compares the unescaped against its round trip version // i.e. after carrying out escaping followed by unescaping on it. void TestBytesEscaping(const std::string& unquoted, const std::string& quoted) { ASSERT_OK_AND_ASSIGN(auto unescaped, UnescapeBytes(unquoted)); const std::string escaped = EscapeBytes(unescaped); ASSERT_OK_AND_ASSIGN(auto unescaped2, UnescapeBytes(escaped)); EXPECT_EQ(unescaped, unescaped2); std::string escaped2 = EscapeBytes(unescaped, true); ASSERT_OK_AND_ASSIGN(auto unescaped3, UnescapeBytes(escaped2)); EXPECT_EQ(unescaped, unescaped3); } // takes a byte literal of the form b'...', b'''...''' void TestBytesLiteral(const std::string& quoted) { // Parse the literal. ASSERT_OK_AND_ASSIGN(auto unquoted, ParseBytesLiteral(quoted)); // Take the parsed literal and turn it back to a literal. std::string requoted = FormatBytesLiteral(unquoted); // Parse it again. ASSERT_OK_AND_ASSIGN(auto unquoted2, ParseBytesLiteral(requoted)); // Test the parsed literal forms for equality, not the unparsed forms. // This is because the unparsed forms can have different representations for // the same data, i.e., \000 and \x00. EXPECT_EQ(unquoted, unquoted2) << "unquoted : " << unquoted << "\nunquoted2: " << unquoted2; TestBytesEscaping(unquoted, quoted); } // takes a raw byte literal of the form rb'...', br'...', rb'''...''' // or br'''...'''. is the expected parsed form of . void TestQuotedRawBytesLiteral(const std::string& unquoted, const std::string& quoted) { ASSERT_OK_AND_ASSIGN(auto actual_unquoted, ParseBytesLiteral(quoted)); EXPECT_EQ(unquoted, actual_unquoted) << "quoted: " << quoted; } // takes a string of not escaped unquoted bytes. void TestUnescapedBytes(const std::string& unquoted) { TestBytesLiteral(FormatBytesLiteral(unquoted)); } void TestRawBytes(const std::string& unquoted) { const std::string quote = (!absl::StrContains(unquoted, "'")) ? "'" : "\""; TestQuotedRawBytesLiteral(unquoted, absl::StrCat("rb", quote, unquoted, quote)); TestQuotedRawBytesLiteral(unquoted, absl::StrCat("br", quote, unquoted, quote)); TestQuotedRawBytesLiteral(unquoted, absl::StrCat("rb'''", unquoted, "'''")); TestQuotedRawBytesLiteral(unquoted, absl::StrCat("br'''", unquoted, "'''")); TestQuotedRawBytesLiteral(unquoted, absl::StrCat("rb\"\"\"", unquoted, "\"\"\"")); TestQuotedRawBytesLiteral(unquoted, absl::StrCat("br\"\"\"", unquoted, "\"\"\"")); } void TestParseString(const std::string& orig) { EXPECT_OK(ParseStringLiteral(orig)) << orig; } void TestParseBytes(const std::string& orig) { EXPECT_OK(ParseBytesLiteral(orig)) << orig; } void TestStringEscaping(const std::string& orig) { const std::string escaped = EscapeString(orig); ASSERT_OK_AND_ASSIGN(auto unescaped, UnescapeString(escaped)); EXPECT_EQ(orig, unescaped) << "escaped: " << escaped; } void TestValue(const std::string& orig) { TestStringEscaping(orig); TestString(orig); } // Test that is treated as invalid, with error offset // and an error that contains substring // . The last arguments are optional because most // flat-out bad inputs are rejected without further information. void TestInvalidString(const std::string& str, const std::string& expected_error_substr = "") { auto status_or_string = ParseStringLiteral(str); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains(status_or_string.status().message(), expected_error_substr)); } // Test that is treated as invalid, with error offset // and an error that contains substring // . The last arguments are optional because most // flat-out bad inputs are rejected without further information. void TestInvalidBytes(const std::string& str, const std::string& expected_error_substr = "") { auto status_or_string = ParseBytesLiteral(str); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains(status_or_string.status().message(), expected_error_substr)); } TEST(StringsTest, TestParsingOfAllEscapeCharacters) { // All the valid escapes. const std::set valid_escapes = {'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '?', '"', '\'', '`', 'u', 'U', 'x', 'X'}; for (int escape_char_int = 0; escape_char_int < 255; ++escape_char_int) { char escape_char = static_cast(escape_char_int); absl::string_view escape_piece(&escape_char, 1); if (valid_escapes.find(escape_char) != valid_escapes.end()) { if (escape_char == '\'') { TestParseString(absl::StrCat("\"a\\", escape_piece, "0010ffff\"")); } TestParseString(absl::StrCat("'a\\", escape_piece, "0010ffff'")); TestParseString(absl::StrCat("'''a\\", escape_piece, "0010ffff'''")); } else if (absl::ascii_isdigit(escape_char)) { // Can also escape 0-3. const std::string test_string = absl::StrCat("'a\\", escape_piece, "00b'"); const std::string test_triple_quoted_string = absl::StrCat("'''a\\", escape_piece, "00b'''"); if (escape_char <= '3') { TestParseString(test_string); TestParseString(test_triple_quoted_string); } else { TestInvalidString(test_string, "Illegal escape sequence: "); TestInvalidString(test_triple_quoted_string, "Illegal escape sequence: "); } } else { if (Utf8IsValid(escape_piece)) { const std::string expected_error = ((escape_char == '\n' || escape_char == '\r') ? "Illegal escaped newline" : "Illegal escape sequence: "); TestInvalidString(absl::StrCat("'a\\", escape_piece, "b'"), expected_error); TestInvalidString(absl::StrCat("'''a\\", escape_piece, "b'''"), expected_error); } else { TestInvalidString(absl::StrCat("'a\\", escape_piece, "b'"), "Structurally invalid UTF8" // " string"); TestInvalidString(absl::StrCat("'''a\\", escape_piece, "b'''"), "Structurally invalid UTF8" // " string"); } } } } TEST(StringsTest, TestParsingOfOctalEscapes) { for (int idx = 0; idx < 256; ++idx) { const char end_char = (idx % 8) + '0'; const char mid_char = ((idx / 8) % 8) + '0'; const char lead_char = (idx / 64) + '0'; absl::string_view lead_piece(&lead_char, 1); absl::string_view mid_piece(&mid_char, 1); absl::string_view end_piece(&end_char, 1); const std::string test_string = absl::StrCat(lead_piece, mid_piece, end_piece); TestParseString(absl::StrCat("'\\", test_string, "'")); TestParseString(absl::StrCat("'''\\", test_string, "'''")); TestParseBytes(absl::StrCat("b'\\", test_string, "'")); } TestInvalidString("'\\'", "String must end with '"); TestInvalidString("'abc\\'", "String must end with '"); TestInvalidString("'''\\'''", "String must end with '''"); TestInvalidString("'''abc\\'''", "String must end with '''"); TestInvalidString( "'\\0'", "Octal escape must be followed by 3 octal digits but saw: \\0"); TestInvalidString( "'''abc\\0'''", "Octal escape must be followed by 3 octal digits but saw: \\0"); TestInvalidString( "'\\00'", "Octal escape must be followed by 3 octal digits but saw: \\00"); TestInvalidString( "'''ab\\00'''", "Octal escape must be followed by 3 octal digits but saw: \\00"); TestInvalidString( "'a\\008'", "Octal escape must be followed by 3 octal digits but saw: \\008"); TestInvalidString( "'''\\008'''", "Octal escape must be followed by 3 octal digits but saw: \\008"); TestInvalidString("'\\400'", "Illegal escape sequence: \\4"); TestInvalidString("'''\\400'''", "Illegal escape sequence: \\4"); TestInvalidString("'\\777'", "Illegal escape sequence: \\7"); TestInvalidString("'''\\777'''", "Illegal escape sequence: \\7"); } TEST(StringsTest, TestParsingOfHexEscapes) { for (int idx = 0; idx < 256; ++idx) { char lead_char = absl::StrFormat("%X", idx / 16)[0]; char end_char = absl::StrFormat("%x", idx % 16)[0]; absl::string_view lead_piece(&lead_char, 1); absl::string_view end_piece(&end_char, 1); TestParseString(absl::StrCat("'\\x", lead_piece, end_piece, "'")); TestParseString(absl::StrCat("'''\\x", lead_piece, end_piece, "'''")); TestParseString(absl::StrCat("'\\X", lead_piece, end_piece, "'")); TestParseString(absl::StrCat("'''\\X", lead_piece, end_piece, "'''")); TestParseBytes(absl::StrCat("b'\\X", lead_piece, end_piece, "'")); } TestInvalidString("'\\x'"); TestInvalidString("'''\\x'''"); TestInvalidString("'\\x0'"); TestInvalidString("'''\\x0'''"); TestInvalidString("'\\x0G'"); TestInvalidString("'''\\x0G'''"); } TEST(StringsTest, RoundTrip) { // Empty string is valid as a string but not an identifier. TestStringEscaping(""); TestString(""); TestValue("abc"); TestValue("abc123"); TestValue("123abc"); TestValue("_abc123"); TestValue("_123"); TestValue("abc def"); TestValue("a`b"); TestValue("a77b"); TestValue("\"abc\""); TestValue("'abc'"); TestValue("`abc`"); TestValue("aaa'bbb\"ccc`ddd"); TestValue("\n"); TestValue("\\"); TestValue("\\n"); TestValue("\x12"); TestValue("a,g 8q483 *(YG(*$(&*98fg\\r\\n\\t\x12gb"); // Value with an embedded zero char. std::string s = "abc"; s[1] = 0; TestValue(s); // Reserved SQL keyword, which must be quoted as an identifier. TestValue("select"); TestValue("SELECT"); TestValue("SElecT"); // Non-reserved SQL keyword, which shouldn't be quoted. TestValue("options"); // Note that control characters and other odd byte values such as \0 are // allowed in string literals as long as they are utf8 structurally valid. TestValue("\x01\x31"); TestValue("abc\xb\x42\141bc"); TestValue("123\1\x31\x32\x33"); TestValue("\\\"\xe8\xb0\xb7\xe6\xad\x8c\\\" is Google\\\'s Chinese name"); } TEST(StringsTest, InvalidString) { const std::string kInvalidStringLiteral = "Invalid string literal"; TestInvalidString("A", kInvalidStringLiteral); // No quote at all TestInvalidString("'", kInvalidStringLiteral); // No closing quote TestInvalidString("\"", kInvalidStringLiteral); // No closing quote TestInvalidString("a'", kInvalidStringLiteral); // No opening quote TestInvalidString("a\"", kInvalidStringLiteral); // No opening quote TestInvalidString("'''", "String cannot contain unescaped '"); TestInvalidString("\"\"\"", "String cannot contain unescaped \""); TestInvalidString("''''", "String cannot contain unescaped '"); TestInvalidString("\"\"\"\"", "String cannot contain unescaped \""); TestInvalidString("'''''", "String cannot contain unescaped '"); TestInvalidString("\"\"\"\"\"", "String cannot contain unescaped \""); TestInvalidString("'''''''", "String cannot contain unescaped '''"); TestInvalidString("\"\"\"\"\"\"\"", "String cannot contain unescaped \"\"\""); TestInvalidString("'''''''''", "String cannot contain unescaped '''"); TestInvalidString("\"\"\"\"\"\"\"\"\"", "String cannot contain unescaped \"\"\""); TestInvalidString("abc"); TestInvalidString("'abc'def'", "String cannot contain unescaped '"); TestInvalidString("'abc''def'", "String cannot contain unescaped '"); TestInvalidString("\"abc\"\"def\"", "String cannot contain unescaped \""); TestInvalidString("'''abc'''def'''", "String cannot contain unescaped '''"); TestInvalidString("\"\"\"abc\"\"\"def\"\"\"", "String cannot contain unescaped \"\"\""); TestInvalidString("'abc"); TestInvalidString("\"abc"); TestInvalidString("'''abc"); TestInvalidString("\"\"\"abc"); TestInvalidString("abc'"); TestInvalidString("abc\""); TestInvalidString("abc'''"); TestInvalidString("abc\"\"\""); TestInvalidString("\"abc'"); TestInvalidString("'abc\""); TestInvalidString("'''abc'", "String cannot contain unescaped '"); TestInvalidString("'''abc\""); TestInvalidString("'''a'", "String cannot contain unescaped '"); TestInvalidString("\"\"\"a\"", "String cannot contain unescaped \""); TestInvalidString("'''a''", "String cannot contain unescaped '"); TestInvalidString("\"\"\"a\"\"", "String cannot contain unescaped \""); TestInvalidString("'''a''''", "String cannot contain unescaped '''"); TestInvalidString("\"\"\"a\"\"\"\"", "String cannot contain unescaped \"\"\""); TestInvalidString("'''abc\"\"\""); TestInvalidString("\"\"\"abc'"); TestInvalidString("\"\"\"abc\"", "String cannot contain unescaped \""); TestInvalidString("\"\"\"abc'''"); TestInvalidString("'''\\\''''''", "String cannot contain unescaped '''"); TestInvalidString("\"\"\"\\\"\"\"\"\"\"", "String cannot contain unescaped \"\"\""); TestInvalidString("''''\\\'''''", "String cannot contain unescaped '''"); TestInvalidString("\"\"\"\"\\\"\"\"\"\"", "String cannot contain unescaped \"\"\""); TestInvalidString("\"\"\"'a' \"b\"\"\"\"", "String cannot contain unescaped \"\"\""); TestInvalidString("`abc`"); TestInvalidString("'abc\\'", "String must end with '"); TestInvalidString("\"abc\\\"", "String must end with \""); TestInvalidString("'''abc\\'''", "String must end with '''"); TestInvalidString("\"\"\"abc\\\"\"\"", "String must end with \"\"\""); TestInvalidString("'\\U12345678'", "Value of \\U12345678 exceeds Unicode limit (0x0010FFFF)"); // All trailing escapes. TestInvalidString("'\\"); TestInvalidString("\"\\"); TestInvalidString("''''''\\"); TestInvalidString("\"\"\"\"\"\"\\"); TestInvalidString("''\\\\"); TestInvalidString("\"\"\\\\"); TestInvalidString("''''''\\\\"); TestInvalidString("\"\"\"\"\"\"\\\\"); // String with an unescaped 0 byte. std::string s = "abc"; s[1] = 0; TestInvalidString(s); // Note: These are C-escapes to define the invalid strings. TestInvalidString("'\xc1'", "Structurally invalid UTF8 string"); TestInvalidString("'\xca'", "Structurally invalid UTF8 string"); TestInvalidString("'\xcc'", "Structurally invalid UTF8 string"); TestInvalidString("'\xFA'", "Structurally invalid UTF8 string"); TestInvalidString("'\xc1\xca\x1b\x62\x19o\xcc\x04'", "Structurally invalid UTF8 string"); TestInvalidString("'\xc2\xc0'", "Structurally invalid UTF8 string"); // First byte ok utf8, // invalid together. TestValue("\xc2\xbf"); // Same first byte, good sequence. // These are all valid prefixes for utf8 characters, but the characters // are not complete. TestInvalidString( "'\xc2'", "Structurally invalid UTF8 string"); // Should be 2 byte utf8 character. TestInvalidString( "'\xc3'", "Structurally invalid UTF8 string"); // Should be 2 byte utf8 character. TestInvalidString( "'\xe0'", "Structurally invalid UTF8 string"); // Should be 3 byte utf8 character. TestInvalidString( "'\xe0\xac'", "Structurally invalid UTF8 string"); // Should be 3 byte utf8 character. TestInvalidString( "'\xf0'", "Structurally invalid UTF8 string"); // Should be 4 byte utf8 character. TestInvalidString( "'\xf0\x90'", "Structurally invalid UTF8 string"); // Should be 4 byte utf8 character. TestInvalidString( "'\xf0\x90\x80'", "Structurally invalid UTF8 string"); // Should be 4 byte utf8 character. } TEST(BytesTest, RoundTrip) { TestBytesLiteral("b\"\""); TestBytesLiteral("b\"\"\"\"\"\""); TestUnescapedBytes(""); TestBytesLiteral("b'\\000\\x00AAA\\xfF\\377'"); TestBytesLiteral("b'''\\000\\x00AAA\\xfF\\377'''"); TestBytesLiteral("b'\\a\\b\\f\\n\\r\\t\\v\\\\\\?\\\"\\'\\`\\x00\\Xff'"); TestBytesLiteral("b'''\\a\\b\\f\\n\\r\\t\\v\\\\\\?\\\"\\'\\`\\x00\\Xff'''"); TestBytesLiteral("b'\\n\\012\\x0A'"); // Different newline representations. TestBytesLiteral("b'''\\n\\012\\x0A'''"); // Note the C-escaping to define the bytes. These are invalid strings for // various reasons, but are valid as bytes. TestUnescapedBytes("\xc1"); TestUnescapedBytes("\xca"); TestUnescapedBytes("\xcc"); TestUnescapedBytes("\xFA"); TestUnescapedBytes("\xc1\xca\x1b\x62\x19o\xcc\x04"); } TEST(BytesTest, ToBytesLiteralTests) { // ToBytesLiteral will choose to quote with ' if it will avoid escaping. // Non-printable bytes are escaped as hex. For printable bytes, only the // escape character and quote character are escaped. EXPECT_EQ("b\"\"", FormatBytesLiteral("")); EXPECT_EQ("b\"abc\"", FormatBytesLiteral("abc")); EXPECT_EQ("b\"abc'def\"", FormatBytesLiteral("abc'def")); EXPECT_EQ("b'abc\"def'", FormatBytesLiteral("abc\"def")); EXPECT_EQ("b\"abc`def\"", FormatBytesLiteral("abc`def")); EXPECT_EQ("b\"abc'\\\"`def\"", FormatBytesLiteral("abc'\"`def")); // Override the quoting style to use single quotes. EXPECT_EQ("b''", FormatSingleQuotedBytesLiteral("")); EXPECT_EQ("b'abc'", FormatSingleQuotedBytesLiteral("abc")); EXPECT_EQ("b'abc\\'def'", FormatSingleQuotedBytesLiteral("abc'def")); EXPECT_EQ("b'abc\"def'", FormatSingleQuotedBytesLiteral("abc\"def")); EXPECT_EQ("b'abc`def'", FormatSingleQuotedBytesLiteral("abc`def")); EXPECT_EQ("b'abc\\'\"`def'", FormatSingleQuotedBytesLiteral("abc'\"`def")); // Override the quoting style to use double quotes. EXPECT_EQ("b\"\"", FormatDoubleQuotedBytesLiteral("")); EXPECT_EQ("b\"abc\"", FormatDoubleQuotedBytesLiteral("abc")); EXPECT_EQ("b\"abc'def\"", FormatDoubleQuotedBytesLiteral("abc'def")); EXPECT_EQ("b\"abc\\\"def\"", FormatDoubleQuotedBytesLiteral("abc\"def")); EXPECT_EQ("b\"abc`def\"", FormatDoubleQuotedBytesLiteral("abc`def")); EXPECT_EQ("b\"abc'\\\"`def\"", FormatDoubleQuotedBytesLiteral("abc'\"`def")); EXPECT_EQ("b\"\\x07-\\x08-\\x0c-\\x0a-\\x0d-\\x09-\\x0b-\\\\-?-\\\"-'-`\"", FormatBytesLiteral("\a-\b-\f-\n-\r-\t-\v-\\-?-\"-'-`")); EXPECT_EQ("b\"\\x0a\"", FormatBytesLiteral("\n")); ASSERT_OK_AND_ASSIGN(auto unquoted, ParseBytesLiteral("b'\\n\\012\\x0a\\x0A'")); EXPECT_EQ("b\"\\x0a\\x0a\\x0a\\x0a\"", FormatBytesLiteral(unquoted)); } TEST(ByesTest, InvalidBytes) { TestInvalidBytes("A", "Invalid bytes literal"); // No quotes TestInvalidBytes("b'A", "Invalid bytes literal"); // No ending quote TestInvalidBytes("'A'", "Invalid bytes literal"); // No ending quote TestInvalidBytes("'A'", "Invalid bytes literal"); // No 'b' prefix. TestInvalidBytes("'''A'''"); TestInvalidBytes("b'k\\u0030'", kUnicodeNotAllowedInBytes1); TestInvalidBytes("b'''\\u0030'''", kUnicodeNotAllowedInBytes1); TestInvalidBytes("b'\\U00000030'", kUnicodeNotAllowedInBytes2); TestInvalidBytes("b'''qwerty\\U00000030'''", kUnicodeNotAllowedInBytes2); EXPECT_FALSE(UnescapeBytes("abc\\u0030").ok()); EXPECT_FALSE(UnescapeBytes("abc\\U00000030").ok()); EXPECT_FALSE(UnescapeBytes("abc\\U00000030").ok()); } TEST(RawStringsTest, ValidCases) { TestRawString(""); TestRawString("1"); TestRawString("\\x53"); TestRawString("\\x123"); TestRawString("\\001"); TestRawString("a\\44'A"); TestRawString("a\\e"); TestRawString("\\ea"); TestRawString("\\U1234"); TestRawString("\\u"); TestRawString("\\xc2\\\\"); TestRawString("f\\(abc',(.*),def\\?"); TestRawString("a\\\"b"); } TEST(RawStringsTest, InvalidRawStrings) { TestInvalidString("r\"\\\"", "String must end with \""); TestInvalidString("r\"\\\\\\\"", "String must end with \""); TestInvalidString("r\""); TestInvalidString("r"); TestInvalidString("rb\"\""); TestInvalidString("b\"\""); TestInvalidString("r'''", "String cannot contain unescaped '"); } TEST(RawBytesTest, ValidCases) { TestRawBytes(""); TestRawBytes("1"); TestRawBytes("\\x53"); TestRawBytes("\\x123"); TestRawBytes("\\001"); TestRawBytes("a\\44'A"); TestRawBytes("a\\e"); TestRawBytes("\\ea"); TestRawBytes("\\U1234"); TestRawBytes("\\u"); TestRawBytes("\\xc2\\\\"); TestRawBytes("f\\(abc',(.*),def\\?"); } TEST(RawBytesTest, InvalidRawBytes) { TestInvalidBytes("r''"); TestInvalidBytes("r''''''"); TestInvalidBytes("rrb''"); TestInvalidBytes("brb''"); TestInvalidBytes("rb'a\\e"); TestInvalidBytes("rb\"\\\"", "String must end with \""); TestInvalidBytes("br\"\\\\\\\"", "String must end with \""); TestInvalidBytes("rb"); TestInvalidBytes("br"); TestInvalidBytes("rb\""); TestInvalidBytes("rb\"\"\"", "String cannot contain unescaped \""); TestInvalidBytes("rb\"xyz\"\"", "String cannot contain unescaped \""); } TEST(StringsTest, QuotedForms) { // EscapeString escapes all quote characters. EXPECT_EQ("", EscapeString("")); EXPECT_EQ("abc", EscapeString("abc")); EXPECT_EQ("abc\\'def", EscapeString("abc'def")); EXPECT_EQ("abc\\\"def", EscapeString("abc\"def")); EXPECT_EQ("abc\\`def", EscapeString("abc`def")); // ToStringLiteral will choose to quote with ' if it will avoid escaping. // Other quoted characters will not be escaped. EXPECT_EQ("\"\"", FormatStringLiteral("")); EXPECT_EQ("\"abc\"", FormatStringLiteral("abc")); EXPECT_EQ("\"abc'def\"", FormatStringLiteral("abc'def")); EXPECT_EQ("'abc\"def'", FormatStringLiteral("abc\"def")); EXPECT_EQ("\"abc`def\"", FormatStringLiteral("abc`def")); EXPECT_EQ("\"abc'\\\"`def\"", FormatStringLiteral("abc'\"`def")); // Override the quoting style to use single quotes. EXPECT_EQ("''", FormatSingleQuotedStringLiteral("")); EXPECT_EQ("'abc'", FormatSingleQuotedStringLiteral("abc")); EXPECT_EQ("'abc\\'def'", FormatSingleQuotedStringLiteral("abc'def")); EXPECT_EQ("'abc\"def'", FormatSingleQuotedStringLiteral("abc\"def")); EXPECT_EQ("'abc`def'", FormatSingleQuotedStringLiteral("abc`def")); EXPECT_EQ("'abc\\'\"`def'", FormatSingleQuotedStringLiteral("abc'\"`def")); // Override the quoting style to use double quotes. EXPECT_EQ("\"\"", FormatDoubleQuotedStringLiteral("")); EXPECT_EQ("\"abc\"", FormatDoubleQuotedStringLiteral("abc")); EXPECT_EQ("\"abc'def\"", FormatDoubleQuotedStringLiteral("abc'def")); EXPECT_EQ("\"abc\\\"def\"", FormatDoubleQuotedStringLiteral("abc\"def")); EXPECT_EQ("\"abc`def\"", FormatDoubleQuotedStringLiteral("abc`def")); EXPECT_EQ("\"abc'\\\"`def\"", FormatDoubleQuotedStringLiteral("abc'\"`def")); } void ExpectParsedString(const std::string& expected, const std::vector& quoted_strings) { for (const std::string& quoted : quoted_strings) { ASSERT_OK_AND_ASSIGN(auto parsed, ParseStringLiteral(quoted)); EXPECT_EQ(expected, parsed); } } void ExpectParsedBytes(const std::string& expected, const std::vector& quoted_strings) { for (const std::string& quoted : quoted_strings) { ASSERT_OK_AND_ASSIGN(auto parsed, ParseBytesLiteral(quoted)); EXPECT_EQ(expected, parsed); } } TEST(StringsTest, Parse) { ExpectParsedString("abc", {"'abc'", "\"abc\"", "'''abc'''", "\"\"\"abc\"\"\""}); ExpectParsedString( "abc\ndef\x12ghi", {"'abc\\ndef\\x12ghi'", "\"abc\\ndef\\x12ghi\"", "'''abc\\ndef\\x12ghi'''", "\"\"\"abc\\ndef\\x12ghi\"\"\""}); ExpectParsedString("\xF4\x8F\xBF\xBD", {"'\\U0010FFFD'", "\"\\U0010FFFD\"", "'''\\U0010FFFD'''", "\"\"\"\\U0010FFFD\"\"\""}); // Some more test cases for triple quoted content. ExpectParsedString("", {"''''''", "\"\"\"\"\"\""}); ExpectParsedString("'\"", {"''''\"'''"}); ExpectParsedString("''''''", {"'''''\\'''\\''''"}); ExpectParsedString("'", {"'''\\''''"}); ExpectParsedString("''", {"'''\\'\\''''"}); ExpectParsedString("'\"", {"''''\"'''"}); ExpectParsedString("'a", {"''''a'''"}); ExpectParsedString("\"a", {"\"\"\"\"a\"\"\""}); ExpectParsedString("''a", {"'''''a'''"}); ExpectParsedString("\"\"a", {"\"\"\"\"\"a\"\"\""}); } TEST(StringsTest, TestNewlines) { ExpectParsedString("a\nb", {"'''a\rb'''", "'''a\nb'''", "'''a\r\nb'''"}); ExpectParsedString("a\n\nb", {"'''a\n\rb'''", "'''a\r\n\r\nb'''"}); // Escaped newlines. ExpectParsedString("a\nb", {"'''a\\nb'''"}); ExpectParsedString("a\rb", {"'''a\\rb'''"}); ExpectParsedString("a\r\nb", {"'''a\\r\\nb'''"}); } TEST(RawStringsTest, CompareRawAndRegularStringParsing) { ExpectParsedString("\\n", {"r'\\n'", "r\"\\n\"", "r'''\\n'''", "r\"\"\"\\n\"\"\""}); ExpectParsedString("\n", {"'\\n'", "\"\\n\"", "'''\\n'''", "\"\"\"\\n\"\"\""}); ExpectParsedString("\\e", {"r'\\e'", "r\"\\e\"", "r'''\\e'''", "r\"\"\"\\e\"\"\""}); TestInvalidString("'\\e'", "Illegal escape sequence: \\e"); TestInvalidString("\"\\e\"", "Illegal escape sequence: \\e"); TestInvalidString("'''\\e'''", "Illegal escape sequence: \\e"); TestInvalidString("\"\"\"\\e\"\"\"", "Illegal escape sequence: \\e"); ExpectParsedString( "\\x0", {"r'\\x0'", "r\"\\x0\"", "r'''\\x0'''", "r\"\"\"\\x0\"\"\""}); constexpr char kHexError[] = "Hex escape must be followed by 2 hex digits but saw: \\x0"; TestInvalidString("'\\x0'", kHexError); TestInvalidString("\"\\x0\"", kHexError); TestInvalidString("'''\\x0'''", kHexError); TestInvalidString("\"\"\"\\x0\"\"\"", kHexError); ExpectParsedString("\\'", {"r'\\\''"}); ExpectParsedString("'", {"'\\\''"}); ExpectParsedString("\\\"", {"r\"\\\"\""}); ExpectParsedString("\"", {"\"\\\"\""}); ExpectParsedString("''\\'", {"r'''\'\'\\\''''"}); ExpectParsedString("'''", {"'''\'\'\\\''''"}); ExpectParsedString("\"\"\\\"", {"r\"\"\"\"\"\\\"\"\"\""}); ExpectParsedString("\"\"\"", {"\"\"\"\"\"\\\"\"\"\""}); } TEST(RawBytesTest, CompareRawAndRegularBytesParsing) { ExpectParsedBytes("\\n", {"rb'\\n'", "br'\\n'", "rb\"\\n\"", "br\"\\n\""}); ExpectParsedBytes("\n", {"b'\\n'", "b\"\\n\""}); ExpectParsedBytes("\\u0030", {"rb'\\u0030'", "br'\\u0030'", "rb\"\\u0030\"", "br\"\\u0030\""}); TestInvalidBytes("b'\\u0030'", kUnicodeNotAllowedInBytes1); TestInvalidBytes("b\"\\u0030\"", kUnicodeNotAllowedInBytes1); TestInvalidBytes("b\"abc\\u0030\"", kUnicodeNotAllowedInBytes1); ExpectParsedBytes("\\U00000030", {"rb'\\U00000030'", "br'\\U00000030'", "rb\"\\U00000030\"", "br\"\\U00000030\""}); TestInvalidBytes("b'\\U00000030'", kUnicodeNotAllowedInBytes2); TestInvalidBytes("b\"\\U00000030\"", kUnicodeNotAllowedInBytes2); TestInvalidBytes("b\"abc\\U00000030\"", kUnicodeNotAllowedInBytes2); ExpectParsedBytes("\\e", {"rb'\\e'", "br'\\e'", "rb\"\\e\"", "br\"\\e\""}); TestInvalidBytes("b'\\e'", "Illegal escape sequence: \\e"); TestInvalidBytes("b\"\\e\"", "Illegal escape sequence: \\e"); TestInvalidBytes("b\"abcd\\e\"", "Illegal escape sequence: \\e"); ExpectParsedBytes("\\'", {"rb'\\\''", "br'\\\''"}); ExpectParsedBytes("'", {"b'\\\''"}); ExpectParsedBytes("\\\"", {"rb\"\\\"\"", "br\"\\\"\""}); ExpectParsedBytes("\"", {"b\"\\\"\""}); ExpectParsedBytes("''\\'", {"rb'''\'\'\\\''''", "br'''\'\'\\\''''"}); ExpectParsedBytes("'''", {"b'''\'\'\\\''''"}); ExpectParsedBytes("\"\"\\\"", {"rb\"\"\"\"\"\\\"\"\"\"", "br\"\"\"\"\"\\\"\"\"\""}); ExpectParsedBytes("\"\"\"", {"b\"\"\"\"\"\\\"\"\"\""}); } struct epair { std::string escaped; std::string unescaped; }; // Copied from strings/escaping_test.cc, CEscape::BasicEscaping. TEST(StringsTest, UTF8Escape) { epair utf8_hex_values[] = { {"\x20\xe4\xbd\xa0\\t\xe5\xa5\xbd,\\r!\\n", "\x20\xe4\xbd\xa0\t\xe5\xa5\xbd,\r!\n"}, {"\xe8\xa9\xa6\xe9\xa8\x93\\\' means \\\"test\\\"", "\xe8\xa9\xa6\xe9\xa8\x93\' means \"test\""}, {"\\\\\xe6\x88\x91\\\\:\\\\\xe6\x9d\xa8\xe6\xac\xa2\\\\", "\\\xe6\x88\x91\\:\\\xe6\x9d\xa8\xe6\xac\xa2\\"}, {"\xed\x81\xac\xeb\xa1\xac\\x08\\t\\n\\x0b\\x0c\\r", "\xed\x81\xac\xeb\xa1\xac\010\011\012\013\014\015"}}; for (int i = 0; i < ABSL_ARRAYSIZE(utf8_hex_values); ++i) { std::string escaped = EscapeString(utf8_hex_values[i].unescaped); EXPECT_EQ(escaped, utf8_hex_values[i].escaped); } } // Originally copied from strings/escaping_test.cc, Unescape::BasicFunction, // but changes for '\\xABCD' which only parses 2 hex digits after the escape. TEST(StringsTest, UTF8Unescape) { epair tests[] = {{"\\u0030", "0"}, {"\\u00A3", "\xC2\xA3"}, {"\\u22FD", "\xE2\x8B\xBD"}, {"\\ud7FF", "\xED\x9F\xBF"}, {"\\u22FD", "\xE2\x8B\xBD"}, {"\\U00010000", "\xF0\x90\x80\x80"}, {"\\U0000E000", "\xEE\x80\x80"}, {"\\U0001DFFF", "\xF0\x9D\xBF\xBF"}, {"\\U0010FFFD", "\xF4\x8F\xBF\xBD"}, {"\\xAbCD", "\xc2\xab" "CD"}, {"\\253CD", "\xc2\xab" "CD"}, {"\\x4141", "A41"}}; for (int i = 0; i < ABSL_ARRAYSIZE(tests); ++i) { const std::string& e = tests[i].escaped; const std::string& u = tests[i].unescaped; ASSERT_OK_AND_ASSIGN(auto out, UnescapeString(e)); EXPECT_EQ(u, out) << "original escaped: '" << e << "'\nunescaped: '" << out << "'\nexpected unescaped: '" << u << "'"; } std::string bad[] = {"\\u1", // too short "\\U1", // too short "\\Uffffff", "\\777"}; // exceeds 0xff for (int i = 0; i < ABSL_ARRAYSIZE(bad); ++i) { const std::string& e = bad[i]; auto status_or_string = UnescapeString(e); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains(status_or_string.status().message(), "Invalid escaped string")); } } TEST(StringsTest, TestUnescapeErrorMessages) { std::string error_string; std::string out; auto status_or_string = UnescapeString("\\2"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: Octal escape must be followed by 3 octal " "digits but saw: \\2")); status_or_string = UnescapeString("\\22X0"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: Octal escape must be followed by 3 octal " "digits but saw: \\22X")); status_or_string = UnescapeString("\\X0"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: Hex escape must be followed by 2 hex digits " "but saw: \\X0")); status_or_string = UnescapeString("\\x0G0"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: Hex escape must be followed by 2 hex digits " "but saw: \\x0G")); status_or_string = UnescapeString("\\u00"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: \\u must be followed by 4 hex digits but saw: " "\\u00")); status_or_string = UnescapeString("\\ude8c"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: Unicode value \\ude8c is invalid")); status_or_string = UnescapeString("\\u000G0"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: \\u must be followed by 4 hex digits but saw: " "\\u000G")); status_or_string = UnescapeString("\\U00"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: \\U must be followed by 8 hex digits but saw: " "\\U00")); status_or_string = UnescapeString("\\U000000G00"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: \\U must be followed by 8 hex digits but saw: " "\\U000000G0")); status_or_string = UnescapeString("\\U0000D83D"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: Unicode value \\U0000D83D is invalid")); status_or_string = UnescapeString("\\UFFFFFFFF0"); EXPECT_THAT(status_or_string, StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_TRUE(absl::StrContains( status_or_string.status().message(), "Illegal escape sequence: Value of \\UFFFFFFFF exceeds Unicode limit " "(0x0010FFFF)")); } } // namespace } // namespace cel::internal ================================================ FILE: internal/testing.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/testing.h" #include "absl/strings/str_cat.h" // IWYU pragma: keep namespace cel::internal { void AddFatalFailure(const char* file, int line, absl::string_view expression, const StatusBuilder& builder) { GTEST_MESSAGE_AT_(file, line, absl::StrCat(expression, " returned error: ", absl::Status(builder).ToString( absl::StatusToStringMode::kWithEverything)) .c_str(), ::testing::TestPartResult::kFatalFailure); } } // namespace cel::internal ================================================ FILE: internal/testing.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_H_ #include "gmock/gmock.h" // IWYU pragma: export #include "gtest/gtest.h" // IWYU pragma: export #include "absl/status/status_matchers.h" #include "internal/status_macros.h" // IWYU pragma: keep #ifndef ASSERT_OK #define ASSERT_OK(expr) ASSERT_THAT(expr, ::absl_testing::IsOk()) #endif #ifndef EXPECT_OK #define EXPECT_OK(expr) EXPECT_THAT(expr, ::absl_testing::IsOk()) #endif #ifndef ASSERT_OK_AND_ASSIGN #define ASSERT_OK_AND_ASSIGN(lhs, rhs) \ CEL_ASSIGN_OR_RETURN( \ lhs, rhs, ::cel::internal::AddFatalFailure(__FILE__, __LINE__, #rhs, _)) #endif namespace cel::internal { void AddFatalFailure(const char* file, int line, absl::string_view expression, const StatusBuilder& builder); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_H_ ================================================ FILE: internal/testing_descriptor_pool.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/testing_descriptor_pool.h" #include #include #include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/base/macros.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "internal/noop_delete.h" #include "google/protobuf/descriptor.h" namespace cel::internal { namespace { ABSL_CONST_INIT const uint8_t kTestingDescriptorSet[] = { #include "internal/testing_descriptor_set_embed.inc" }; } // namespace const google::protobuf::DescriptorPool* absl_nonnull GetTestingDescriptorPool() { static const google::protobuf::DescriptorPool* absl_nonnull const pool = []() { google::protobuf::FileDescriptorSet file_desc_set; ABSL_CHECK(file_desc_set.ParseFromArray( // Crash OK kTestingDescriptorSet, ABSL_ARRAYSIZE(kTestingDescriptorSet))); auto* pool = new google::protobuf::DescriptorPool(); for (const auto& file_desc : file_desc_set.file()) { ABSL_CHECK(pool->BuildFile(file_desc) != nullptr); // Crash OK } return pool; }(); return pool; } absl_nonnull std::shared_ptr GetSharedTestingDescriptorPool() { static const absl::NoDestructor< absl_nonnull std::shared_ptr> instance(GetTestingDescriptorPool(), internal::NoopDeleteFor()); return *instance; } } // namespace cel::internal ================================================ FILE: internal/testing_descriptor_pool.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_DESCRIPTOR_POOL_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_DESCRIPTOR_POOL_H_ #include #include "absl/base/nullability.h" #include "google/protobuf/descriptor.h" namespace cel::internal { // GetTestingDescriptorPool returns a pointer to a `google::protobuf::DescriptorPool` // which includes has the necessary descriptors required for the purposes of // testing. The returning `google::protobuf::DescriptorPool` is valid for the lifetime of // the process. const google::protobuf::DescriptorPool* absl_nonnull GetTestingDescriptorPool(); absl_nonnull std::shared_ptr GetSharedTestingDescriptorPool(); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_DESCRIPTOR_POOL_H_ ================================================ FILE: internal/testing_descriptor_pool_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/testing_descriptor_pool.h" #include "internal/testing.h" #include "google/protobuf/descriptor.h" namespace cel::internal { namespace { using ::testing::NotNull; TEST(TestingDescriptorPool, NullValue) { ASSERT_THAT(GetTestingDescriptorPool()->FindEnumTypeByName( "google.protobuf.NullValue"), NotNull()); } TEST(TestingDescriptorPool, BoolValue) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.BoolValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE); } TEST(TestingDescriptorPool, Int32Value) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.Int32Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE); } TEST(TestingDescriptorPool, Int64Value) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.Int64Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE); } TEST(TestingDescriptorPool, UInt32Value) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.UInt32Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE); } TEST(TestingDescriptorPool, UInt64Value) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.UInt64Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE); } TEST(TestingDescriptorPool, FloatValue) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.FloatValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE); } TEST(TestingDescriptorPool, DoubleValue) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.DoubleValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE); } TEST(TestingDescriptorPool, BytesValue) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.BytesValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE); } TEST(TestingDescriptorPool, StringValue) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.StringValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE); } TEST(TestingDescriptorPool, Any) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName("google.protobuf.Any"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_ANY); } TEST(TestingDescriptorPool, Duration) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.Duration"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION); } TEST(TestingDescriptorPool, Timestamp) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.Timestamp"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP); } TEST(TestingDescriptorPool, Value) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.Value"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); } TEST(TestingDescriptorPool, ListValue) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.ListValue"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); } TEST(TestingDescriptorPool, Struct) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.Struct"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); } TEST(TestingDescriptorPool, FieldMask) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.FieldMask"); ASSERT_THAT(desc, NotNull()); EXPECT_EQ(desc->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_FIELDMASK); } TEST(TestingDescriptorPool, Empty) { const auto* desc = GetTestingDescriptorPool()->FindMessageTypeByName( "google.protobuf.Empty"); ASSERT_THAT(desc, NotNull()); } TEST(TestingDescriptorPool, TestAllTypesProto2) { EXPECT_THAT(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto2.TestAllTypes"), NotNull()); } TEST(TestingDescriptorPool, TestAllTypesProto3) { EXPECT_THAT(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes"), NotNull()); } } // namespace } // namespace cel::internal ================================================ FILE: internal/testing_message_factory.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/testing_message_factory.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" namespace cel::internal { google::protobuf::MessageFactory* absl_nonnull GetTestingMessageFactory() { static absl::NoDestructor factory( GetTestingDescriptorPool()); return &*factory; } } // namespace cel::internal ================================================ FILE: internal/testing_message_factory.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_MESSAGE_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_MESSAGE_FACTORY_H_ #include "absl/base/nullability.h" #include "google/protobuf/message.h" namespace cel::internal { // GetTestingMessageFactory returns a pointer to a `google::protobuf::MessageFactory` // which should be used with the descriptor pool returned by // `GetTestingDescriptorPool`. The returning `google::protobuf::MessageFactory` is valid // for the lifetime of the process. google::protobuf::MessageFactory* absl_nonnull GetTestingMessageFactory(); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_TESTING_MESSAGE_FACTORY_H_ ================================================ FILE: internal/time.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/time.h" #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "internal/status_macros.h" #include "google/protobuf/util/time_util.h" namespace cel::internal { namespace { std::string RawFormatTimestamp(absl::Time timestamp) { return absl::FormatTime("%Y-%m-%d%ET%H:%M:%E*SZ", timestamp, absl::UTCTimeZone()); } } // namespace absl::Duration MaxDuration() { // This currently supports a larger range then the current CEL spec. The // intent is to widen the CEL spec to support the larger range and match // google.protobuf.Duration from protocol buffer messages, which this // implementation currently supports. // TODO(google/cel-spec/issues/214): revisit return absl::Seconds(google::protobuf::util::TimeUtil::kDurationMaxSeconds) + absl::Nanoseconds(google::protobuf::util::TimeUtil::kDurationMaxNanoseconds); } absl::Duration MinDuration() { // This currently supports a larger range then the current CEL spec. The // intent is to widen the CEL spec to support the larger range and match // google.protobuf.Duration from protocol buffer messages, which this // implementation currently supports. // TODO(google/cel-spec/issues/214): revisit return absl::Seconds(google::protobuf::util::TimeUtil::kDurationMinSeconds) + absl::Nanoseconds(google::protobuf::util::TimeUtil::kDurationMinNanoseconds); } absl::Time MaxTimestamp() { return absl::UnixEpoch() + absl::Seconds(google::protobuf::util::TimeUtil::kTimestampMaxSeconds) + absl::Nanoseconds(google::protobuf::util::TimeUtil::kTimestampMaxNanoseconds); } absl::Time MinTimestamp() { return absl::UnixEpoch() + absl::Seconds(google::protobuf::util::TimeUtil::kTimestampMinSeconds) + absl::Nanoseconds(google::protobuf::util::TimeUtil::kTimestampMinNanoseconds); } absl::Status ValidateDuration(absl::Duration duration) { if (duration < MinDuration()) { return absl::InvalidArgumentError( absl::StrCat("Duration \"", absl::FormatDuration(duration), "\" below minimum allowed duration \"", absl::FormatDuration(MinDuration()), "\"")); } if (duration > MaxDuration()) { return absl::InvalidArgumentError( absl::StrCat("Duration \"", absl::FormatDuration(duration), "\" above maximum allowed duration \"", absl::FormatDuration(MaxDuration()), "\"")); } return absl::OkStatus(); } absl::StatusOr ParseDuration(absl::string_view input) { absl::Duration duration; if (!absl::ParseDuration(input, &duration)) { return absl::InvalidArgumentError("Failed to parse duration from string"); } return duration; } absl::StatusOr FormatDuration(absl::Duration duration) { CEL_RETURN_IF_ERROR(ValidateDuration(duration)); return absl::FormatDuration(duration); } std::string DebugStringDuration(absl::Duration duration) { return absl::FormatDuration(duration); } absl::Status ValidateTimestamp(absl::Time timestamp) { if (timestamp < MinTimestamp()) { return absl::InvalidArgumentError( absl::StrCat("Timestamp \"", RawFormatTimestamp(timestamp), "\" below minimum allowed timestamp \"", RawFormatTimestamp(MinTimestamp()), "\"")); } if (timestamp > MaxTimestamp()) { return absl::InvalidArgumentError( absl::StrCat("Timestamp \"", RawFormatTimestamp(timestamp), "\" above maximum allowed timestamp \"", RawFormatTimestamp(MaxTimestamp()), "\"")); } return absl::OkStatus(); } absl::StatusOr ParseTimestamp(absl::string_view input) { absl::Time timestamp; std::string err; if (!absl::ParseTime(absl::RFC3339_full, input, absl::UTCTimeZone(), ×tamp, &err)) { return err.empty() ? absl::InvalidArgumentError( "Failed to parse timestamp from string") : absl::InvalidArgumentError(absl::StrCat( "Failed to parse timestamp from string: ", err)); } CEL_RETURN_IF_ERROR(ValidateTimestamp(timestamp)); return timestamp; } absl::StatusOr FormatTimestamp(absl::Time timestamp) { CEL_RETURN_IF_ERROR(ValidateTimestamp(timestamp)); return RawFormatTimestamp(timestamp); } std::string FormatNanos(int32_t nanos) { constexpr int32_t kNanosPerMillisecond = 1000000; constexpr int32_t kNanosPerMicrosecond = 1000; if (nanos % kNanosPerMillisecond == 0) { return absl::StrFormat("%03d", nanos / kNanosPerMillisecond); } else if (nanos % kNanosPerMicrosecond == 0) { return absl::StrFormat("%06d", nanos / kNanosPerMicrosecond); } return absl::StrFormat("%09d", nanos); } absl::StatusOr EncodeDurationToJson(absl::Duration duration) { // Adapted from protobuf time_util. CEL_RETURN_IF_ERROR(ValidateDuration(duration)); std::string result; int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); int64_t nanos = absl::IDivDuration(duration, absl::Nanoseconds(1), &duration); if (seconds < 0 || nanos < 0) { result = "-"; seconds = -seconds; nanos = -nanos; } absl::StrAppend(&result, seconds); if (nanos != 0) { absl::StrAppend(&result, ".", FormatNanos(nanos)); } absl::StrAppend(&result, "s"); return result; } absl::StatusOr EncodeTimestampToJson(absl::Time timestamp) { // Adapted from protobuf time_util. static constexpr absl::string_view kTimestampFormat = "%E4Y-%m-%dT%H:%M:%S"; CEL_RETURN_IF_ERROR(ValidateTimestamp(timestamp)); // Handle nanos and the seconds separately to match proto JSON format. absl::Time unix_seconds = absl::FromUnixSeconds(absl::ToUnixSeconds(timestamp)); int64_t n = (timestamp - unix_seconds) / absl::Nanoseconds(1); std::string result = absl::FormatTime(kTimestampFormat, unix_seconds, absl::UTCTimeZone()); if (n > 0) { absl::StrAppend(&result, ".", FormatNanos(n)); } absl::StrAppend(&result, "Z"); return result; } std::string DebugStringTimestamp(absl::Time timestamp) { return RawFormatTimestamp(timestamp); } } // namespace cel::internal ================================================ FILE: internal/time.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TIME_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_TIME_H_ #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" namespace cel::internal { absl::Duration MaxDuration(); absl::Duration MinDuration(); absl::Time MaxTimestamp(); absl::Time MinTimestamp(); absl::Status ValidateDuration(absl::Duration duration); absl::StatusOr ParseDuration(absl::string_view input); // Human-friendly format for duration provided to match DebugString. // Checks that the duration is in the supported range for CEL values. absl::StatusOr FormatDuration(absl::Duration duration); // Encodes duration as a string for JSON. // This implementation is compatible with protobuf. absl::StatusOr EncodeDurationToJson(absl::Duration duration); std::string DebugStringDuration(absl::Duration duration); absl::Status ValidateTimestamp(absl::Time timestamp); absl::StatusOr ParseTimestamp(absl::string_view input); // Human-friendly format for timestamp provided to match DebugString. // Checks that the timestamp is in the supported range for CEL values. absl::StatusOr FormatTimestamp(absl::Time timestamp); // Encodes timestamp as a string for JSON. // This implementation is compatible with protobuf. absl::StatusOr EncodeTimestampToJson(absl::Time timestamp); std::string DebugStringTimestamp(absl::Time timestamp); } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_TIME_H_ ================================================ FILE: internal/time_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/time.h" #include #include "absl/status/status.h" #include "absl/time/time.h" #include "internal/testing.h" #include "google/protobuf/util/time_util.h" namespace cel::internal { namespace { using ::absl_testing::StatusIs; TEST(MaxDuration, ProtoEquiv) { EXPECT_EQ(MaxDuration(), absl::Seconds(google::protobuf::util::TimeUtil::kDurationMaxSeconds) + absl::Nanoseconds(999999999)); } TEST(MinDuration, ProtoEquiv) { EXPECT_EQ(MinDuration(), absl::Seconds(google::protobuf::util::TimeUtil::kDurationMinSeconds) + absl::Nanoseconds(-999999999)); } TEST(MaxTimestamp, ProtoEquiv) { EXPECT_EQ(MaxTimestamp(), absl::UnixEpoch() + absl::Seconds(google::protobuf::util::TimeUtil::kTimestampMaxSeconds) + absl::Nanoseconds(999999999)); } TEST(MinTimestamp, ProtoEquiv) { EXPECT_EQ(MinTimestamp(), absl::UnixEpoch() + absl::Seconds(google::protobuf::util::TimeUtil::kTimestampMinSeconds)); } TEST(ParseDuration, Conformance) { absl::Duration parsed; ASSERT_OK_AND_ASSIGN(parsed, internal::ParseDuration("1s")); EXPECT_EQ(parsed, absl::Seconds(1)); ASSERT_OK_AND_ASSIGN(parsed, internal::ParseDuration("0.010s")); EXPECT_EQ(parsed, absl::Milliseconds(10)); ASSERT_OK_AND_ASSIGN(parsed, internal::ParseDuration("0.000010s")); EXPECT_EQ(parsed, absl::Microseconds(10)); ASSERT_OK_AND_ASSIGN(parsed, internal::ParseDuration("0.000000010s")); EXPECT_EQ(parsed, absl::Nanoseconds(10)); EXPECT_THAT(internal::ParseDuration("abc"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(internal::ParseDuration("1c"), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(FormatDuration, Conformance) { std::string formatted; ASSERT_OK_AND_ASSIGN(formatted, internal::FormatDuration(absl::Seconds(1))); EXPECT_EQ(formatted, "1s"); ASSERT_OK_AND_ASSIGN(formatted, internal::FormatDuration(absl::Milliseconds(10))); EXPECT_EQ(formatted, "10ms"); ASSERT_OK_AND_ASSIGN(formatted, internal::FormatDuration(absl::Microseconds(10))); EXPECT_EQ(formatted, "10us"); ASSERT_OK_AND_ASSIGN(formatted, internal::FormatDuration(absl::Nanoseconds(10))); EXPECT_EQ(formatted, "10ns"); EXPECT_THAT(internal::FormatDuration(absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(internal::FormatDuration(-absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ParseTimestamp, Conformance) { absl::Time parsed; ASSERT_OK_AND_ASSIGN(parsed, internal::ParseTimestamp("1-01-01T00:00:00Z")); EXPECT_EQ(parsed, MinTimestamp()); ASSERT_OK_AND_ASSIGN( parsed, internal::ParseTimestamp("9999-12-31T23:59:59.999999999Z")); EXPECT_EQ(parsed, MaxTimestamp()); ASSERT_OK_AND_ASSIGN(parsed, internal::ParseTimestamp("1970-01-01T00:00:00Z")); EXPECT_EQ(parsed, absl::UnixEpoch()); ASSERT_OK_AND_ASSIGN(parsed, internal::ParseTimestamp("1970-01-01T00:00:00.010Z")); EXPECT_EQ(parsed, absl::UnixEpoch() + absl::Milliseconds(10)); ASSERT_OK_AND_ASSIGN(parsed, internal::ParseTimestamp("1970-01-01T00:00:00.000010Z")); EXPECT_EQ(parsed, absl::UnixEpoch() + absl::Microseconds(10)); ASSERT_OK_AND_ASSIGN( parsed, internal::ParseTimestamp("1970-01-01T00:00:00.000000010Z")); EXPECT_EQ(parsed, absl::UnixEpoch() + absl::Nanoseconds(10)); EXPECT_THAT(internal::ParseTimestamp("abc"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(internal::ParseTimestamp("10000-01-01T00:00:00Z"), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(FormatTimestamp, Conformance) { std::string formatted; ASSERT_OK_AND_ASSIGN(formatted, internal::FormatTimestamp(MinTimestamp())); EXPECT_EQ(formatted, "1-01-01T00:00:00Z"); ASSERT_OK_AND_ASSIGN(formatted, internal::FormatTimestamp(MaxTimestamp())); EXPECT_EQ(formatted, "9999-12-31T23:59:59.999999999Z"); ASSERT_OK_AND_ASSIGN(formatted, internal::FormatTimestamp(absl::UnixEpoch())); EXPECT_EQ(formatted, "1970-01-01T00:00:00Z"); ASSERT_OK_AND_ASSIGN( formatted, internal::FormatTimestamp(absl::UnixEpoch() + absl::Milliseconds(10))); EXPECT_EQ(formatted, "1970-01-01T00:00:00.01Z"); ASSERT_OK_AND_ASSIGN( formatted, internal::FormatTimestamp(absl::UnixEpoch() + absl::Microseconds(10))); EXPECT_EQ(formatted, "1970-01-01T00:00:00.00001Z"); ASSERT_OK_AND_ASSIGN( formatted, internal::FormatTimestamp(absl::UnixEpoch() + absl::Nanoseconds(10))); EXPECT_EQ(formatted, "1970-01-01T00:00:00.00000001Z"); EXPECT_THAT(internal::FormatTimestamp(absl::InfiniteFuture()), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(internal::FormatTimestamp(absl::InfinitePast()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(EncodeDurationToJson, Conformance) { std::string formatted; ASSERT_OK_AND_ASSIGN(formatted, EncodeDurationToJson(absl::Seconds(1))); EXPECT_EQ(formatted, "1s"); ASSERT_OK_AND_ASSIGN(formatted, EncodeDurationToJson(absl::Milliseconds(10))); EXPECT_EQ(formatted, "0.010s"); ASSERT_OK_AND_ASSIGN(formatted, EncodeDurationToJson(absl::Microseconds(10))); EXPECT_EQ(formatted, "0.000010s"); ASSERT_OK_AND_ASSIGN(formatted, EncodeDurationToJson(absl::Nanoseconds(10))); EXPECT_EQ(formatted, "0.000000010s"); EXPECT_THAT(EncodeDurationToJson(absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(EncodeDurationToJson(-absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(EncodeTimestampToJson, Conformance) { std::string formatted; ASSERT_OK_AND_ASSIGN(formatted, EncodeTimestampToJson(MinTimestamp())); EXPECT_EQ(formatted, "0001-01-01T00:00:00Z"); ASSERT_OK_AND_ASSIGN(formatted, EncodeTimestampToJson(MaxTimestamp())); EXPECT_EQ(formatted, "9999-12-31T23:59:59.999999999Z"); ASSERT_OK_AND_ASSIGN(formatted, EncodeTimestampToJson(absl::UnixEpoch())); EXPECT_EQ(formatted, "1970-01-01T00:00:00Z"); ASSERT_OK_AND_ASSIGN( formatted, EncodeTimestampToJson(absl::UnixEpoch() + absl::Milliseconds(10))); EXPECT_EQ(formatted, "1970-01-01T00:00:00.010Z"); ASSERT_OK_AND_ASSIGN( formatted, EncodeTimestampToJson(absl::UnixEpoch() + absl::Microseconds(10))); EXPECT_EQ(formatted, "1970-01-01T00:00:00.000010Z"); ASSERT_OK_AND_ASSIGN(formatted, EncodeTimestampToJson(absl::UnixEpoch() + absl::Nanoseconds(10))); EXPECT_EQ(formatted, "1970-01-01T00:00:00.000000010Z"); EXPECT_THAT(EncodeTimestampToJson(absl::InfiniteFuture()), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(EncodeTimestampToJson(absl::InfinitePast()), StatusIs(absl::StatusCode::kInvalidArgument)); } } // namespace } // namespace cel::internal ================================================ FILE: internal/to_address.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_TO_ADDRESS_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_TO_ADDRESS_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/meta/type_traits.h" namespace cel::internal { // ----------------------------------------------------------------------------- // Function Template: to_address() // ----------------------------------------------------------------------------- // // Backport of std::to_address introduced in C++20. Enables obtaining the // address of an object regardless of whether the pointer is raw or fancy. #if defined(__cpp_lib_to_address) && __cpp_lib_to_address >= 201711L using std::to_address; #else template constexpr T* to_address(T* ptr) noexcept { static_assert(!std::is_function::value, "T must not be a function"); return ptr; } template struct PointerTraitsToAddress { static constexpr auto Dispatch( const T& p ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { return internal::to_address(p.operator->()); } }; template struct PointerTraitsToAddress< T, std::void_t::to_address( std::declval()))> > { static constexpr auto Dispatch( const T& p ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { return std::pointer_traits::to_address(p); } }; template constexpr auto to_address(const T& ptr ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { return PointerTraitsToAddress::Dispatch(ptr); } #endif } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_TO_ADDRESS_H_ ================================================ FILE: internal/to_address_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/to_address.h" #include #include "internal/testing.h" namespace cel { namespace { TEST(ToAddress, RawPointer) { char c; EXPECT_EQ(internal::to_address(&c), &c); } struct ImplicitFancyPointer { using element_type = char; char* operator->() const { return ptr; } char* ptr; }; struct ExplicitFancyPointer { char* ptr; }; } // namespace } // namespace cel namespace std { template <> struct pointer_traits : pointer_traits { static constexpr char* to_address( const cel::ExplicitFancyPointer& efp) noexcept { return efp.ptr; } }; } // namespace std namespace cel { namespace { TEST(ToAddress, FancyPointerNoPointerTraits) { char c; ImplicitFancyPointer ip{&c}; EXPECT_EQ(internal::to_address(ip), &c); } TEST(ToAddress, FancyPointerWithPointerTraits) { char c; ExplicitFancyPointer ip{&c}; EXPECT_EQ(internal::to_address(ip), &c); } } // namespace } // namespace cel ================================================ FILE: internal/unicode.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_UNICODE_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_UNICODE_H_ namespace cel::internal { inline constexpr char32_t kUnicodeReplacementCharacter = 0xfffd; constexpr bool UnicodeIsValid(char32_t code_point) { return code_point < 0xd800 || (code_point > 0xdfff && code_point <= 0x10ffff); } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_UNICODE_H_ ================================================ FILE: internal/utf8.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/utf8.h" #include #include #include #include #include #include "absl/base/macros.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/log/absl_check.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "internal/unicode.h" // Implementation is based on // https://go.googlesource.com/go/+/refs/heads/master/src/unicode/utf8/utf8.go // but adapted for C++. namespace cel::internal { namespace { constexpr uint8_t kUtf8RuneSelf = 0x80; constexpr size_t kUtf8Max = 4; constexpr uint8_t kLow = 0x80; constexpr uint8_t kHigh = 0xbf; constexpr uint8_t kMaskX = 0x3f; constexpr uint8_t kMask2 = 0x1f; constexpr uint8_t kMask3 = 0xf; constexpr uint8_t kMask4 = 0x7; constexpr uint8_t kTX = 0x80; constexpr uint8_t kT2 = 0xc0; constexpr uint8_t kT3 = 0xe0; constexpr uint8_t kT4 = 0xf0; constexpr uint8_t kXX = 0xf1; constexpr uint8_t kAS = 0xf0; constexpr uint8_t kS1 = 0x02; constexpr uint8_t kS2 = 0x13; constexpr uint8_t kS3 = 0x03; constexpr uint8_t kS4 = 0x23; constexpr uint8_t kS5 = 0x34; constexpr uint8_t kS6 = 0x04; constexpr uint8_t kS7 = 0x44; // NOLINTBEGIN // clang-format off constexpr uint8_t kLeading[256] = { // 1 2 3 4 5 6 7 8 9 A B C D E F kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x00-0x0F kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x10-0x1F kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x20-0x2F kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x30-0x3F kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x40-0x4F kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x50-0x5F kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x60-0x6F kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, kAS, // 0x70-0x7F // 1 2 3 4 5 6 7 8 9 A B C D E F kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0x80-0x8F kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0x90-0x9F kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0xA0-0xAF kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0xB0-0xBF kXX, kXX, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, // 0xC0-0xCF kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, kS1, // 0xD0-0xDF kS2, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS3, kS4, kS3, kS3, // 0xE0-0xEF kS5, kS6, kS6, kS6, kS7, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, kXX, // 0xF0-0xFF }; // clang-format on // NOLINTEND constexpr std::pair kAccept[16] = { {kLow, kHigh}, {0xa0, kHigh}, {kLow, 0x9f}, {0x90, kHigh}, {kLow, 0x8f}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, }; class StringReader final { public: constexpr explicit StringReader(absl::string_view input) : input_(input) {} size_t Remaining() const { return input_.size(); } bool HasRemaining() const { return !input_.empty(); } absl::string_view Peek(size_t n) { ABSL_ASSERT(n <= Remaining()); return input_.substr(0, n); } char Read() { ABSL_ASSERT(HasRemaining()); char value = input_.front(); input_.remove_prefix(1); return value; } void Advance(size_t n) { ABSL_ASSERT(n <= Remaining()); input_.remove_prefix(n); } void Reset(absl::string_view input) { input_ = input; } private: absl::string_view input_; }; class CordReader final { public: explicit CordReader(const absl::Cord& input) : input_(input), size_(input_.size()), buffer_(), index_(0) {} size_t Remaining() const { return size_; } bool HasRemaining() const { return size_ != 0; } absl::string_view Peek(size_t n) { ABSL_ASSERT(n <= Remaining()); if (n == 0) { return absl::string_view(); } if (n <= buffer_.size() - index_) { // Enough data remaining in temporary buffer. return absl::string_view(buffer_.data() + index_, n); } // We do not have enough data. See if we can fit it without allocating by // shifting data back to the beginning of the buffer. if (buffer_.capacity() >= n) { // It will fit in the current capacity, see if we need to shift the // existing data to make it fit. if (buffer_.capacity() - buffer_.size() < n && index_ != 0) { // We need to shift. buffer_.erase(buffer_.begin(), buffer_.begin() + index_); index_ = 0; } } // Ensure we never reserve less than kUtf8Max. buffer_.reserve(std::max(buffer_.size() + n, kUtf8Max)); size_t to_copy = n - (buffer_.size() - index_); absl::CopyCordToString(input_.Subcord(0, to_copy), &buffer_); input_.RemovePrefix(to_copy); return absl::string_view(buffer_.data() + index_, n); } char Read() { char value = Peek(1).front(); Advance(1); return value; } void Advance(size_t n) { ABSL_ASSERT(n <= Remaining()); if (n == 0) { return; } if (index_ < buffer_.size()) { size_t count = std::min(n, buffer_.size() - index_); index_ += count; n -= count; size_ -= count; if (index_ < buffer_.size()) { return; } // Temporary buffer is empty, clear it. buffer_.clear(); index_ = 0; } input_.RemovePrefix(n); size_ -= n; } void Reset(const absl::Cord& input) { input_ = input; size_ = input_.size(); buffer_.clear(); index_ = 0; } private: absl::Cord input_; size_t size_; std::string buffer_; size_t index_; }; template bool Utf8IsValidImpl(BufferedByteReader* reader) { while (reader->HasRemaining()) { const auto b = static_cast(reader->Read()); if (b < kUtf8RuneSelf) { continue; } const auto leading = kLeading[b]; if (leading == kXX) { return false; } const auto size = static_cast(leading & 7) - 1; if (size > reader->Remaining()) { return false; } const absl::string_view segment = reader->Peek(size); const auto& accept = kAccept[leading >> 4]; if (static_cast(segment[0]) < accept.first || static_cast(segment[0]) > accept.second) { return false; } else if (size == 1) { } else if (static_cast(segment[1]) < kLow || static_cast(segment[1]) > kHigh) { return false; } else if (size == 2) { } else if (static_cast(segment[2]) < kLow || static_cast(segment[2]) > kHigh) { return false; } reader->Advance(size); } return true; } template size_t Utf8CodePointCountImpl(BufferedByteReader* reader) { size_t count = 0; while (reader->HasRemaining()) { count++; const auto b = static_cast(reader->Read()); if (b < kUtf8RuneSelf) { continue; } const auto leading = kLeading[b]; if (leading == kXX) { continue; } auto size = static_cast(leading & 7) - 1; if (size > reader->Remaining()) { continue; } const absl::string_view segment = reader->Peek(size); const auto& accept = kAccept[leading >> 4]; if (static_cast(segment[0]) < accept.first || static_cast(segment[0]) > accept.second) { size = 0; } else if (size == 1) { } else if (static_cast(segment[1]) < kLow || static_cast(segment[1]) > kHigh) { size = 0; } else if (size == 2) { } else if (static_cast(segment[2]) < kLow || static_cast(segment[2]) > kHigh) { size = 0; } reader->Advance(size); } return count; } template std::pair Utf8ValidateImpl(BufferedByteReader* reader) { size_t count = 0; while (reader->HasRemaining()) { const auto b = static_cast(reader->Read()); if (b < kUtf8RuneSelf) { count++; continue; } const auto leading = kLeading[b]; if (leading == kXX) { return {count, false}; } const auto size = static_cast(leading & 7) - 1; if (size > reader->Remaining()) { return {count, false}; } const absl::string_view segment = reader->Peek(size); const auto& accept = kAccept[leading >> 4]; if (static_cast(segment[0]) < accept.first || static_cast(segment[0]) > accept.second) { return {count, false}; } else if (size == 1) { count++; } else if (static_cast(segment[1]) < kLow || static_cast(segment[1]) > kHigh) { return {count, false}; } else if (size == 2) { count++; } else if (static_cast(segment[2]) < kLow || static_cast(segment[2]) > kHigh) { return {count, false}; } else { count++; } reader->Advance(size); } return {count, true}; } } // namespace bool Utf8IsValid(absl::string_view str) { StringReader reader(str); bool valid = Utf8IsValidImpl(&reader); ABSL_ASSERT((reader.Reset(str), valid == Utf8ValidateImpl(&reader).second)); return valid; } bool Utf8IsValid(const absl::Cord& str) { CordReader reader(str); bool valid = Utf8IsValidImpl(&reader); ABSL_ASSERT((reader.Reset(str), valid == Utf8ValidateImpl(&reader).second)); return valid; } size_t Utf8CodePointCount(absl::string_view str) { StringReader reader(str); return Utf8CodePointCountImpl(&reader); } size_t Utf8CodePointCount(const absl::Cord& str) { CordReader reader(str); return Utf8CodePointCountImpl(&reader); } std::pair Utf8Validate(absl::string_view str) { StringReader reader(str); auto result = Utf8ValidateImpl(&reader); ABSL_ASSERT((reader.Reset(str), result.second == Utf8IsValidImpl(&reader))); return result; } std::pair Utf8Validate(const absl::Cord& str) { CordReader reader(str); auto result = Utf8ValidateImpl(&reader); ABSL_ASSERT((reader.Reset(str), result.second == Utf8IsValidImpl(&reader))); return result; } namespace { size_t Utf8DecodeImpl(uint8_t b, uint8_t leading, size_t size, absl::string_view str, char32_t* absl_nullable code_point) { const auto& accept = kAccept[leading >> 4]; const auto b1 = static_cast(str.front()); if (ABSL_PREDICT_FALSE(b1 < accept.first || b1 > accept.second)) { if (code_point != nullptr) { *code_point = kUnicodeReplacementCharacter; } return 1; } if (size <= 1) { if (code_point != nullptr) { *code_point = (static_cast(b & kMask2) << 6) | static_cast(b1 & kMaskX); } return 2; } str.remove_prefix(1); const auto b2 = static_cast(str.front()); if (ABSL_PREDICT_FALSE(b2 < kLow || b2 > kHigh)) { if (code_point != nullptr) { *code_point = kUnicodeReplacementCharacter; } return 1; } if (size <= 2) { if (code_point != nullptr) { *code_point = (static_cast(b & kMask3) << 12) | (static_cast(b1 & kMaskX) << 6) | static_cast(b2 & kMaskX); } return 3; } str.remove_prefix(1); const auto b3 = static_cast(str.front()); if (ABSL_PREDICT_FALSE(b3 < kLow || b3 > kHigh)) { if (code_point != nullptr) { *code_point = kUnicodeReplacementCharacter; } return 1; } if (code_point != nullptr) { *code_point = (static_cast(b & kMask4) << 18) | (static_cast(b1 & kMaskX) << 12) | (static_cast(b2 & kMaskX) << 6) | static_cast(b3 & kMaskX); } return 4; } } // namespace size_t Utf8Decode(absl::string_view str, char32_t* absl_nullable code_point) { ABSL_DCHECK(!str.empty()); const auto b = static_cast(str.front()); if (b < kUtf8RuneSelf) { if (code_point != nullptr) { *code_point = static_cast(b); } return 1; } const auto leading = kLeading[b]; if (ABSL_PREDICT_FALSE(leading == kXX)) { if (code_point != nullptr) { *code_point = kUnicodeReplacementCharacter; } return 1; } auto size = static_cast(leading & 7) - 1; str.remove_prefix(1); if (ABSL_PREDICT_FALSE(size > str.size())) { if (code_point != nullptr) { *code_point = kUnicodeReplacementCharacter; } return 1; } return Utf8DecodeImpl(b, leading, size, str, code_point); } size_t Utf8Decode(const absl::Cord::CharIterator& it, char32_t* absl_nullable code_point) { absl::string_view str = absl::Cord::ChunkRemaining(it); ABSL_DCHECK(!str.empty()); const auto b = static_cast(str.front()); if (b < kUtf8RuneSelf) { if (code_point != nullptr) { *code_point = static_cast(b); } return 1; } const auto leading = kLeading[b]; if (ABSL_PREDICT_FALSE(leading == kXX)) { if (code_point != nullptr) { *code_point = kUnicodeReplacementCharacter; } return 1; } auto size = static_cast(leading & 7) - 1; str.remove_prefix(1); if (ABSL_PREDICT_TRUE(size <= str.size())) { // Fast path. return Utf8DecodeImpl(b, leading, size, str, code_point); } absl::Cord::CharIterator current = it; absl::Cord::Advance(¤t, 1); char buffer[3]; size_t buffer_len = 0; while (buffer_len < size) { str = absl::Cord::ChunkRemaining(current); if (ABSL_PREDICT_FALSE(str.empty())) { if (code_point != nullptr) { *code_point = kUnicodeReplacementCharacter; } return 1; } size_t to_copy = std::min(size_t{3} - buffer_len, str.size()); std::memcpy(buffer + buffer_len, str.data(), to_copy); buffer_len += to_copy; absl::Cord::Advance(¤t, to_copy); } return Utf8DecodeImpl(b, leading, size, absl::string_view(buffer, buffer_len), code_point); } size_t Utf8Encode(char32_t code_point, std::string* absl_nonnull buffer) { ABSL_DCHECK(buffer != nullptr); char storage[4]; size_t storage_len = Utf8Encode(code_point, storage); buffer->append(storage, storage_len); return storage_len; } size_t Utf8Encode(char32_t code_point, char* absl_nonnull buffer) { ABSL_DCHECK(buffer != nullptr); if (ABSL_PREDICT_FALSE(!UnicodeIsValid(code_point))) { code_point = kUnicodeReplacementCharacter; } size_t storage_len = 0; if (code_point <= 0x7f) { buffer[storage_len++] = static_cast(static_cast(code_point)); } else if (code_point <= 0x7ff) { buffer[storage_len++] = static_cast(kT2 | static_cast(code_point >> 6)); buffer[storage_len++] = static_cast(kTX | (static_cast(code_point) & kMaskX)); } else if (code_point <= 0xffff) { buffer[storage_len++] = static_cast(kT3 | static_cast(code_point >> 12)); buffer[storage_len++] = static_cast( kTX | (static_cast(code_point >> 6) & kMaskX)); buffer[storage_len++] = static_cast(kTX | (static_cast(code_point) & kMaskX)); } else { buffer[storage_len++] = static_cast(kT4 | static_cast(code_point >> 18)); buffer[storage_len++] = static_cast( kTX | (static_cast(code_point >> 12) & kMaskX)); buffer[storage_len++] = static_cast( kTX | (static_cast(code_point >> 6) & kMaskX)); buffer[storage_len++] = static_cast(kTX | (static_cast(code_point) & kMaskX)); } return storage_len; } } // namespace cel::internal ================================================ FILE: internal/utf8.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_UTF8_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_UTF8_H_ #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" namespace cel::internal { // Returns true if the given UTF-8 encoded string is not malformed, false // otherwise. bool Utf8IsValid(absl::string_view str); bool Utf8IsValid(const absl::Cord& str); // Returns the number of Unicode code points in the UTF-8 encoded string. // // If there are any invalid bytes, they will each be counted as an invalid code // point. size_t Utf8CodePointCount(absl::string_view str); size_t Utf8CodePointCount(const absl::Cord& str); // Validates the given UTF-8 encoded string. The first return value is the // number of code points and its meaning depends on the second return value. If // the second return value is true the entire string is not malformed and the // first return value is the number of code points. If the second return value // is false the string is malformed and the first return value is the number of // code points up until the malformed sequence was encountered. std::pair Utf8Validate(absl::string_view str); std::pair Utf8Validate(const absl::Cord& str); // Decodes the next code point, returning the decoded code point and the number // of code units (a.k.a. bytes) consumed. In the event that an invalid code unit // sequence is returned the replacement character, U+FFFD, is returned with a // code unit count of 1. As U+FFFD requires 3 code units when encoded, this can // be used to differentiate valid input from malformed input. size_t Utf8Decode(absl::string_view str, char32_t* absl_nullable code_point); size_t Utf8Decode(const absl::Cord::CharIterator& it, char32_t* absl_nullable code_point); inline std::pair Utf8Decode(absl::string_view str) { char32_t code_point; size_t code_units = Utf8Decode(str, &code_point); return std::pair{code_point, code_units}; } inline std::pair Utf8Decode( const absl::Cord::CharIterator& it) { char32_t code_point; size_t code_units = Utf8Decode(it, &code_point); return std::pair{code_point, code_units}; } // Encodes the given code point and appends it to the buffer. If the code point // is an unpaired surrogate or outside of the valid Unicode range it is replaced // with the replacement character, U+FFFD. size_t Utf8Encode(char32_t code_point, std::string* absl_nonnull buffer); size_t Utf8Encode(char32_t code_point, char* absl_nonnull buffer); ABSL_DEPRECATED("Use other overload") inline size_t Utf8Encode(std::string& buffer, char32_t code_point) { return Utf8Encode(code_point, &buffer); } } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_UTF8_H_ ================================================ FILE: internal/utf8_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/utf8.h" #include #include #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" #include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "internal/benchmark.h" #include "internal/testing.h" // Tests is based on // https://go.googlesource.com/go/+/refs/heads/master/src/unicode/utf8/utf8.go // but adapted for C++. namespace cel::internal { namespace { TEST(Utf8IsValid, String) { EXPECT_TRUE(Utf8IsValid("")); EXPECT_TRUE(Utf8IsValid("a")); EXPECT_TRUE(Utf8IsValid("abc")); EXPECT_TRUE(Utf8IsValid("\xd0\x96")); EXPECT_TRUE(Utf8IsValid("\xd0\x96\xd0\x96")); EXPECT_TRUE(Utf8IsValid( "\xd0\xb1\xd1\x80\xd1\x8d\xd0\xb4-\xd0\x9b\xd0\x93\xd0\xa2\xd0\x9c")); EXPECT_TRUE(Utf8IsValid("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9")); EXPECT_TRUE(Utf8IsValid("a\ufffdb")); EXPECT_TRUE(Utf8IsValid("\xf4\x8f\xbf\xbf")); EXPECT_FALSE(Utf8IsValid("\x42\xfa")); EXPECT_FALSE(Utf8IsValid("\x42\xfa\x43")); EXPECT_FALSE(Utf8IsValid("\xf4\x90\x80\x80")); EXPECT_FALSE(Utf8IsValid("\xf7\xbf\xbf\xbf")); EXPECT_FALSE(Utf8IsValid("\xfb\xbf\xbf\xbf\xbf")); EXPECT_FALSE(Utf8IsValid("\xc0\x80")); EXPECT_FALSE(Utf8IsValid("\xed\xa0\x80")); EXPECT_FALSE(Utf8IsValid("\xed\xbf\xbf")); } TEST(Utf8IsValid, Cord) { EXPECT_TRUE(Utf8IsValid(absl::Cord(""))); EXPECT_TRUE(Utf8IsValid(absl::Cord("a"))); EXPECT_TRUE(Utf8IsValid(absl::Cord("abc"))); EXPECT_TRUE(Utf8IsValid(absl::Cord("\xd0\x96"))); EXPECT_TRUE(Utf8IsValid(absl::Cord("\xd0\x96\xd0\x96"))); EXPECT_TRUE(Utf8IsValid(absl::Cord( "\xd0\xb1\xd1\x80\xd1\x8d\xd0\xb4-\xd0\x9b\xd0\x93\xd0\xa2\xd0\x9c"))); EXPECT_TRUE(Utf8IsValid(absl::Cord("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9"))); EXPECT_TRUE(Utf8IsValid(absl::Cord("a\ufffdb"))); EXPECT_TRUE(Utf8IsValid(absl::Cord("\xf4\x8f\xbf\xbf"))); EXPECT_FALSE(Utf8IsValid(absl::Cord("\x42\xfa"))); EXPECT_FALSE(Utf8IsValid(absl::Cord("\x42\xfa\x43"))); EXPECT_FALSE(Utf8IsValid(absl::Cord("\xf4\x90\x80\x80"))); EXPECT_FALSE(Utf8IsValid(absl::Cord("\xf7\xbf\xbf\xbf"))); EXPECT_FALSE(Utf8IsValid(absl::Cord("\xfb\xbf\xbf\xbf\xbf"))); EXPECT_FALSE(Utf8IsValid(absl::Cord("\xc0\x80"))); EXPECT_FALSE(Utf8IsValid(absl::Cord("\xed\xa0\x80"))); EXPECT_FALSE(Utf8IsValid(absl::Cord("\xed\xbf\xbf"))); } TEST(Utf8CodePointCount, String) { EXPECT_EQ(Utf8CodePointCount("abcd"), 4); EXPECT_EQ(Utf8CodePointCount("1,2,3,4"), 7); EXPECT_EQ(Utf8CodePointCount("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9"), 3); EXPECT_EQ(Utf8CodePointCount(absl::string_view("\xe2\x00", 2)), 2); EXPECT_EQ(Utf8CodePointCount("\xe2\x80"), 2); EXPECT_EQ(Utf8CodePointCount("a\xe2\x80"), 3); } TEST(Utf8CodePointCount, Cord) { EXPECT_EQ(Utf8CodePointCount(absl::Cord("abcd")), 4); EXPECT_EQ(Utf8CodePointCount(absl::Cord("1,2,3,4")), 7); EXPECT_EQ( Utf8CodePointCount(absl::Cord("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9")), 3); EXPECT_EQ(Utf8CodePointCount(absl::Cord(absl::string_view("\xe2\x00", 2))), 2); EXPECT_EQ(Utf8CodePointCount(absl::Cord("\xe2\x80")), 2); EXPECT_EQ(Utf8CodePointCount(absl::Cord("a\xe2\x80")), 3); } TEST(Utf8Validate, String) { EXPECT_TRUE(Utf8Validate("").second); EXPECT_TRUE(Utf8Validate("a").second); EXPECT_TRUE(Utf8Validate("abc").second); EXPECT_TRUE(Utf8Validate("\xd0\x96").second); EXPECT_TRUE(Utf8Validate("\xd0\x96\xd0\x96").second); EXPECT_TRUE( Utf8Validate( "\xd0\xb1\xd1\x80\xd1\x8d\xd0\xb4-\xd0\x9b\xd0\x93\xd0\xa2\xd0\x9c") .second); EXPECT_TRUE(Utf8Validate("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9").second); EXPECT_TRUE(Utf8Validate("a\ufffdb").second); EXPECT_TRUE(Utf8Validate("\xf4\x8f\xbf\xbf").second); EXPECT_FALSE(Utf8Validate("\x42\xfa").second); EXPECT_FALSE(Utf8Validate("\x42\xfa\x43").second); EXPECT_FALSE(Utf8Validate("\xf4\x90\x80\x80").second); EXPECT_FALSE(Utf8Validate("\xf7\xbf\xbf\xbf").second); EXPECT_FALSE(Utf8Validate("\xfb\xbf\xbf\xbf\xbf").second); EXPECT_FALSE(Utf8Validate("\xc0\x80").second); EXPECT_FALSE(Utf8Validate("\xed\xa0\x80").second); EXPECT_FALSE(Utf8Validate("\xed\xbf\xbf").second); EXPECT_EQ(Utf8Validate("abcd").first, 4); EXPECT_EQ(Utf8Validate("1,2,3,4").first, 7); EXPECT_EQ(Utf8Validate("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9").first, 3); EXPECT_EQ(Utf8Validate(absl::string_view("\xe2\x00", 2)).first, 0); EXPECT_EQ(Utf8Validate("\xe2\x80").first, 0); EXPECT_EQ(Utf8Validate("a\xe2\x80").first, 1); } TEST(Utf8Validate, Cord) { EXPECT_TRUE(Utf8Validate(absl::Cord("")).second); EXPECT_TRUE(Utf8Validate(absl::Cord("a")).second); EXPECT_TRUE(Utf8Validate(absl::Cord("abc")).second); EXPECT_TRUE(Utf8Validate(absl::Cord("\xd0\x96")).second); EXPECT_TRUE(Utf8Validate(absl::Cord("\xd0\x96\xd0\x96")).second); EXPECT_TRUE(Utf8Validate(absl::Cord("\xd0\xb1\xd1\x80\xd1\x8d\xd0\xb4-" "\xd0\x9b\xd0\x93\xd0\xa2\xd0\x9c")) .second); EXPECT_TRUE( Utf8Validate(absl::Cord("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9")).second); EXPECT_TRUE(Utf8Validate(absl::Cord("a\ufffdb")).second); EXPECT_TRUE(Utf8Validate(absl::Cord("\xf4\x8f\xbf\xbf")).second); EXPECT_FALSE(Utf8Validate(absl::Cord("\x42\xfa")).second); EXPECT_FALSE(Utf8Validate(absl::Cord("\x42\xfa\x43")).second); EXPECT_FALSE(Utf8Validate(absl::Cord("\xf4\x90\x80\x80")).second); EXPECT_FALSE(Utf8Validate(absl::Cord("\xf7\xbf\xbf\xbf")).second); EXPECT_FALSE(Utf8Validate(absl::Cord("\xfb\xbf\xbf\xbf\xbf")).second); EXPECT_FALSE(Utf8Validate(absl::Cord("\xc0\x80")).second); EXPECT_FALSE(Utf8Validate(absl::Cord("\xed\xa0\x80")).second); EXPECT_FALSE(Utf8Validate(absl::Cord("\xed\xbf\xbf")).second); EXPECT_EQ(Utf8Validate(absl::Cord("abcd")).first, 4); EXPECT_EQ(Utf8Validate(absl::Cord("1,2,3,4")).first, 7); EXPECT_EQ( Utf8Validate(absl::Cord("\xe2\x98\xba\xe2\x98\xbb\xe2\x98\xb9")).first, 3); EXPECT_EQ(Utf8Validate(absl::Cord(absl::string_view("\xe2\x00", 2))).first, 0); EXPECT_EQ(Utf8Validate(absl::Cord("\xe2\x80")).first, 0); EXPECT_EQ(Utf8Validate(absl::Cord("a\xe2\x80")).first, 1); } struct Utf8EncodeTestCase final { char32_t code_point; absl::string_view code_units; }; using Utf8EncodeTest = testing::TestWithParam; TEST_P(Utf8EncodeTest, Compliance) { const Utf8EncodeTestCase& test_case = GetParam(); std::string result; EXPECT_EQ(Utf8Encode(result, test_case.code_point), test_case.code_units.size()); EXPECT_EQ(result, test_case.code_units); } INSTANTIATE_TEST_SUITE_P(Utf8EncodeTest, Utf8EncodeTest, testing::ValuesIn({ {0x0000, absl::string_view("\x00", 1)}, {0x0001, "\x01"}, {0x007e, "\x7e"}, {0x007f, "\x7f"}, {0x0080, "\xc2\x80"}, {0x0081, "\xc2\x81"}, {0x00bf, "\xc2\xbf"}, {0x00c0, "\xc3\x80"}, {0x00c1, "\xc3\x81"}, {0x00c8, "\xc3\x88"}, {0x00d0, "\xc3\x90"}, {0x00e0, "\xc3\xa0"}, {0x00f0, "\xc3\xb0"}, {0x00f8, "\xc3\xb8"}, {0x00ff, "\xc3\xbf"}, {0x0100, "\xc4\x80"}, {0x07ff, "\xdf\xbf"}, {0x0400, "\xd0\x80"}, {0x0800, "\xe0\xa0\x80"}, {0x0801, "\xe0\xa0\x81"}, {0x1000, "\xe1\x80\x80"}, {0xd000, "\xed\x80\x80"}, {0xd7ff, "\xed\x9f\xbf"}, {0xe000, "\xee\x80\x80"}, {0xfffe, "\xef\xbf\xbe"}, {0xffff, "\xef\xbf\xbf"}, {0x10000, "\xf0\x90\x80\x80"}, {0x10001, "\xf0\x90\x80\x81"}, {0x40000, "\xf1\x80\x80\x80"}, {0x10fffe, "\xf4\x8f\xbf\xbe"}, {0x10ffff, "\xf4\x8f\xbf\xbf"}, {0xFFFD, "\xef\xbf\xbd"}, })); struct Utf8DecodeTestCase final { char32_t code_point; absl::string_view code_units; }; using Utf8DecodeTest = testing::TestWithParam; TEST_P(Utf8DecodeTest, StringView) { const Utf8DecodeTestCase& test_case = GetParam(); auto [code_point, code_units] = Utf8Decode(test_case.code_units); EXPECT_EQ(code_units, test_case.code_units.size()) << absl::CHexEscape(test_case.code_units); EXPECT_EQ(code_point, test_case.code_point) << absl::CHexEscape(test_case.code_units); EXPECT_EQ(Utf8Decode(test_case.code_units, nullptr), test_case.code_units.size()); } TEST_P(Utf8DecodeTest, Cord) { const Utf8DecodeTestCase& test_case = GetParam(); auto cord = absl::Cord(test_case.code_units); auto it = cord.char_begin(); auto [code_point, code_units] = Utf8Decode(it); absl::Cord::Advance(&it, code_units); EXPECT_EQ(it, cord.char_end()); EXPECT_EQ(code_units, test_case.code_units.size()) << absl::CHexEscape(test_case.code_units); EXPECT_EQ(code_point, test_case.code_point) << absl::CHexEscape(test_case.code_units); it = cord.char_begin(); EXPECT_EQ(Utf8Decode(it, nullptr), test_case.code_units.size()); } std::vector FragmentString(absl::string_view text) { std::vector fragments; fragments.reserve(text.size()); for (const auto& c : text) { fragments.emplace_back().push_back(c); } return fragments; } TEST_P(Utf8DecodeTest, CordFragmented) { const Utf8DecodeTestCase& test_case = GetParam(); auto cord = absl::MakeFragmentedCord(FragmentString(test_case.code_units)); auto it = cord.char_begin(); auto [code_point, code_units] = Utf8Decode(it); absl::Cord::Advance(&it, code_units); EXPECT_EQ(it, cord.char_end()); EXPECT_EQ(code_units, test_case.code_units.size()) << absl::CHexEscape(test_case.code_units); EXPECT_EQ(code_point, test_case.code_point) << absl::CHexEscape(test_case.code_units); } INSTANTIATE_TEST_SUITE_P(Utf8DecodeTest, Utf8DecodeTest, testing::ValuesIn({ {0x0000, absl::string_view("\x00", 1)}, {0x0001, "\x01"}, {0x007e, "\x7e"}, {0x007f, "\x7f"}, {0x0080, "\xc2\x80"}, {0x0081, "\xc2\x81"}, {0x00bf, "\xc2\xbf"}, {0x00c0, "\xc3\x80"}, {0x00c1, "\xc3\x81"}, {0x00c8, "\xc3\x88"}, {0x00d0, "\xc3\x90"}, {0x00e0, "\xc3\xa0"}, {0x00f0, "\xc3\xb0"}, {0x00f8, "\xc3\xb8"}, {0x00ff, "\xc3\xbf"}, {0x0100, "\xc4\x80"}, {0x07ff, "\xdf\xbf"}, {0x0400, "\xd0\x80"}, {0x0800, "\xe0\xa0\x80"}, {0x0801, "\xe0\xa0\x81"}, {0x1000, "\xe1\x80\x80"}, {0xd000, "\xed\x80\x80"}, {0xd7ff, "\xed\x9f\xbf"}, {0xe000, "\xee\x80\x80"}, {0xfffe, "\xef\xbf\xbe"}, {0xffff, "\xef\xbf\xbf"}, {0x10000, "\xf0\x90\x80\x80"}, {0x10001, "\xf0\x90\x80\x81"}, {0x40000, "\xf1\x80\x80\x80"}, {0x10fffe, "\xf4\x8f\xbf\xbe"}, {0x10ffff, "\xf4\x8f\xbf\xbf"}, {0xFFFD, "\xef\xbf\xbd"}, })); void BM_Utf8CodePointCount_String_AsciiTen(benchmark::State& state) { for (auto s : state) { benchmark::DoNotOptimize(Utf8CodePointCount("0123456789")); } } BENCHMARK(BM_Utf8CodePointCount_String_AsciiTen); void BM_Utf8CodePointCount_Cord_AsciiTen(benchmark::State& state) { absl::Cord value("0123456789"); for (auto s : state) { benchmark::DoNotOptimize(Utf8CodePointCount(value)); } } BENCHMARK(BM_Utf8CodePointCount_Cord_AsciiTen); void BM_Utf8CodePointCount_String_JapaneseTen(benchmark::State& state) { for (auto s : state) { benchmark::DoNotOptimize(Utf8CodePointCount( "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5")); } } BENCHMARK(BM_Utf8CodePointCount_String_JapaneseTen); void BM_Utf8CodePointCount_Cord_JapaneseTen(benchmark::State& state) { absl::Cord value( "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5"); for (auto s : state) { benchmark::DoNotOptimize(Utf8CodePointCount(value)); } } BENCHMARK(BM_Utf8CodePointCount_Cord_JapaneseTen); void BM_Utf8IsValid_String_AsciiTen(benchmark::State& state) { for (auto s : state) { benchmark::DoNotOptimize(Utf8IsValid("0123456789")); } } BENCHMARK(BM_Utf8IsValid_String_AsciiTen); void BM_Utf8IsValid_Cord_AsciiTen(benchmark::State& state) { absl::Cord value("0123456789"); for (auto s : state) { benchmark::DoNotOptimize(Utf8IsValid(value)); } } BENCHMARK(BM_Utf8IsValid_Cord_AsciiTen); void BM_Utf8IsValid_String_JapaneseTen(benchmark::State& state) { for (auto s : state) { benchmark::DoNotOptimize(Utf8IsValid( "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5")); } } BENCHMARK(BM_Utf8IsValid_String_JapaneseTen); void BM_Utf8IsValid_Cord_JapaneseTen(benchmark::State& state) { absl::Cord value( "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5"); for (auto s : state) { benchmark::DoNotOptimize(Utf8IsValid(value)); } } BENCHMARK(BM_Utf8IsValid_Cord_JapaneseTen); void BM_Utf8Validate_String_AsciiTen(benchmark::State& state) { for (auto s : state) { benchmark::DoNotOptimize(Utf8Validate("0123456789")); } } BENCHMARK(BM_Utf8Validate_String_AsciiTen); void BM_Utf8Validate_Cord_AsciiTen(benchmark::State& state) { absl::Cord value("0123456789"); for (auto s : state) { benchmark::DoNotOptimize(Utf8Validate(value)); } } BENCHMARK(BM_Utf8Validate_Cord_AsciiTen); void BM_Utf8Validate_String_JapaneseTen(benchmark::State& state) { for (auto s : state) { benchmark::DoNotOptimize(Utf8Validate( "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5")); } } BENCHMARK(BM_Utf8Validate_String_JapaneseTen); void BM_Utf8Validate_Cord_JapaneseTen(benchmark::State& state) { absl::Cord value( "\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa" "\x9e\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\xe6\x97\xa5"); for (auto s : state) { benchmark::DoNotOptimize(Utf8Validate(value)); } } BENCHMARK(BM_Utf8Validate_Cord_JapaneseTen); } // namespace } // namespace cel::internal ================================================ FILE: internal/well_known_types.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/well_known_types.h" #include #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/base/call_once.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/time/time.h" #include "absl/types/variant.h" #include "common/json.h" #include "common/memory.h" #include "extensions/protobuf/internal/map_reflection.h" #include "internal/protobuf_runtime_version.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/reflection.h" #include "google/protobuf/util/time_util.h" namespace cel::well_known_types { namespace { using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorPool; using ::google::protobuf::EnumDescriptor; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::OneofDescriptor; using ::google::protobuf::util::TimeUtil; using CppStringType = ::google::protobuf::FieldDescriptor::CppStringType; FieldDescriptor::Label GetFieldLabel( const FieldDescriptor* absl_nonnull field) { if (field->is_required()) { return FieldDescriptor::LABEL_REQUIRED; } else if (field->is_repeated()) { return FieldDescriptor::LABEL_REPEATED; } else { return FieldDescriptor::LABEL_OPTIONAL; } } absl::string_view FlatStringValue( const StringValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return absl::visit( absl::Overload( [](absl::string_view string) -> absl::string_view { return string; }, [&](const absl::Cord& cord) -> absl::string_view { if (auto flat = cord.TryFlat(); flat) { return *flat; } scratch = static_cast(cord); return scratch; }), AsVariant(value)); } StringValue CopyStringValue(const StringValue& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return absl::visit( absl::Overload( [&](absl::string_view string) -> StringValue { if (string.data() != scratch.data()) { scratch.assign(string.data(), string.size()); return scratch; } return string; }, [](const absl::Cord& cord) -> StringValue { return cord; }), AsVariant(value)); } BytesValue CopyBytesValue(const BytesValue& value, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return absl::visit( absl::Overload( [&](absl::string_view string) -> BytesValue { if (string.data() != scratch.data()) { scratch.assign(string.data(), string.size()); return scratch; } return string; }, [](const absl::Cord& cord) -> BytesValue { return cord; }), AsVariant(value)); } google::protobuf::Reflection::ScratchSpace& GetScratchSpace() { static absl::NoDestructor scratch_space; return *scratch_space; } template Variant GetStringField(const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, CppStringType string_type, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(field->cpp_string_type() == string_type); switch (string_type) { case CppStringType::kCord: return reflection->GetCord(message, field); case CppStringType::kView: ABSL_FALLTHROUGH_INTENDED; case CppStringType::kString: // Message is guaranteed to be storing as some sort of contiguous array of // bytes, there is no need to copy. But unfortunately `GetStringView` // forces taking scratch space. return reflection->GetStringView(message, field, GetScratchSpace()); default: return absl::string_view( reflection->GetStringReference(message, field, &scratch)); } } template Variant GetStringField(const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, CppStringType string_type, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return GetStringField(message.GetReflection(), message, field, string_type, scratch); } template Variant GetRepeatedStringField( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, CppStringType string_type, int index, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { ABSL_DCHECK(field->cpp_string_type() == string_type); switch (string_type) { case CppStringType::kView: ABSL_FALLTHROUGH_INTENDED; case CppStringType::kString: // Message is guaranteed to be storing as some sort of contiguous array of // bytes, there is no need to copy. But unfortunately `GetStringView` // forces taking scratch space. return reflection->GetRepeatedStringView(message, field, index, GetScratchSpace()); default: return absl::string_view(reflection->GetRepeatedStringReference( message, field, index, &scratch)); } } template Variant GetRepeatedStringField( const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, CppStringType string_type, int index, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return GetRepeatedStringField(message.GetReflection(), message, field, string_type, index, scratch); } absl::StatusOr GetMessageTypeByName( const DescriptorPool* absl_nonnull pool, absl::string_view name) { const auto* descriptor = pool->FindMessageTypeByName(name); if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { return absl::InvalidArgumentError(absl::StrCat( "descriptor missing for protocol buffer message well known type: ", name)); } return descriptor; } absl::StatusOr GetEnumTypeByName( const DescriptorPool* absl_nonnull pool, absl::string_view name) { const auto* descriptor = pool->FindEnumTypeByName(name); if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { return absl::InvalidArgumentError(absl::StrCat( "descriptor missing for protocol buffer enum well known type: ", name)); } return descriptor; } absl::StatusOr GetOneofByName( const Descriptor* absl_nonnull descriptor, absl::string_view name) { const auto* oneof = descriptor->FindOneofByName(name); if (ABSL_PREDICT_FALSE(oneof == nullptr)) { return absl::InvalidArgumentError(absl::StrCat( "oneof missing for protocol buffer message well known type: ", descriptor->full_name(), ".", name)); } return oneof; } absl::StatusOr GetFieldByNumber( const Descriptor* absl_nonnull descriptor, int32_t number) { const auto* field = descriptor->FindFieldByNumber(number); if (ABSL_PREDICT_FALSE(field == nullptr)) { return absl::InvalidArgumentError(absl::StrCat( "field missing for protocol buffer message well known type: ", descriptor->full_name(), ".", number)); } return field; } absl::Status CheckFieldType(const FieldDescriptor* absl_nonnull field, FieldDescriptor::Type type) { if (ABSL_PREDICT_FALSE(field->type() != type)) { return absl::InvalidArgumentError(absl::StrCat( "unexpected field type for protocol buffer message well known type: ", field->full_name(), " ", field->type_name())); } return absl::OkStatus(); } absl::Status CheckFieldCppType(const FieldDescriptor* absl_nonnull field, FieldDescriptor::CppType cpp_type) { if (ABSL_PREDICT_FALSE(field->cpp_type() != cpp_type)) { return absl::InvalidArgumentError(absl::StrCat( "unexpected field type for protocol buffer message well known type: ", field->full_name(), " ", field->cpp_type_name())); } return absl::OkStatus(); } absl::string_view LabelToString(FieldDescriptor::Label label) { switch (label) { case FieldDescriptor::LABEL_REPEATED: return "REPEATED"; case FieldDescriptor::LABEL_REQUIRED: return "REQUIRED"; case FieldDescriptor::LABEL_OPTIONAL: return "OPTIONAL"; default: return "ERROR"; } } absl::Status CheckFieldCardinality(const FieldDescriptor* absl_nonnull field, FieldDescriptor::Label label) { if (ABSL_PREDICT_FALSE(GetFieldLabel(field) != label)) { return absl::InvalidArgumentError(absl::StrCat( "unexpected field cardinality for protocol buffer message " "well known type: ", field->full_name(), " ", LabelToString(GetFieldLabel(field)))); } return absl::OkStatus(); } absl::string_view WellKnownTypeToString( Descriptor::WellKnownType well_known_type) { switch (well_known_type) { case Descriptor::WELLKNOWNTYPE_BOOLVALUE: return "BOOLVALUE"; case Descriptor::WELLKNOWNTYPE_INT32VALUE: return "INT32VALUE"; case Descriptor::WELLKNOWNTYPE_INT64VALUE: return "INT64VALUE"; case Descriptor::WELLKNOWNTYPE_UINT32VALUE: return "UINT32VALUE"; case Descriptor::WELLKNOWNTYPE_UINT64VALUE: return "UINT64VALUE"; case Descriptor::WELLKNOWNTYPE_FLOATVALUE: return "FLOATVALUE"; case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: return "DOUBLEVALUE"; case Descriptor::WELLKNOWNTYPE_BYTESVALUE: return "BYTESVALUE"; case Descriptor::WELLKNOWNTYPE_STRINGVALUE: return "STRINGVALUE"; case Descriptor::WELLKNOWNTYPE_ANY: return "ANY"; case Descriptor::WELLKNOWNTYPE_DURATION: return "DURATION"; case Descriptor::WELLKNOWNTYPE_TIMESTAMP: return "TIMESTAMP"; case Descriptor::WELLKNOWNTYPE_VALUE: return "VALUE"; case Descriptor::WELLKNOWNTYPE_LISTVALUE: return "LISTVALUE"; case Descriptor::WELLKNOWNTYPE_STRUCT: return "STRUCT"; case Descriptor::WELLKNOWNTYPE_FIELDMASK: return "FIELDMASK"; default: return "ERROR"; } } absl::Status CheckWellKnownType(const Descriptor* absl_nonnull descriptor, Descriptor::WellKnownType well_known_type) { if (ABSL_PREDICT_FALSE(descriptor->well_known_type() != well_known_type)) { return absl::InvalidArgumentError(absl::StrCat( "expected message to be well known type: ", descriptor->full_name(), " ", WellKnownTypeToString(descriptor->well_known_type()))); } return absl::OkStatus(); } absl::Status CheckFieldWellKnownType( const FieldDescriptor* absl_nonnull field, Descriptor::WellKnownType well_known_type) { ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_MESSAGE); if (ABSL_PREDICT_FALSE(field->message_type()->well_known_type() != well_known_type)) { return absl::InvalidArgumentError(absl::StrCat( "expected message field to be well known type for protocol buffer " "message well known type: ", field->full_name(), " ", WellKnownTypeToString(field->message_type()->well_known_type()))); } return absl::OkStatus(); } absl::Status CheckFieldOneof(const FieldDescriptor* absl_nonnull field, const OneofDescriptor* absl_nonnull oneof, int index) { if (ABSL_PREDICT_FALSE(field->containing_oneof() != oneof)) { return absl::InvalidArgumentError( absl::StrCat("expected field to be member of oneof for protocol buffer " "message well known type: ", field->full_name())); } if (ABSL_PREDICT_FALSE(field->index_in_oneof() != index)) { return absl::InvalidArgumentError(absl::StrCat( "expected field to have index in oneof of ", index, " for protocol buffer " "message well known type: ", field->full_name(), " oneof_index=", field->index_in_oneof())); } return absl::OkStatus(); } absl::Status CheckMapField(const FieldDescriptor* absl_nonnull field) { if (ABSL_PREDICT_FALSE(!field->is_map())) { return absl::InvalidArgumentError( absl::StrCat("expected field to be map for protocol buffer " "message well known type: ", field->full_name())); } return absl::OkStatus(); } } // namespace bool StringValue::ConsumePrefix(absl::string_view prefix) { return absl::visit(absl::Overload( [&](absl::string_view& value) { return absl::ConsumePrefix(&value, prefix); }, [&](absl::Cord& cord) { if (cord.StartsWith(prefix)) { cord.RemovePrefix(prefix.size()); return true; } return false; }), AsVariant(*this)); } StringValue GetStringField(const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, std::string& scratch) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && !field->is_repeated()); ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_STRING); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); return GetStringField(reflection, message, field, field->cpp_string_type(), scratch); } BytesValue GetBytesField(const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, std::string& scratch) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && !field->is_repeated()); ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_BYTES); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); return GetStringField(reflection, message, field, field->cpp_string_type(), scratch); } StringValue GetRepeatedStringField( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, int index, std::string& scratch) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_STRING); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); return GetRepeatedStringField( reflection, message, field, field->cpp_string_type(), index, scratch); } BytesValue GetRepeatedBytesField( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message, const FieldDescriptor* absl_nonnull field, int index, std::string& scratch) { ABSL_DCHECK_EQ(reflection, message.GetReflection()); ABSL_DCHECK(!field->is_map() && field->is_repeated()); ABSL_DCHECK_EQ(field->type(), FieldDescriptor::TYPE_BYTES); ABSL_DCHECK_EQ(field->cpp_type(), FieldDescriptor::CPPTYPE_STRING); return GetRepeatedStringField( reflection, message, field, field->cpp_string_type(), index, scratch); } absl::Status NullValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetEnumTypeByName(pool, "google.protobuf.NullValue")); return Initialize(descriptor); } absl::Status NullValueReflection::Initialize( const EnumDescriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { if (ABSL_PREDICT_FALSE(descriptor->full_name() != "google.protobuf.NullValue")) { return absl::InvalidArgumentError(absl::StrCat( "expected enum to be well known type: ", descriptor->full_name(), " google.protobuf.NullValue")); } descriptor_ = nullptr; value_ = descriptor->FindValueByNumber(0); if (ABSL_PREDICT_FALSE(value_ == nullptr)) { return absl::InvalidArgumentError( "well known protocol buffer enum missing value: " "google.protobuf.NullValue.NULL_VALUE"); } if (ABSL_PREDICT_FALSE(descriptor->value_count() != 1)) { std::vector values; values.reserve(static_cast(descriptor->value_count())); for (int i = 0; i < descriptor->value_count(); ++i) { values.push_back(descriptor->value(i)->name()); } return absl::InvalidArgumentError( absl::StrCat("well known protocol buffer enum has multiple values: [", absl::StrJoin(values, ", "), "]")); } descriptor_ = descriptor; } return absl::OkStatus(); } absl::Status BoolValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.BoolValue")); return Initialize(descriptor); } absl::Status BoolValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_BOOL)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } bool BoolValueReflection::GetValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetBool(message, value_field_); } void BoolValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, bool value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetBool(message, value_field_, value); } absl::StatusOr GetBoolValueReflection( const Descriptor* absl_nonnull descriptor) { BoolValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status Int32ValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN( const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.Int32Value")); return Initialize(descriptor); } absl::Status Int32ValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_INT32)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } int32_t Int32ValueReflection::GetValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetInt32(message, value_field_); } void Int32ValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, int32_t value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetInt32(message, value_field_, value); } absl::StatusOr GetInt32ValueReflection( const Descriptor* absl_nonnull descriptor) { Int32ValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status Int64ValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN( const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.Int64Value")); return Initialize(descriptor); } absl::Status Int64ValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_INT64)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } int64_t Int64ValueReflection::GetValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetInt64(message, value_field_); } void Int64ValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, int64_t value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetInt64(message, value_field_, value); } absl::StatusOr GetInt64ValueReflection( const Descriptor* absl_nonnull descriptor) { Int64ValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status UInt32ValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN( const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.UInt32Value")); return Initialize(descriptor); } absl::Status UInt32ValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_UINT32)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } uint32_t UInt32ValueReflection::GetValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetUInt32(message, value_field_); } void UInt32ValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, uint32_t value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetUInt32(message, value_field_, value); } absl::StatusOr GetUInt32ValueReflection( const Descriptor* absl_nonnull descriptor) { UInt32ValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status UInt64ValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN( const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.UInt64Value")); return Initialize(descriptor); } absl::Status UInt64ValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_UINT64)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } uint64_t UInt64ValueReflection::GetValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetUInt64(message, value_field_); } void UInt64ValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, uint64_t value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetUInt64(message, value_field_, value); } absl::StatusOr GetUInt64ValueReflection( const Descriptor* absl_nonnull descriptor) { UInt64ValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status FloatValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN( const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.FloatValue")); return Initialize(descriptor); } absl::Status FloatValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_FLOAT)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } float FloatValueReflection::GetValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetFloat(message, value_field_); } void FloatValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, float value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetFloat(message, value_field_, value); } absl::StatusOr GetFloatValueReflection( const Descriptor* absl_nonnull descriptor) { FloatValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status DoubleValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN( const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.DoubleValue")); return Initialize(descriptor); } absl::Status DoubleValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(value_field_, FieldDescriptor::CPPTYPE_DOUBLE)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } double DoubleValueReflection::GetValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetDouble(message, value_field_); } void DoubleValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, double value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetDouble(message, value_field_, value); } absl::StatusOr GetDoubleValueReflection( const Descriptor* absl_nonnull descriptor) { DoubleValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status BytesValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN( const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.BytesValue")); return Initialize(descriptor); } absl::Status BytesValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldType(value_field_, FieldDescriptor::TYPE_BYTES)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); value_field_string_type_ = value_field_->cpp_string_type(); descriptor_ = descriptor; } return absl::OkStatus(); } BytesValue BytesValueReflection::GetValue(const google::protobuf::Message& message, std::string& scratch) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return GetStringField(message, value_field_, value_field_string_type_, scratch); } void BytesValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, absl::string_view value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetString(message, value_field_, std::string(value)); } void BytesValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetString(message, value_field_, value); } absl::StatusOr GetBytesValueReflection( const Descriptor* absl_nonnull descriptor) { BytesValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status StringValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN( const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.StringValue")); return Initialize(descriptor); } absl::Status StringValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldType(value_field_, FieldDescriptor::TYPE_STRING)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); value_field_string_type_ = value_field_->cpp_string_type(); descriptor_ = descriptor; } return absl::OkStatus(); } StringValue StringValueReflection::GetValue(const google::protobuf::Message& message, std::string& scratch) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return GetStringField(message, value_field_, value_field_string_type_, scratch); } void StringValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, absl::string_view value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetString(message, value_field_, std::string(value)); } void StringValueReflection::SetValue(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetString(message, value_field_, value); } absl::StatusOr GetStringValueReflection( const Descriptor* absl_nonnull descriptor) { StringValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status AnyReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.Any")); return Initialize(descriptor); } absl::Status AnyReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(type_url_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldType(type_url_field_, FieldDescriptor::TYPE_STRING)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(type_url_field_, FieldDescriptor::LABEL_OPTIONAL)); type_url_field_string_type_ = type_url_field_->cpp_string_type(); CEL_ASSIGN_OR_RETURN(value_field_, GetFieldByNumber(descriptor, 2)); CEL_RETURN_IF_ERROR( CheckFieldType(value_field_, FieldDescriptor::TYPE_BYTES)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(value_field_, FieldDescriptor::LABEL_OPTIONAL)); value_field_string_type_ = value_field_->cpp_string_type(); descriptor_ = descriptor; } return absl::OkStatus(); } void AnyReflection::SetTypeUrl(google::protobuf::Message* absl_nonnull message, absl::string_view type_url) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetString(message, type_url_field_, std::string(type_url)); } void AnyReflection::SetValue(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetString(message, value_field_, value); } StringValue AnyReflection::GetTypeUrl(const google::protobuf::Message& message, std::string& scratch) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return GetStringField(message, type_url_field_, type_url_field_string_type_, scratch); } BytesValue AnyReflection::GetValue(const google::protobuf::Message& message, std::string& scratch) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return GetStringField(message, value_field_, value_field_string_type_, scratch); } absl::StatusOr GetAnyReflection( const Descriptor* absl_nonnull descriptor) { AnyReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } AnyReflection GetAnyReflectionOrDie( const google::protobuf::Descriptor* absl_nonnull descriptor) { AnyReflection reflection; ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK return reflection; } absl::Status DurationReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.Duration")); return Initialize(descriptor); } absl::Status DurationReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(seconds_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(seconds_field_, FieldDescriptor::CPPTYPE_INT64)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(seconds_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_ASSIGN_OR_RETURN(nanos_field_, GetFieldByNumber(descriptor, 2)); CEL_RETURN_IF_ERROR( CheckFieldCppType(nanos_field_, FieldDescriptor::CPPTYPE_INT32)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(nanos_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } int64_t DurationReflection::GetSeconds(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetInt64(message, seconds_field_); } int32_t DurationReflection::GetNanos(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetInt32(message, nanos_field_); } void DurationReflection::SetSeconds(google::protobuf::Message* absl_nonnull message, int64_t value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetInt64(message, seconds_field_, value); } void DurationReflection::SetNanos(google::protobuf::Message* absl_nonnull message, int32_t value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetInt32(message, nanos_field_, value); } absl::Status DurationReflection::SetFromAbslDuration( google::protobuf::Message* absl_nonnull message, absl::Duration duration) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kDurationMinSeconds || seconds > TimeUtil::kDurationMaxSeconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid duration seconds: ", seconds)); } int32_t nanos = static_cast( absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kDurationMinNanoseconds || nanos > TimeUtil::kDurationMaxNanoseconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid duration nanoseconds: ", nanos)); } if ((seconds < 0 && nanos > 0) || (seconds > 0 && nanos < 0)) { return absl::InvalidArgumentError(absl::StrCat( "duration sign mismatch: seconds=", seconds, ", nanoseconds=", nanos)); } SetSeconds(message, seconds); SetNanos(message, nanos); return absl::OkStatus(); } absl::Status DurationReflection::SetFromAbslDuration( GeneratedMessageType* absl_nonnull message, absl::Duration duration) { int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kDurationMinSeconds || seconds > TimeUtil::kDurationMaxSeconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid duration seconds: ", seconds)); } int32_t nanos = static_cast( absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kDurationMinNanoseconds || nanos > TimeUtil::kDurationMaxNanoseconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid duration nanoseconds: ", nanos)); } if ((seconds < 0 && nanos > 0) || (seconds > 0 && nanos < 0)) { return absl::InvalidArgumentError(absl::StrCat( "duration sign mismatch: seconds=", seconds, ", nanoseconds=", nanos)); } SetSeconds(message, seconds); SetNanos(message, nanos); return absl::OkStatus(); } void DurationReflection::UnsafeSetFromAbslDuration( google::protobuf::Message* absl_nonnull message, absl::Duration duration) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); int64_t seconds = absl::IDivDuration(duration, absl::Seconds(1), &duration); int32_t nanos = static_cast( absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); SetSeconds(message, seconds); SetNanos(message, nanos); } absl::StatusOr DurationReflection::ToAbslDuration( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); int64_t seconds = GetSeconds(message); if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kDurationMinSeconds || seconds > TimeUtil::kDurationMaxSeconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid duration seconds: ", seconds)); } int32_t nanos = GetNanos(message); if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kDurationMinNanoseconds || nanos > TimeUtil::kDurationMaxNanoseconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid duration nanoseconds: ", nanos)); } if ((seconds < 0 && nanos > 0) || (seconds > 0 && nanos < 0)) { return absl::InvalidArgumentError(absl::StrCat( "duration sign mismatch: seconds=", seconds, ", nanoseconds=", nanos)); } return absl::Seconds(seconds) + absl::Nanoseconds(nanos); } absl::Duration DurationReflection::UnsafeToAbslDuration( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); int64_t seconds = GetSeconds(message); int32_t nanos = GetNanos(message); return absl::Seconds(seconds) + absl::Nanoseconds(nanos); } absl::StatusOr GetDurationReflection( const Descriptor* absl_nonnull descriptor) { DurationReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status TimestampReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.Timestamp")); return Initialize(descriptor); } absl::Status TimestampReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(seconds_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(seconds_field_, FieldDescriptor::CPPTYPE_INT64)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(seconds_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_ASSIGN_OR_RETURN(nanos_field_, GetFieldByNumber(descriptor, 2)); CEL_RETURN_IF_ERROR( CheckFieldCppType(nanos_field_, FieldDescriptor::CPPTYPE_INT32)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(nanos_field_, FieldDescriptor::LABEL_OPTIONAL)); descriptor_ = descriptor; } return absl::OkStatus(); } int64_t TimestampReflection::GetSeconds(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetInt64(message, seconds_field_); } int32_t TimestampReflection::GetNanos(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetInt32(message, nanos_field_); } void TimestampReflection::SetSeconds(google::protobuf::Message* absl_nonnull message, int64_t value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetInt64(message, seconds_field_, value); } void TimestampReflection::SetNanos(google::protobuf::Message* absl_nonnull message, int32_t value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetInt32(message, nanos_field_, value); } absl::Status TimestampReflection::SetFromAbslTime( google::protobuf::Message* absl_nonnull message, absl::Time time) const { int64_t seconds = absl::ToUnixSeconds(time); if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kTimestampMinSeconds || seconds > TimeUtil::kTimestampMaxSeconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid timestamp seconds: ", seconds)); } int64_t nanos = static_cast((time - absl::FromUnixSeconds(seconds)) / absl::Nanoseconds(1)); if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kTimestampMinNanoseconds || nanos > TimeUtil::kTimestampMaxNanoseconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid timestamp nanoseconds: ", nanos)); } SetSeconds(message, seconds); SetNanos(message, static_cast(nanos)); return absl::OkStatus(); } absl::Status TimestampReflection::SetFromAbslTime( GeneratedMessageType* absl_nonnull message, absl::Time time) { int64_t seconds = absl::ToUnixSeconds(time); if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kTimestampMinSeconds || seconds > TimeUtil::kTimestampMaxSeconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid timestamp seconds: ", seconds)); } int64_t nanos = static_cast((time - absl::FromUnixSeconds(seconds)) / absl::Nanoseconds(1)); if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kTimestampMinNanoseconds || nanos > TimeUtil::kTimestampMaxNanoseconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid timestamp nanoseconds: ", nanos)); } SetSeconds(message, seconds); SetNanos(message, static_cast(nanos)); return absl::OkStatus(); } void TimestampReflection::UnsafeSetFromAbslTime( google::protobuf::Message* absl_nonnull message, absl::Time time) const { int64_t seconds = absl::ToUnixSeconds(time); int32_t nanos = static_cast((time - absl::FromUnixSeconds(seconds)) / absl::Nanoseconds(1)); SetSeconds(message, seconds); SetNanos(message, nanos); } absl::StatusOr TimestampReflection::ToAbslTime( const google::protobuf::Message& message) const { int64_t seconds = GetSeconds(message); if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kTimestampMinSeconds || seconds > TimeUtil::kTimestampMaxSeconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid timestamp seconds: ", seconds)); } int32_t nanos = GetNanos(message); if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kTimestampMinNanoseconds || nanos > TimeUtil::kTimestampMaxNanoseconds)) { return absl::InvalidArgumentError( absl::StrCat("invalid timestamp nanoseconds: ", nanos)); } return absl::UnixEpoch() + absl::Seconds(seconds) + absl::Nanoseconds(nanos); } absl::Time TimestampReflection::UnsafeToAbslTime( const google::protobuf::Message& message) const { int64_t seconds = GetSeconds(message); int32_t nanos = GetNanos(message); return absl::UnixEpoch() + absl::Seconds(seconds) + absl::Nanoseconds(nanos); } absl::StatusOr GetTimestampReflection( const Descriptor* absl_nonnull descriptor) { TimestampReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } void ValueReflection::SetNumberValue( google::protobuf::Value* absl_nonnull message, int64_t value) { if (value < kJsonMinInt || value > kJsonMaxInt) { SetStringValue(message, absl::StrCat(value)); return; } SetNumberValue(message, static_cast(value)); } void ValueReflection::SetNumberValue( google::protobuf::Value* absl_nonnull message, uint64_t value) { if (value > kJsonMaxUint) { SetStringValue(message, absl::StrCat(value)); return; } SetNumberValue(message, static_cast(value)); } absl::Status ValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.Value")); return Initialize(descriptor); } absl::Status ValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(kind_field_, GetOneofByName(descriptor, "kind")); CEL_ASSIGN_OR_RETURN(null_value_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(null_value_field_, FieldDescriptor::CPPTYPE_ENUM)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(null_value_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_RETURN_IF_ERROR(CheckFieldOneof(null_value_field_, kind_field_, 0)); CEL_ASSIGN_OR_RETURN(bool_value_field_, GetFieldByNumber(descriptor, 4)); CEL_RETURN_IF_ERROR( CheckFieldCppType(bool_value_field_, FieldDescriptor::CPPTYPE_BOOL)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(bool_value_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_RETURN_IF_ERROR(CheckFieldOneof(bool_value_field_, kind_field_, 3)); CEL_ASSIGN_OR_RETURN(number_value_field_, GetFieldByNumber(descriptor, 2)); CEL_RETURN_IF_ERROR(CheckFieldCppType(number_value_field_, FieldDescriptor::CPPTYPE_DOUBLE)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(number_value_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_RETURN_IF_ERROR(CheckFieldOneof(number_value_field_, kind_field_, 1)); CEL_ASSIGN_OR_RETURN(string_value_field_, GetFieldByNumber(descriptor, 3)); CEL_RETURN_IF_ERROR(CheckFieldCppType(string_value_field_, FieldDescriptor::CPPTYPE_STRING)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(string_value_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_RETURN_IF_ERROR(CheckFieldOneof(string_value_field_, kind_field_, 2)); string_value_field_string_type_ = string_value_field_->cpp_string_type(); CEL_ASSIGN_OR_RETURN(list_value_field_, GetFieldByNumber(descriptor, 6)); CEL_RETURN_IF_ERROR( CheckFieldCppType(list_value_field_, FieldDescriptor::CPPTYPE_MESSAGE)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(list_value_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_RETURN_IF_ERROR(CheckFieldOneof(list_value_field_, kind_field_, 5)); CEL_RETURN_IF_ERROR(CheckFieldWellKnownType( list_value_field_, Descriptor::WELLKNOWNTYPE_LISTVALUE)); CEL_ASSIGN_OR_RETURN(struct_value_field_, GetFieldByNumber(descriptor, 5)); CEL_RETURN_IF_ERROR(CheckFieldCppType(struct_value_field_, FieldDescriptor::CPPTYPE_MESSAGE)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(struct_value_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_RETURN_IF_ERROR(CheckFieldOneof(struct_value_field_, kind_field_, 4)); CEL_RETURN_IF_ERROR(CheckFieldWellKnownType( struct_value_field_, Descriptor::WELLKNOWNTYPE_STRUCT)); descriptor_ = descriptor; } return absl::OkStatus(); } google::protobuf::Value::KindCase ValueReflection::GetKindCase( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); const auto* field = message.GetReflection()->GetOneofFieldDescriptor(message, kind_field_); return field != nullptr ? static_cast( field->index_in_oneof() + 1) : google::protobuf::Value::KIND_NOT_SET; } bool ValueReflection::GetBoolValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetBool(message, bool_value_field_); } double ValueReflection::GetNumberValue(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetDouble(message, number_value_field_); } StringValue ValueReflection::GetStringValue(const google::protobuf::Message& message, std::string& scratch) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return GetStringField(message, string_value_field_, string_value_field_string_type_, scratch); } const google::protobuf::Message& ValueReflection::GetListValue( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); #undef GetMessage return message.GetReflection()->GetMessage(message, list_value_field_); } const google::protobuf::Message& ValueReflection::GetStructValue( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); #undef GetMessage return message.GetReflection()->GetMessage(message, struct_value_field_); } void ValueReflection::SetNullValue( google::protobuf::Message* absl_nonnull message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetEnumValue(message, null_value_field_, 0); } void ValueReflection::SetBoolValue(google::protobuf::Message* absl_nonnull message, bool value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetBool(message, bool_value_field_, value); } void ValueReflection::SetNumberValue(google::protobuf::Message* absl_nonnull message, int64_t value) const { if (value < kJsonMinInt || value > kJsonMaxInt) { SetStringValue(message, absl::StrCat(value)); return; } SetNumberValue(message, static_cast(value)); } void ValueReflection::SetNumberValue(google::protobuf::Message* absl_nonnull message, uint64_t value) const { if (value > kJsonMaxUint) { SetStringValue(message, absl::StrCat(value)); return; } SetNumberValue(message, static_cast(value)); } void ValueReflection::SetNumberValue(google::protobuf::Message* absl_nonnull message, double value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetDouble(message, number_value_field_, value); } void ValueReflection::SetStringValue(google::protobuf::Message* absl_nonnull message, absl::string_view value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetString(message, string_value_field_, std::string(value)); } void ValueReflection::SetStringValue(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); message->GetReflection()->SetString(message, string_value_field_, value); } void ValueReflection::SetStringValueFromBytes( google::protobuf::Message* absl_nonnull message, absl::string_view value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); if (value.empty()) { SetStringValue(message, value); return; } SetStringValue(message, absl::Base64Escape(value)); } void ValueReflection::SetStringValueFromBytes( google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); if (value.empty()) { SetStringValue(message, value); return; } if (auto flat = value.TryFlat(); flat) { SetStringValue(message, absl::Base64Escape(*flat)); return; } std::string flat; absl::CopyCordToString(value, &flat); SetStringValue(message, absl::Base64Escape(flat)); } void ValueReflection::SetStringValueFromDuration( google::protobuf::Message* absl_nonnull message, absl::Duration duration) const { google::protobuf::Duration proto; proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); proto.set_nanos(static_cast( absl::IDivDuration(duration, absl::Nanoseconds(1), &duration))); ABSL_DCHECK(TimeUtil::IsDurationValid(proto)); SetStringValue(message, TimeUtil::ToString(proto)); } void ValueReflection::SetStringValueFromTimestamp( google::protobuf::Message* absl_nonnull message, absl::Time time) const { google::protobuf::Timestamp proto; proto.set_seconds(absl::ToUnixSeconds(time)); proto.set_nanos((time - absl::FromUnixSeconds(proto.seconds())) / absl::Nanoseconds(1)); ABSL_DCHECK(TimeUtil::IsTimestampValid(proto)); SetStringValue(message, TimeUtil::ToString(proto)); } google::protobuf::Message* absl_nonnull ValueReflection::MutableListValue( google::protobuf::Message* absl_nonnull message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); return message->GetReflection()->MutableMessage(message, list_value_field_); } google::protobuf::Message* absl_nonnull ValueReflection::MutableStructValue( google::protobuf::Message* absl_nonnull message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); return message->GetReflection()->MutableMessage(message, struct_value_field_); } Unique ValueReflection::ReleaseListValue( google::protobuf::Message* absl_nonnull message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); const auto* reflection = message->GetReflection(); if (!reflection->HasField(*message, list_value_field_)) { reflection->MutableMessage(message, list_value_field_); } return WrapUnique( reflection->UnsafeArenaReleaseMessage(message, list_value_field_), message->GetArena()); } Unique ValueReflection::ReleaseStructValue( google::protobuf::Message* absl_nonnull message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); const auto* reflection = message->GetReflection(); if (!reflection->HasField(*message, struct_value_field_)) { reflection->MutableMessage(message, struct_value_field_); } return WrapUnique( reflection->UnsafeArenaReleaseMessage(message, struct_value_field_), message->GetArena()); } absl::StatusOr GetValueReflection( const Descriptor* absl_nonnull descriptor) { ValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } ValueReflection GetValueReflectionOrDie( const google::protobuf::Descriptor* absl_nonnull descriptor) { ValueReflection reflection; ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK; return reflection; } absl::Status ListValueReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.ListValue")); return Initialize(descriptor); } absl::Status ListValueReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(values_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(values_field_, FieldDescriptor::CPPTYPE_MESSAGE)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(values_field_, FieldDescriptor::LABEL_REPEATED)); CEL_RETURN_IF_ERROR(CheckFieldWellKnownType( values_field_, Descriptor::WELLKNOWNTYPE_VALUE)); descriptor_ = descriptor; } return absl::OkStatus(); } int ListValueReflection::ValuesSize(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->FieldSize(message, values_field_); } google::protobuf::RepeatedFieldRef ListValueReflection::Values( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetRepeatedFieldRef( message, values_field_); } const google::protobuf::Message& ListValueReflection::Values( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, int index) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->GetRepeatedMessage(message, values_field_, index); } google::protobuf::MutableRepeatedFieldRef ListValueReflection::MutableValues( google::protobuf::Message* absl_nonnull message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); return message->GetReflection()->GetMutableRepeatedFieldRef( message, values_field_); } google::protobuf::Message* absl_nonnull ListValueReflection::AddValues( google::protobuf::Message* absl_nonnull message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); return message->GetReflection()->AddMessage(message, values_field_); } absl::StatusOr GetListValueReflection( const Descriptor* absl_nonnull descriptor) { ListValueReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } ListValueReflection GetListValueReflectionOrDie( const google::protobuf::Descriptor* absl_nonnull descriptor) { ListValueReflection reflection; ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK return reflection; } absl::Status StructReflection::Initialize( const DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.Struct")); return Initialize(descriptor); } absl::Status StructReflection::Initialize( const Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(fields_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR(CheckMapField(fields_field_)); fields_key_field_ = fields_field_->message_type()->map_key(); CEL_RETURN_IF_ERROR( CheckFieldCppType(fields_key_field_, FieldDescriptor::CPPTYPE_STRING)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(fields_key_field_, FieldDescriptor::LABEL_OPTIONAL)); fields_value_field_ = fields_field_->message_type()->map_value(); CEL_RETURN_IF_ERROR(CheckFieldCppType(fields_value_field_, FieldDescriptor::CPPTYPE_MESSAGE)); CEL_RETURN_IF_ERROR(CheckFieldCardinality(fields_value_field_, FieldDescriptor::LABEL_OPTIONAL)); CEL_RETURN_IF_ERROR(CheckFieldWellKnownType( fields_value_field_, Descriptor::WELLKNOWNTYPE_VALUE)); descriptor_ = descriptor; } return absl::OkStatus(); } int StructReflection::FieldsSize(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return cel::extensions::protobuf_internal::MapSize(*message.GetReflection(), message, *fields_field_); } google::protobuf::ConstMapIterator StructReflection::BeginFields( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return cel::extensions::protobuf_internal::ConstMapBegin( *message.GetReflection(), message, *fields_field_); } google::protobuf::ConstMapIterator StructReflection::EndFields( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return cel::extensions::protobuf_internal::ConstMapEnd( *message.GetReflection(), message, *fields_field_); } bool StructReflection::ContainsField(const google::protobuf::Message& message, absl::string_view name) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); #if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) google::protobuf::MapKey key; key.SetStringValue(name); #else std::string key_scratch(name); google::protobuf::MapKey key; key.SetStringValue(key_scratch); #endif return cel::extensions::protobuf_internal::ContainsMapKey( *message.GetReflection(), message, *fields_field_, key); } const google::protobuf::Message* absl_nullable StructReflection::FindField( const google::protobuf::Message& message, absl::string_view name) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); #if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) google::protobuf::MapKey key; key.SetStringValue(name); #else std::string key_scratch(name); google::protobuf::MapKey key; key.SetStringValue(key_scratch); #endif google::protobuf::MapValueConstRef value; if (cel::extensions::protobuf_internal::LookupMapValue( *message.GetReflection(), message, *fields_field_, key, &value)) { return &value.GetMessageValue(); } return nullptr; } google::protobuf::Message* absl_nonnull StructReflection::InsertField( google::protobuf::Message* absl_nonnull message, absl::string_view name) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); #if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) google::protobuf::MapKey key; key.SetStringValue(name); #else std::string key_scratch(name); google::protobuf::MapKey key; key.SetStringValue(key_scratch); #endif google::protobuf::MapValueRef value; cel::extensions::protobuf_internal::InsertOrLookupMapValue( *message->GetReflection(), message, *fields_field_, key, &value); return value.MutableMessageValue(); } bool StructReflection::DeleteField(google::protobuf::Message* absl_nonnull message, absl::string_view name) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message->GetDescriptor(), descriptor_); #if CEL_INTERNAL_PROTOBUF_OSS_VERSION_PREREQ(5, 30, 0) google::protobuf::MapKey key; key.SetStringValue(name); #else std::string key_scratch(name); google::protobuf::MapKey key; key.SetStringValue(key_scratch); #endif return cel::extensions::protobuf_internal::DeleteMapValue( message->GetReflection(), message, fields_field_, key); } absl::StatusOr GetStructReflection( const Descriptor* absl_nonnull descriptor) { StructReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } StructReflection GetStructReflectionOrDie( const google::protobuf::Descriptor* absl_nonnull descriptor) { StructReflection reflection; ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK return reflection; } absl::Status FieldMaskReflection::Initialize( const google::protobuf::DescriptorPool* absl_nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, GetMessageTypeByName(pool, "google.protobuf.FieldMask")); return Initialize(descriptor); } absl::Status FieldMaskReflection::Initialize( const google::protobuf::Descriptor* absl_nonnull descriptor) { if (descriptor_ != descriptor) { CEL_RETURN_IF_ERROR(CheckWellKnownType(descriptor, kWellKnownType)); descriptor_ = nullptr; CEL_ASSIGN_OR_RETURN(paths_field_, GetFieldByNumber(descriptor, 1)); CEL_RETURN_IF_ERROR( CheckFieldCppType(paths_field_, FieldDescriptor::CPPTYPE_STRING)); CEL_RETURN_IF_ERROR( CheckFieldCardinality(paths_field_, FieldDescriptor::LABEL_REPEATED)); paths_field_string_type_ = paths_field_->cpp_string_type(); descriptor_ = descriptor; } return absl::OkStatus(); } int FieldMaskReflection::PathsSize(const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); return message.GetReflection()->FieldSize(message, paths_field_); } StringValue FieldMaskReflection::Paths(const google::protobuf::Message& message, int index, std::string& scratch) const { return GetRepeatedStringField( message, paths_field_, paths_field_string_type_, index, scratch); } absl::StatusOr GetFieldMaskReflection( const google::protobuf::Descriptor* absl_nonnull descriptor) { FieldMaskReflection reflection; CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); return reflection; } absl::Status JsonReflection::Initialize( const google::protobuf::DescriptorPool* absl_nonnull pool) { CEL_RETURN_IF_ERROR(Value().Initialize(pool)); CEL_RETURN_IF_ERROR(ListValue().Initialize(pool)); CEL_RETURN_IF_ERROR(Struct().Initialize(pool)); return absl::OkStatus(); } absl::Status JsonReflection::Initialize( const google::protobuf::Descriptor* absl_nonnull descriptor) { switch (descriptor->well_known_type()) { case google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE: CEL_RETURN_IF_ERROR(Value().Initialize(descriptor)); CEL_RETURN_IF_ERROR( ListValue().Initialize(Value().GetListValueDescriptor())); CEL_RETURN_IF_ERROR(Struct().Initialize(Value().GetStructDescriptor())); return absl::OkStatus(); case google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE: CEL_RETURN_IF_ERROR(ListValue().Initialize(descriptor)); CEL_RETURN_IF_ERROR(Value().Initialize(ListValue().GetValueDescriptor())); CEL_RETURN_IF_ERROR(Struct().Initialize(Value().GetStructDescriptor())); return absl::OkStatus(); case google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT: CEL_RETURN_IF_ERROR(Struct().Initialize(descriptor)); CEL_RETURN_IF_ERROR(Value().Initialize(Struct().GetValueDescriptor())); CEL_RETURN_IF_ERROR( ListValue().Initialize(Value().GetListValueDescriptor())); return absl::OkStatus(); default: return absl::InvalidArgumentError( absl::StrCat("expected message to be JSON-like well known type: ", descriptor->full_name(), " ", WellKnownTypeToString(descriptor->well_known_type()))); } } bool JsonReflection::IsInitialized() const { return Value().IsInitialized() && ListValue().IsInitialized() && Struct().IsInitialized(); } namespace { [[maybe_unused]] ABSL_CONST_INIT absl::once_flag link_well_known_message_reflection; void LinkWellKnownMessageReflection() { google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection(); } } // namespace absl::Status Reflection::Initialize(const DescriptorPool* absl_nonnull pool) { if (pool == DescriptorPool::generated_pool()) { absl::call_once(link_well_known_message_reflection, &LinkWellKnownMessageReflection); } CEL_RETURN_IF_ERROR(NullValue().Initialize(pool)); CEL_RETURN_IF_ERROR(BoolValue().Initialize(pool)); CEL_RETURN_IF_ERROR(Int32Value().Initialize(pool)); CEL_RETURN_IF_ERROR(Int64Value().Initialize(pool)); CEL_RETURN_IF_ERROR(UInt32Value().Initialize(pool)); CEL_RETURN_IF_ERROR(UInt64Value().Initialize(pool)); CEL_RETURN_IF_ERROR(FloatValue().Initialize(pool)); CEL_RETURN_IF_ERROR(DoubleValue().Initialize(pool)); CEL_RETURN_IF_ERROR(BytesValue().Initialize(pool)); CEL_RETURN_IF_ERROR(StringValue().Initialize(pool)); CEL_RETURN_IF_ERROR(Any().Initialize(pool)); CEL_RETURN_IF_ERROR(Duration().Initialize(pool)); CEL_RETURN_IF_ERROR(Timestamp().Initialize(pool)); CEL_RETURN_IF_ERROR(Json().Initialize(pool)); // google.protobuf.FieldMask is not strictly mandatory, but we do have to // treat it specifically for JSON. So use it if we have it. if (const auto* descriptor = pool->FindMessageTypeByName("google.protobuf.FieldMask"); descriptor != nullptr) { CEL_RETURN_IF_ERROR(FieldMask().Initialize(descriptor)); } return absl::OkStatus(); } bool Reflection::IsInitialized() const { // Check that everything is initialized except field mask, which is optional. return NullValue().IsInitialized() && BoolValue().IsInitialized() && Int32Value().IsInitialized() && Int64Value().IsInitialized() && UInt32Value().IsInitialized() && UInt64Value().IsInitialized() && FloatValue().IsInitialized() && DoubleValue().IsInitialized() && BytesValue().IsInitialized() && StringValue().IsInitialized() && Any().IsInitialized() && Duration().IsInitialized() && Timestamp().IsInitialized() && Json().IsInitialized(); } namespace { // AdaptListValue verifies the message is the well known type // `google.protobuf.ListValue` and performs the complicated logic of reimaging // it as `ListValue`. If adapted is empty, we return as a reference. If adapted // is present, message must be a reference to the value held in adapted and it // will be returned by value. absl::StatusOr AdaptListValue(google::protobuf::Arena* absl_nullable arena, const google::protobuf::Message& message, Unique adapted) { ABSL_DCHECK(!adapted || &message == cel::to_address(adapted)); const auto* descriptor = message.GetDescriptor(); if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { return absl::InvalidArgumentError( absl::StrCat("missing descriptor for protocol buffer message: ", message.GetTypeName())); } // Not much to do. Just verify the well known type is well-formed. CEL_RETURN_IF_ERROR(GetListValueReflection(descriptor).status()); if (adapted) { return ListValue(std::move(adapted)); } return ListValue(std::cref(message)); } // AdaptStruct verifies the message is the well known type // `google.protobuf.Struct` and performs the complicated logic of reimaging it // as `Struct`. If adapted is empty, we return as a reference. If adapted is // present, message must be a reference to the value held in adapted and it will // be returned by value. absl::StatusOr AdaptStruct(google::protobuf::Arena* absl_nullable arena, const google::protobuf::Message& message, Unique adapted) { ABSL_DCHECK(!adapted || &message == cel::to_address(adapted)); const auto* descriptor = message.GetDescriptor(); if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { return absl::InvalidArgumentError( absl::StrCat("missing descriptor for protocol buffer message: ", message.GetTypeName())); } // Not much to do. Just verify the well known type is well-formed. CEL_RETURN_IF_ERROR(GetStructReflection(descriptor).status()); if (adapted) { return Struct(std::move(adapted)); } return Struct(std::cref(message)); } // AdaptAny recursively unpacks a protocol buffer message which is an instance // of `google.protobuf.Any`. absl::StatusOr> AdaptAny( google::protobuf::Arena* absl_nullable arena, AnyReflection& reflection, const google::protobuf::Message& message, const Descriptor* absl_nonnull descriptor, const DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory, bool error_if_unresolveable) { ABSL_DCHECK_EQ(descriptor->well_known_type(), Descriptor::WELLKNOWNTYPE_ANY); const google::protobuf::Message* absl_nonnull to_unwrap = &message; Unique unwrapped; std::string type_url_scratch; std::string value_scratch; do { CEL_RETURN_IF_ERROR(reflection.Initialize(descriptor)); StringValue type_url = reflection.GetTypeUrl(*to_unwrap, type_url_scratch); absl::string_view type_url_view = FlatStringValue(type_url, type_url_scratch); if (!absl::ConsumePrefix(&type_url_view, "type.googleapis.com/") && !absl::ConsumePrefix(&type_url_view, "type.googleprod.com/")) { if (!error_if_unresolveable) { break; } return absl::InvalidArgumentError(absl::StrCat( "unable to find descriptor for type URL: ", type_url_view)); } const auto* packed_descriptor = pool->FindMessageTypeByName(type_url_view); if (packed_descriptor == nullptr) { if (!error_if_unresolveable) { break; } return absl::InvalidArgumentError(absl::StrCat( "unable to find descriptor for type name: ", type_url_view)); } const auto* prototype = factory->GetPrototype(packed_descriptor); if (prototype == nullptr) { return absl::InvalidArgumentError(absl::StrCat( "unable to build prototype for type name: ", type_url_view)); } BytesValue value = reflection.GetValue(*to_unwrap, value_scratch); Unique unpacked = WrapUnique(prototype->New(arena), arena); const bool ok = absl::visit(absl::Overload( [&](absl::string_view string) -> bool { return unpacked->ParseFromString(string); }, [&](const absl::Cord& cord) -> bool { return unpacked->ParseFromString(cord); }), AsVariant(value)); if (!ok) { return absl::InvalidArgumentError(absl::StrCat( "failed to unpack protocol buffer message: ", type_url_view)); } // We can only update unwrapped at this point, not before. This is because // we could have been unpacking from unwrapped itself. unwrapped = std::move(unpacked); to_unwrap = cel::to_address(unwrapped); descriptor = to_unwrap->GetDescriptor(); if (descriptor == nullptr) { return absl::InvalidArgumentError( absl::StrCat("missing descriptor for protocol buffer message: ", to_unwrap->GetTypeName())); } } while (descriptor->well_known_type() == Descriptor::WELLKNOWNTYPE_ANY); return unwrapped; } } // namespace absl::StatusOr> UnpackAnyFrom( google::protobuf::Arena* absl_nullable arena, AnyReflection& reflection, const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory) { ABSL_DCHECK_EQ(message.GetDescriptor()->well_known_type(), Descriptor::WELLKNOWNTYPE_ANY); return AdaptAny(arena, reflection, message, message.GetDescriptor(), pool, factory, /*error_if_unresolveable=*/true); } absl::StatusOr> UnpackAnyIfResolveable( google::protobuf::Arena* absl_nullable arena, AnyReflection& reflection, const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory) { ABSL_DCHECK_EQ(message.GetDescriptor()->well_known_type(), Descriptor::WELLKNOWNTYPE_ANY); return AdaptAny(arena, reflection, message, message.GetDescriptor(), pool, factory, /*error_if_unresolveable=*/false); } absl::StatusOr AdaptFromMessage( google::protobuf::Arena* absl_nullable arena, const google::protobuf::Message& message, const DescriptorPool* absl_nonnull pool, google::protobuf::MessageFactory* absl_nonnull factory, std::string& scratch) { const auto* descriptor = message.GetDescriptor(); if (ABSL_PREDICT_FALSE(descriptor == nullptr)) { return absl::InvalidArgumentError( absl::StrCat("missing descriptor for protocol buffer message: ", message.GetTypeName())); } const google::protobuf::Message* absl_nonnull to_adapt; Unique adapted; Descriptor::WellKnownType well_known_type = descriptor->well_known_type(); if (well_known_type == Descriptor::WELLKNOWNTYPE_ANY) { AnyReflection reflection; CEL_ASSIGN_OR_RETURN( adapted, UnpackAnyFrom(arena, reflection, message, pool, factory)); to_adapt = cel::to_address(adapted); // GetDescriptor() is guaranteed to be nonnull by AdaptAny(). descriptor = to_adapt->GetDescriptor(); well_known_type = descriptor->well_known_type(); } else { to_adapt = &message; } switch (descriptor->well_known_type()) { case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetDoubleValueReflection(descriptor)); return reflection.GetValue(*to_adapt); } case Descriptor::WELLKNOWNTYPE_FLOATVALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetFloatValueReflection(descriptor)); return reflection.GetValue(*to_adapt); } case Descriptor::WELLKNOWNTYPE_INT64VALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetInt64ValueReflection(descriptor)); return reflection.GetValue(*to_adapt); } case Descriptor::WELLKNOWNTYPE_UINT64VALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetUInt64ValueReflection(descriptor)); return reflection.GetValue(*to_adapt); } case Descriptor::WELLKNOWNTYPE_INT32VALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetInt32ValueReflection(descriptor)); return reflection.GetValue(*to_adapt); } case Descriptor::WELLKNOWNTYPE_UINT32VALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetUInt32ValueReflection(descriptor)); return reflection.GetValue(*to_adapt); } case Descriptor::WELLKNOWNTYPE_STRINGVALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetStringValueReflection(descriptor)); auto value = reflection.GetValue(*to_adapt, scratch); if (adapted) { // value might actually be a view of data owned by adapted, force a copy // to scratch if that is the case. value = CopyStringValue(value, scratch); } return value; } case Descriptor::WELLKNOWNTYPE_BYTESVALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetBytesValueReflection(descriptor)); auto value = reflection.GetValue(*to_adapt, scratch); if (adapted) { // value might actually be a view of data owned by adapted, force a copy // to scratch if that is the case. value = CopyBytesValue(value, scratch); } return value; } case Descriptor::WELLKNOWNTYPE_BOOLVALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetBoolValueReflection(descriptor)); return reflection.GetValue(*to_adapt); } case Descriptor::WELLKNOWNTYPE_ANY: // This is unreachable, as AdaptAny() above recursively unpacks. ABSL_UNREACHABLE(); case Descriptor::WELLKNOWNTYPE_DURATION: { CEL_ASSIGN_OR_RETURN(auto reflection, GetDurationReflection(descriptor)); return reflection.ToAbslDuration(*to_adapt); } case Descriptor::WELLKNOWNTYPE_TIMESTAMP: { CEL_ASSIGN_OR_RETURN(auto reflection, GetTimestampReflection(descriptor)); return reflection.ToAbslTime(*to_adapt); } case Descriptor::WELLKNOWNTYPE_VALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetValueReflection(descriptor)); const auto kind_case = reflection.GetKindCase(*to_adapt); switch (kind_case) { case google::protobuf::Value::KIND_NOT_SET: ABSL_FALLTHROUGH_INTENDED; case google::protobuf::Value::kNullValue: return nullptr; case google::protobuf::Value::kNumberValue: return reflection.GetNumberValue(*to_adapt); case google::protobuf::Value::kStringValue: { auto value = reflection.GetStringValue(*to_adapt, scratch); if (adapted) { value = CopyStringValue(value, scratch); } return value; } case google::protobuf::Value::kBoolValue: return reflection.GetBoolValue(*to_adapt); case google::protobuf::Value::kStructValue: { if (adapted) { // We can release. adapted = reflection.ReleaseStructValue(cel::to_address(adapted)); to_adapt = cel::to_address(adapted); } else { to_adapt = &reflection.GetStructValue(*to_adapt); } return AdaptStruct(arena, *to_adapt, std::move(adapted)); } case google::protobuf::Value::kListValue: { if (adapted) { // We can release. adapted = reflection.ReleaseListValue(cel::to_address(adapted)); to_adapt = cel::to_address(adapted); } else { to_adapt = &reflection.GetListValue(*to_adapt); } return AdaptListValue(arena, *to_adapt, std::move(adapted)); } default: return absl::InvalidArgumentError( absl::StrCat("unexpected value kind case: ", kind_case)); } } case Descriptor::WELLKNOWNTYPE_LISTVALUE: return AdaptListValue(arena, *to_adapt, std::move(adapted)); case Descriptor::WELLKNOWNTYPE_STRUCT: return AdaptStruct(arena, *to_adapt, std::move(adapted)); default: if (adapted) { return adapted; } return absl::monostate{}; } } } // namespace cel::well_known_types ================================================ FILE: internal/well_known_types.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This file provides handling for well known protocol buffer types, which is // agnostic to whether the types are dynamic or generated. It also performs // exhaustive verification of the structure of the well known message types, // ensuring they will work as intended throughout the rest of our codebase. // // For each well know type, there is a class `XReflection` where `X` is the // unqualified well know type name. Each class can be initialized from a // descriptor pool or a descriptor. Once initialized, they can be used with // messages which use that exact descriptor. Using them with a different version // of the descriptor from a separate descriptor pool results in undefined // behavior. If unsure, you can initialize multiple times. If initializing with // the same descriptor, it is a noop. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_WELL_KNOWN_TYPES_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_WELL_KNOWN_TYPES_H_ #include #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/variant.h" #include "common/any.h" #include "common/memory.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/map_field.h" #include "google/protobuf/message.h" #include "google/protobuf/reflection.h" namespace cel::well_known_types { // Strongly typed variant capable of holding the value representation of any // protocol buffer message string field. We do this instead of type aliasing to // avoid collisions in other variants such as `well_known_types::Value`. class StringValue final : public absl::variant { public: using absl::variant::variant; bool ConsumePrefix(absl::string_view prefix); }; // Older versions of GCC do not deal with inheriting from variant correctly when // using `visit`, so we cheat by upcasting. inline const absl::variant& AsVariant( const StringValue& value) { return static_cast&>( value); } inline absl::variant& AsVariant( StringValue& value) { return static_cast&>(value); } inline const absl::variant&& AsVariant( const StringValue&& value) { return static_cast&&>( value); } inline absl::variant&& AsVariant( StringValue&& value) { return static_cast&&>(value); } inline bool operator==(const StringValue& lhs, const StringValue& rhs) { return absl::visit( [](const auto& lhs, const auto& rhs) { return lhs == rhs; }, AsVariant(lhs), AsVariant(rhs)); } inline bool operator!=(const StringValue& lhs, const StringValue& rhs) { return !operator==(lhs, rhs); } template void AbslStringify(S& sink, const StringValue& value) { sink.Append(absl::visit( [&](const auto& value) -> std::string { return absl::StrCat(value); }, AsVariant(value))); } StringValue GetStringField(const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); inline StringValue GetStringField( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return GetStringField(message.GetReflection(), message, field, scratch); } StringValue GetRepeatedStringField( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); inline StringValue GetRepeatedStringField( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return GetRepeatedStringField(message.GetReflection(), message, field, index, scratch); } // Strongly typed variant capable of holding the value representation of any // protocol buffer message bytes field. We do this instead of type aliasing to // avoid collisions in other variants such as `well_known_types::Value`. class BytesValue final : public absl::variant { public: using absl::variant::variant; }; // Older versions of GCC do not deal with inheriting from variant correctly when // using `visit`, so we cheat by upcasting. inline const absl::variant& AsVariant( const BytesValue& value) { return static_cast&>( value); } inline absl::variant& AsVariant( BytesValue& value) { return static_cast&>(value); } inline const absl::variant&& AsVariant( const BytesValue&& value) { return static_cast&&>( value); } inline absl::variant&& AsVariant( BytesValue&& value) { return static_cast&&>(value); } inline bool operator==(const BytesValue& lhs, const BytesValue& rhs) { return absl::visit( [](const auto& lhs, const auto& rhs) { return lhs == rhs; }, AsVariant(lhs), AsVariant(rhs)); } inline bool operator!=(const BytesValue& lhs, const BytesValue& rhs) { return !operator==(lhs, rhs); } template void AbslStringify(S& sink, const BytesValue& value) { sink.Append(absl::visit( [&](const auto& value) -> std::string { return absl::StrCat(value); }, AsVariant(value))); } BytesValue GetBytesField(const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); inline BytesValue GetBytesField( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return GetBytesField(message.GetReflection(), message, field, scratch); } BytesValue GetRepeatedBytesField( const google::protobuf::Reflection* absl_nonnull reflection, const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); inline BytesValue GetRepeatedBytesField( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::FieldDescriptor* absl_nonnull field, int index, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { return GetRepeatedBytesField(message.GetReflection(), message, field, index, scratch); } class NullValueReflection final { public: NullValueReflection() = default; NullValueReflection(const NullValueReflection&) = default; NullValueReflection& operator=(const NullValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize( const google::protobuf::EnumDescriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } private: const google::protobuf::EnumDescriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::EnumValueDescriptor* absl_nullable value_ = nullptr; }; class BoolValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_BOOLVALUE; using GeneratedMessageType = google::protobuf::BoolValue; static bool GetValue(const GeneratedMessageType& message) { return message.value(); } static void SetValue(GeneratedMessageType* absl_nonnull message, bool value) { message->set_value(value); } BoolValueReflection() = default; BoolValueReflection(const BoolValueReflection&) = default; BoolValueReflection& operator=(const BoolValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } bool GetValue(const google::protobuf::Message& message) const; void SetValue(google::protobuf::Message* absl_nonnull message, bool value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; }; absl::StatusOr GetBoolValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class Int32ValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_INT32VALUE; using GeneratedMessageType = google::protobuf::Int32Value; static int32_t GetValue(const GeneratedMessageType& message) { return message.value(); } static void SetValue(GeneratedMessageType* absl_nonnull message, int32_t value) { message->set_value(value); } Int32ValueReflection() = default; Int32ValueReflection(const Int32ValueReflection&) = default; Int32ValueReflection& operator=(const Int32ValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } int32_t GetValue(const google::protobuf::Message& message) const; void SetValue(google::protobuf::Message* absl_nonnull message, int32_t value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; }; absl::StatusOr GetInt32ValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class Int64ValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_INT64VALUE; using GeneratedMessageType = google::protobuf::Int64Value; static int64_t GetValue(const GeneratedMessageType& message) { return message.value(); } static void SetValue(GeneratedMessageType* absl_nonnull message, int64_t value) { message->set_value(value); } Int64ValueReflection() = default; Int64ValueReflection(const Int64ValueReflection&) = default; Int64ValueReflection& operator=(const Int64ValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } int64_t GetValue(const google::protobuf::Message& message) const; void SetValue(google::protobuf::Message* absl_nonnull message, int64_t value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; }; absl::StatusOr GetInt64ValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class UInt32ValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_UINT32VALUE; using GeneratedMessageType = google::protobuf::UInt32Value; static uint32_t GetValue(const GeneratedMessageType& message) { return message.value(); } static void SetValue(GeneratedMessageType* absl_nonnull message, uint32_t value) { message->set_value(value); } UInt32ValueReflection() = default; UInt32ValueReflection(const UInt32ValueReflection&) = default; UInt32ValueReflection& operator=(const UInt32ValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } uint32_t GetValue(const google::protobuf::Message& message) const; void SetValue(google::protobuf::Message* absl_nonnull message, uint32_t value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; }; absl::StatusOr GetUInt32ValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class UInt64ValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_UINT64VALUE; using GeneratedMessageType = google::protobuf::UInt64Value; static uint64_t GetValue(const GeneratedMessageType& message) { return message.value(); } static void SetValue(GeneratedMessageType* absl_nonnull message, uint64_t value) { message->set_value(value); } UInt64ValueReflection() = default; UInt64ValueReflection(const UInt64ValueReflection&) = default; UInt64ValueReflection& operator=(const UInt64ValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } uint64_t GetValue(const google::protobuf::Message& message) const; void SetValue(google::protobuf::Message* absl_nonnull message, uint64_t value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; }; absl::StatusOr GetUInt64ValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class FloatValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_FLOATVALUE; using GeneratedMessageType = google::protobuf::FloatValue; static float GetValue(const GeneratedMessageType& message) { return message.value(); } static void SetValue(GeneratedMessageType* absl_nonnull message, float value) { message->set_value(value); } FloatValueReflection() = default; FloatValueReflection(const FloatValueReflection&) = default; FloatValueReflection& operator=(const FloatValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } float GetValue(const google::protobuf::Message& message) const; void SetValue(google::protobuf::Message* absl_nonnull message, float value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; }; absl::StatusOr GetFloatValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class DoubleValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_DOUBLEVALUE; using GeneratedMessageType = google::protobuf::DoubleValue; static double GetValue(const GeneratedMessageType& message) { return message.value(); } static void SetValue(GeneratedMessageType* absl_nonnull message, double value) { message->set_value(value); } DoubleValueReflection() = default; DoubleValueReflection(const DoubleValueReflection&) = default; DoubleValueReflection& operator=(const DoubleValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } double GetValue(const google::protobuf::Message& message) const; void SetValue(google::protobuf::Message* absl_nonnull message, double value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; }; absl::StatusOr GetDoubleValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class BytesValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_BYTESVALUE; using GeneratedMessageType = google::protobuf::BytesValue; static absl::Cord GetValue(const GeneratedMessageType& message) { return absl::Cord(message.value()); } static void SetValue(GeneratedMessageType* absl_nonnull message, const absl::Cord& value) { message->set_value(static_cast(value)); } BytesValueReflection() = default; BytesValueReflection(const BytesValueReflection&) = default; BytesValueReflection& operator=(const BytesValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } BytesValue GetValue(const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; void SetValue(google::protobuf::Message* absl_nonnull message, absl::string_view value) const; void SetValue(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; google::protobuf::FieldDescriptor::CppStringType value_field_string_type_; }; absl::StatusOr GetBytesValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class StringValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_STRINGVALUE; using GeneratedMessageType = google::protobuf::StringValue; static absl::string_view GetValue( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message.value(); } static void SetValue(GeneratedMessageType* absl_nonnull message, absl::string_view value) { message->set_value(value); } StringValueReflection() = default; StringValueReflection(const StringValueReflection&) = default; StringValueReflection& operator=(const StringValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } StringValue GetValue( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; void SetValue(google::protobuf::Message* absl_nonnull message, absl::string_view value) const; void SetValue(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; google::protobuf::FieldDescriptor::CppStringType value_field_string_type_; }; absl::StatusOr GetStringValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class AnyReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_ANY; using GeneratedMessageType = google::protobuf::Any; static absl::string_view GetTypeUrl( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message.type_url(); } static absl::Cord GetValue(const GeneratedMessageType& message) { return GetAnyValueAsCord(message); } static void SetTypeUrl(GeneratedMessageType* absl_nonnull message, absl::string_view type_url) { message->set_type_url(type_url); } static void SetValue(GeneratedMessageType* absl_nonnull message, const absl::Cord& value) { SetAnyValueFromCord(message, value); } AnyReflection() = default; AnyReflection(const AnyReflection&) = default; AnyReflection& operator=(const AnyReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } void SetTypeUrl(google::protobuf::Message* absl_nonnull message, absl::string_view type_url) const; void SetValue(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const; StringValue GetTypeUrl( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; BytesValue GetValue(const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable type_url_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable value_field_ = nullptr; google::protobuf::FieldDescriptor::CppStringType type_url_field_string_type_; google::protobuf::FieldDescriptor::CppStringType value_field_string_type_; }; absl::StatusOr GetAnyReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); AnyReflection GetAnyReflectionOrDie( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class DurationReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_DURATION; using GeneratedMessageType = google::protobuf::Duration; static int64_t GetSeconds(const GeneratedMessageType& message) { return message.seconds(); } static int64_t GetNanos(const GeneratedMessageType& message) { return message.nanos(); } static void SetSeconds(GeneratedMessageType* absl_nonnull message, int64_t value) { message->set_seconds(value); } static void SetNanos(GeneratedMessageType* absl_nonnull message, int32_t value) { message->set_nanos(value); } static absl::Status SetFromAbslDuration( GeneratedMessageType* absl_nonnull message, absl::Duration duration); DurationReflection() = default; DurationReflection(const DurationReflection&) = default; DurationReflection& operator=(const DurationReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } int64_t GetSeconds(const google::protobuf::Message& message) const; int32_t GetNanos(const google::protobuf::Message& message) const; void SetSeconds(google::protobuf::Message* absl_nonnull message, int64_t value) const; void SetNanos(google::protobuf::Message* absl_nonnull message, int32_t value) const; absl::Status SetFromAbslDuration(google::protobuf::Message* absl_nonnull message, absl::Duration duration) const; // Converts `absl::Duration` to `google.protobuf.Duration` without performing // validity checks. Avoid use. void UnsafeSetFromAbslDuration(google::protobuf::Message* absl_nonnull message, absl::Duration duration) const; absl::StatusOr ToAbslDuration( const google::protobuf::Message& message) const; // Converts `google.protobuf.Duration` to `absl::Duration` without performing // validity checks. Avoid use. absl::Duration UnsafeToAbslDuration(const google::protobuf::Message& message) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable seconds_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable nanos_field_ = nullptr; }; absl::StatusOr GetDurationReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class TimestampReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_TIMESTAMP; using GeneratedMessageType = google::protobuf::Timestamp; static int64_t GetSeconds(const GeneratedMessageType& message) { return message.seconds(); } static int64_t GetNanos(const GeneratedMessageType& message) { return message.nanos(); } static void SetSeconds(GeneratedMessageType* absl_nonnull message, int64_t value) { message->set_seconds(value); } static void SetNanos(GeneratedMessageType* absl_nonnull message, int32_t value) { message->set_nanos(value); } static absl::Status SetFromAbslTime( GeneratedMessageType* absl_nonnull message, absl::Time time); TimestampReflection() = default; TimestampReflection(const TimestampReflection&) = default; TimestampReflection& operator=(const TimestampReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } int64_t GetSeconds(const google::protobuf::Message& message) const; int32_t GetNanos(const google::protobuf::Message& message) const; void SetSeconds(google::protobuf::Message* absl_nonnull message, int64_t value) const; void SetNanos(google::protobuf::Message* absl_nonnull message, int32_t value) const; absl::StatusOr ToAbslTime(const google::protobuf::Message& message) const; // Converts `absl::Time` to `google.protobuf.Timestamp` without performing // validity checks. Avoid use. absl::Time UnsafeToAbslTime(const google::protobuf::Message& message) const; absl::Status SetFromAbslTime(google::protobuf::Message* absl_nonnull message, absl::Time time) const; // Converts `google.protobuf.Timestamp` to `absl::Time` without performing // validity checks. Avoid use. void UnsafeSetFromAbslTime(google::protobuf::Message* absl_nonnull message, absl::Time time) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable seconds_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable nanos_field_ = nullptr; }; absl::StatusOr GetTimestampReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class ValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE; using GeneratedMessageType = google::protobuf::Value; static google::protobuf::Value::KindCase GetKindCase( const google::protobuf::Value& message) { return message.kind_case(); } static bool GetBoolValue(const GeneratedMessageType& message) { return message.bool_value(); } static double GetNumberValue(const GeneratedMessageType& message) { return message.number_value(); } static absl::string_view GetStringValue( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message.string_value(); } static const google::protobuf::ListValue& GetListValue( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message.list_value(); } static const google::protobuf::Struct& GetStructValue( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message.struct_value(); } static void SetNullValue(GeneratedMessageType* absl_nonnull message) { message->set_null_value(google::protobuf::NULL_VALUE); } static void SetBoolValue(GeneratedMessageType* absl_nonnull message, bool value) { message->set_bool_value(value); } static void SetNumberValue(GeneratedMessageType* absl_nonnull message, int64_t value); static void SetNumberValue(GeneratedMessageType* absl_nonnull message, uint64_t value); static void SetNumberValue(GeneratedMessageType* absl_nonnull message, double value) { message->set_number_value(value); } static void SetStringValue(GeneratedMessageType* absl_nonnull message, absl::string_view value) { message->set_string_value(value); } static void SetStringValue(GeneratedMessageType* absl_nonnull message, const absl::Cord& value) { message->set_string_value(static_cast(value)); } static google::protobuf::ListValue* absl_nonnull MutableListValue( GeneratedMessageType* absl_nonnull message) { return message->mutable_list_value(); } static google::protobuf::Struct* absl_nonnull MutableStructValue( GeneratedMessageType* absl_nonnull message) { return message->mutable_struct_value(); } ValueReflection() = default; ValueReflection(const ValueReflection&) = default; ValueReflection& operator=(const ValueReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } const google::protobuf::Descriptor* absl_nonnull GetStructDescriptor() const { ABSL_DCHECK(IsInitialized()); return struct_value_field_->message_type(); } const google::protobuf::Descriptor* absl_nonnull GetListValueDescriptor() const { ABSL_DCHECK(IsInitialized()); return list_value_field_->message_type(); } google::protobuf::Value::KindCase GetKindCase( const google::protobuf::Message& message) const; bool GetBoolValue(const google::protobuf::Message& message) const; double GetNumberValue(const google::protobuf::Message& message) const; StringValue GetStringValue( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; const google::protobuf::Message& GetListValue( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; const google::protobuf::Message& GetStructValue( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; void SetNullValue(google::protobuf::Message* absl_nonnull message) const; void SetBoolValue(google::protobuf::Message* absl_nonnull message, bool value) const; void SetNumberValue(google::protobuf::Message* absl_nonnull message, int64_t value) const; void SetNumberValue(google::protobuf::Message* absl_nonnull message, uint64_t value) const; void SetNumberValue(google::protobuf::Message* absl_nonnull message, double value) const; void SetStringValue(google::protobuf::Message* absl_nonnull message, absl::string_view value) const; void SetStringValue(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const; void SetStringValueFromBytes(google::protobuf::Message* absl_nonnull message, absl::string_view value) const; void SetStringValueFromBytes(google::protobuf::Message* absl_nonnull message, const absl::Cord& value) const; void SetStringValueFromDuration(google::protobuf::Message* absl_nonnull message, absl::Duration duration) const; void SetStringValueFromTimestamp(google::protobuf::Message* absl_nonnull message, absl::Time time) const; google::protobuf::Message* absl_nonnull MutableListValue( google::protobuf::Message* absl_nonnull message) const; google::protobuf::Message* absl_nonnull MutableStructValue( google::protobuf::Message* absl_nonnull message) const; Unique ReleaseListValue( google::protobuf::Message* absl_nonnull message) const; Unique ReleaseStructValue( google::protobuf::Message* absl_nonnull message) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::OneofDescriptor* absl_nullable kind_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable null_value_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable bool_value_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable number_value_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable string_value_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable list_value_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable struct_value_field_ = nullptr; google::protobuf::FieldDescriptor::CppStringType string_value_field_string_type_; }; absl::StatusOr GetValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); // `GetValueReflectionOrDie()` is the same as `GetValueReflection` // except that it aborts if `descriptor` is not a well formed descriptor of // `google.protobuf.Value`. This should only be used in places where it is // guaranteed that the aforementioned prerequisites are met. ValueReflection GetValueReflectionOrDie( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class ListValueReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE; using GeneratedMessageType = google::protobuf::ListValue; static int ValuesSize(const GeneratedMessageType& message) { return message.values_size(); } static const google::protobuf::RepeatedPtrField& Values( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message.values(); } static const google::protobuf::Value& Values( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND, int index) { return message.values(index); } static google::protobuf::RepeatedPtrField& MutableValues( GeneratedMessageType* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return *message->mutable_values(); } static google::protobuf::Value* absl_nonnull AddValues( GeneratedMessageType* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message->add_values(); } absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } const google::protobuf::Descriptor* absl_nonnull GetValueDescriptor() const { ABSL_DCHECK(IsInitialized()); return values_field_->message_type(); } const google::protobuf::FieldDescriptor* absl_nonnull GetValuesDescriptor() const { ABSL_DCHECK(IsInitialized()); return values_field_; } int ValuesSize(const google::protobuf::Message& message) const; google::protobuf::RepeatedFieldRef Values( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; const google::protobuf::Message& Values(const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, int index) const; google::protobuf::MutableRepeatedFieldRef MutableValues( google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; google::protobuf::Message* absl_nonnull AddValues( google::protobuf::Message* absl_nonnull message) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable values_field_ = nullptr; }; absl::StatusOr GetListValueReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); // `GetListValueReflectionOrDie()` is the same as `GetListValueReflection` // except that it aborts if `descriptor` is not a well formed descriptor of // `google.protobuf.ListValue`. This should only be used in places where it is // guaranteed that the aforementioned prerequisites are met. ListValueReflection GetListValueReflectionOrDie( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class StructReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT; using GeneratedMessageType = google::protobuf::Struct; static int FieldsSize(const GeneratedMessageType& message) { return message.fields_size(); } static auto BeginFields( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message.fields().begin(); } static auto EndFields( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND) { return message.fields().end(); } static bool ContainsField(const GeneratedMessageType& message, absl::string_view name) { return message.fields().contains(name); } static const google::protobuf::Value* absl_nullable FindField( const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view name) { if (auto it = message.fields().find(name); it != message.fields().end()) { return &it->second; } return nullptr; } static google::protobuf::Value* absl_nonnull InsertField( GeneratedMessageType* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view name) { return &(*message->mutable_fields())[name]; } static bool DeleteField(GeneratedMessageType* absl_nonnull message, absl::string_view name) { return message->mutable_fields()->erase(name) > 0; } absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } const google::protobuf::Descriptor* absl_nonnull GetValueDescriptor() const { ABSL_DCHECK(IsInitialized()); return fields_value_field_->message_type(); } const google::protobuf::FieldDescriptor* absl_nonnull GetFieldsDescriptor() const { ABSL_DCHECK(IsInitialized()); return fields_field_; } int FieldsSize(const google::protobuf::Message& message) const; google::protobuf::ConstMapIterator BeginFields( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; google::protobuf::ConstMapIterator EndFields( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; bool ContainsField(const google::protobuf::Message& message, absl::string_view name) const; const google::protobuf::Message* absl_nullable FindField( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view name) const; google::protobuf::Message* absl_nonnull InsertField( google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view name) const; bool DeleteField(google::protobuf::Message* absl_nonnull message ABSL_ATTRIBUTE_LIFETIME_BOUND, absl::string_view name) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable fields_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable fields_key_field_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable fields_value_field_ = nullptr; }; absl::StatusOr GetStructReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); // `GetStructReflectionOrDie()` is the same as `GetStructReflection` // except that it aborts if `descriptor` is not a well formed descriptor of // `google.protobuf.Struct`. This should only be used in places where it is // guaranteed that the aforementioned prerequisites are met. StructReflection GetStructReflectionOrDie( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); class FieldMaskReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = google::protobuf::Descriptor::WELLKNOWNTYPE_FIELDMASK; using GeneratedMessageType = google::protobuf::FieldMask; static int PathsSize(const GeneratedMessageType& message) { return message.paths_size(); } static absl::string_view Paths(const GeneratedMessageType& message ABSL_ATTRIBUTE_LIFETIME_BOUND, int index) { return message.paths(index); } absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const { return descriptor_ != nullptr; } const google::protobuf::Descriptor* absl_nonnull GetDescriptor() const { ABSL_DCHECK(IsInitialized()); return descriptor_; } int PathsSize(const google::protobuf::Message& message) const; StringValue Paths( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, int index, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; private: const google::protobuf::Descriptor* absl_nullable descriptor_ = nullptr; const google::protobuf::FieldDescriptor* absl_nullable paths_field_ = nullptr; google::protobuf::FieldDescriptor::CppStringType paths_field_string_type_; }; absl::StatusOr GetFieldMaskReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); using ListValuePtr = Unique; using ListValueConstRef = std::reference_wrapper; using StructPtr = Unique; using StructConstRef = std::reference_wrapper; // Variant holding `std::reference_wrapper` or `Unique`, either of which is an // instance of `google.protobuf.ListValue` which is either a generated message // or dynamic message. class ListValue final : public absl::variant { using absl::variant::variant; }; // Older versions of GCC do not deal with inheriting from variant correctly when // using `visit`, so we cheat by upcasting. inline const absl::variant& AsVariant( const ListValue& value) { return static_cast&>( value); } inline absl::variant& AsVariant( ListValue& value) { return static_cast&>(value); } inline const absl::variant&& AsVariant( const ListValue&& value) { return static_cast&&>( value); } inline absl::variant&& AsVariant( ListValue&& value) { return static_cast&&>(value); } // Variant holding `std::reference_wrapper` or `Unique`, either of which is an // instance of `google.protobuf.Struct` which is either a generated message or // dynamic message. class Struct final : public absl::variant { public: using absl::variant::variant; }; // Older versions of GCC do not deal with inheriting from variant correctly when // using `visit`, so we cheat by upcasting. inline const absl::variant& AsVariant( const Struct& value) { return static_cast&>(value); } inline absl::variant& AsVariant(Struct& value) { return static_cast&>(value); } inline const absl::variant&& AsVariant( const Struct&& value) { return static_cast&&>(value); } inline absl::variant&& AsVariant(Struct&& value) { return static_cast&&>(value); } // Variant capable of representing any unwrapped well known type or message. using Value = absl::variant>; // Unpacks the given instance of `google.protobuf.Any`. absl::StatusOr> UnpackAnyFrom( google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND, AnyReflection& reflection, const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull factory ABSL_ATTRIBUTE_LIFETIME_BOUND); // Unpacks the given instance of `google.protobuf.Any` if it is resolvable. absl::StatusOr> UnpackAnyIfResolveable( google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND, AnyReflection& reflection, const google::protobuf::Message& message, const google::protobuf::DescriptorPool* absl_nonnull pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull factory ABSL_ATTRIBUTE_LIFETIME_BOUND); // Performs any necessary unwrapping of a well known message type. If no // unwrapping is necessary, the resulting `Value` holds the alternative // `absl::monostate`. absl::StatusOr AdaptFromMessage( google::protobuf::Arena* absl_nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, const google::protobuf::DescriptorPool* absl_nonnull pool ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nonnull factory ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); class JsonReflection final { public: JsonReflection() = default; JsonReflection(const JsonReflection&) = default; JsonReflection& operator=(const JsonReflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); absl::Status Initialize(const google::protobuf::Descriptor* absl_nonnull descriptor); bool IsInitialized() const; ValueReflection& Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } ListValueReflection& ListValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return list_value_; } StructReflection& Struct() ABSL_ATTRIBUTE_LIFETIME_BOUND { return struct_; } const ValueReflection& Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_; } const ListValueReflection& ListValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return list_value_; } const StructReflection& Struct() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return struct_; } private: ValueReflection value_; ListValueReflection list_value_; StructReflection struct_; }; class Reflection final { public: Reflection() = default; Reflection(const Reflection&) = default; Reflection& operator=(const Reflection&) = default; absl::Status Initialize(const google::protobuf::DescriptorPool* absl_nonnull pool); bool IsInitialized() const; // At the moment we only use this class for verifying well known types in // descriptor pools. We could eagerly initialize it and cache it somewhere to // make things faster. BoolValueReflection& BoolValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return bool_value_; } Int32ValueReflection& Int32Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return int32_value_; } Int64ValueReflection& Int64Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return int64_value_; } UInt32ValueReflection& UInt32Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return uint32_value_; } UInt64ValueReflection& UInt64Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return uint64_value_; } FloatValueReflection& FloatValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return float_value_; } DoubleValueReflection& DoubleValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return double_value_; } BytesValueReflection& BytesValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return bytes_value_; } StringValueReflection& StringValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return string_value_; } AnyReflection& Any() ABSL_ATTRIBUTE_LIFETIME_BOUND { return any_; } DurationReflection& Duration() ABSL_ATTRIBUTE_LIFETIME_BOUND { return duration_; } TimestampReflection& Timestamp() ABSL_ATTRIBUTE_LIFETIME_BOUND { return timestamp_; } JsonReflection& Json() ABSL_ATTRIBUTE_LIFETIME_BOUND { return json_; } ValueReflection& Value() ABSL_ATTRIBUTE_LIFETIME_BOUND { return Json().Value(); } ListValueReflection& ListValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return Json().ListValue(); } StructReflection& Struct() ABSL_ATTRIBUTE_LIFETIME_BOUND { return Json().Struct(); } FieldMaskReflection& FieldMask() ABSL_ATTRIBUTE_LIFETIME_BOUND { return field_mask_; } const BoolValueReflection& BoolValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return bool_value_; } const Int32ValueReflection& Int32Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return int32_value_; } const Int64ValueReflection& Int64Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return int64_value_; } const UInt32ValueReflection& UInt32Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return uint32_value_; } const UInt64ValueReflection& UInt64Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return uint64_value_; } const FloatValueReflection& FloatValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return float_value_; } const DoubleValueReflection& DoubleValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return double_value_; } const BytesValueReflection& BytesValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return bytes_value_; } const StringValueReflection& StringValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return string_value_; } const AnyReflection& Any() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return any_; } const DurationReflection& Duration() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return duration_; } const TimestampReflection& Timestamp() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return timestamp_; } const JsonReflection& Json() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return json_; } const ValueReflection& Value() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return Json().Value(); } const ListValueReflection& ListValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return Json().ListValue(); } const StructReflection& Struct() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return Json().Struct(); } const FieldMaskReflection& FieldMask() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return field_mask_; } private: NullValueReflection& NullValue() ABSL_ATTRIBUTE_LIFETIME_BOUND { return null_value_; } const NullValueReflection& NullValue() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return null_value_; } NullValueReflection null_value_; BoolValueReflection bool_value_; Int32ValueReflection int32_value_; Int64ValueReflection int64_value_; UInt32ValueReflection uint32_value_; UInt64ValueReflection uint64_value_; FloatValueReflection float_value_; DoubleValueReflection double_value_; BytesValueReflection bytes_value_; StringValueReflection string_value_; AnyReflection any_; DurationReflection duration_; TimestampReflection timestamp_; JsonReflection json_; FieldMaskReflection field_mask_; }; } // namespace cel::well_known_types #endif // THIRD_PARTY_CEL_CPP_INTERNAL_WELL_KNOWN_TYPES_H_ ================================================ FILE: internal/well_known_types_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "internal/well_known_types.h" #include #include #include #include "google/protobuf/any.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/protobuf/wrappers.pb.h" #include "google/protobuf/descriptor.pb.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/variant.h" #include "common/memory.h" #include "internal/message_type_name.h" #include "internal/minimal_descriptor_pool.h" #include "internal/parse_text_proto.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::well_known_types { namespace { using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::internal::GetMinimalDescriptorPool; using ::cel::internal::GetTestingDescriptorPool; using ::cel::internal::GetTestingMessageFactory; using ::testing::_; using ::testing::HasSubstr; using ::testing::IsNull; using ::testing::NotNull; using ::testing::Test; using ::testing::VariantWith; using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes; class ReflectionTest : public Test { public: google::protobuf::Arena* absl_nonnull arena() ABSL_ATTRIBUTE_LIFETIME_BOUND { return &arena_; } std::string& scratch_space() ABSL_ATTRIBUTE_LIFETIME_BOUND { return scratch_space_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return GetTestingDescriptorPool(); } google::protobuf::MessageFactory* absl_nonnull message_factory() { return GetTestingMessageFactory(); } template T* absl_nonnull MakeGenerated() { return google::protobuf::Arena::Create(arena()); } template google::protobuf::Message* absl_nonnull MakeDynamic() { const auto* descriptor = ABSL_DIE_IF_NULL(descriptor_pool()->FindMessageTypeByName( internal::MessageTypeNameFor())); const auto* prototype = ABSL_DIE_IF_NULL(message_factory()->GetPrototype(descriptor)); return prototype->New(arena()); } private: google::protobuf::Arena arena_; std::string scratch_space_; }; TEST_F(ReflectionTest, MinimalDescriptorPool) { EXPECT_THAT(Reflection().Initialize(GetMinimalDescriptorPool()), IsOk()); } TEST_F(ReflectionTest, TestingDescriptorPool) { EXPECT_THAT(Reflection().Initialize(GetTestingDescriptorPool()), IsOk()); } TEST_F(ReflectionTest, BoolValue_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(BoolValueReflection::GetValue(*value), false); BoolValueReflection::SetValue(value, true); EXPECT_EQ(BoolValueReflection::GetValue(*value), true); } TEST_F(ReflectionTest, BoolValue_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetBoolValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value), false); reflection.SetValue(value, true); EXPECT_EQ(reflection.GetValue(*value), true); } TEST_F(ReflectionTest, Int32Value_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(Int32ValueReflection::GetValue(*value), 0); Int32ValueReflection::SetValue(value, 1); EXPECT_EQ(Int32ValueReflection::GetValue(*value), 1); } TEST_F(ReflectionTest, Int32Value_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetInt32ValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value), 0); reflection.SetValue(value, 1); EXPECT_EQ(reflection.GetValue(*value), 1); } TEST_F(ReflectionTest, Int64Value_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(Int64ValueReflection::GetValue(*value), 0); Int64ValueReflection::SetValue(value, 1); EXPECT_EQ(Int64ValueReflection::GetValue(*value), 1); } TEST_F(ReflectionTest, Int64Value_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetInt64ValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value), 0); reflection.SetValue(value, 1); EXPECT_EQ(reflection.GetValue(*value), 1); } TEST_F(ReflectionTest, UInt32Value_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(UInt32ValueReflection::GetValue(*value), 0); UInt32ValueReflection::SetValue(value, 1); EXPECT_EQ(UInt32ValueReflection::GetValue(*value), 1); } TEST_F(ReflectionTest, UInt32Value_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetUInt32ValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value), 0); reflection.SetValue(value, 1); EXPECT_EQ(reflection.GetValue(*value), 1); } TEST_F(ReflectionTest, UInt64Value_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(UInt64ValueReflection::GetValue(*value), 0); UInt64ValueReflection::SetValue(value, 1); EXPECT_EQ(UInt64ValueReflection::GetValue(*value), 1); } TEST_F(ReflectionTest, UInt64Value_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetUInt64ValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value), 0); reflection.SetValue(value, 1); EXPECT_EQ(reflection.GetValue(*value), 1); } TEST_F(ReflectionTest, FloatValue_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(FloatValueReflection::GetValue(*value), 0); FloatValueReflection::SetValue(value, 1); EXPECT_EQ(FloatValueReflection::GetValue(*value), 1); } TEST_F(ReflectionTest, FloatValue_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetFloatValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value), 0); reflection.SetValue(value, 1); EXPECT_EQ(reflection.GetValue(*value), 1); } TEST_F(ReflectionTest, DoubleValue_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(DoubleValueReflection::GetValue(*value), 0); DoubleValueReflection::SetValue(value, 1); EXPECT_EQ(DoubleValueReflection::GetValue(*value), 1); } TEST_F(ReflectionTest, DoubleValue_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetDoubleValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value), 0); reflection.SetValue(value, 1); EXPECT_EQ(reflection.GetValue(*value), 1); } TEST_F(ReflectionTest, BytesValue_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(BytesValueReflection::GetValue(*value), ""); BytesValueReflection::SetValue(value, absl::Cord("Hello World!")); EXPECT_EQ(BytesValueReflection::GetValue(*value), "Hello World!"); } TEST_F(ReflectionTest, BytesValue_Dynamic) { auto* value = MakeDynamic(); std::string scratch; ASSERT_OK_AND_ASSIGN( auto reflection, GetBytesValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value, scratch), ""); reflection.SetValue(value, "Hello World!"); EXPECT_EQ(reflection.GetValue(*value, scratch), "Hello World!"); reflection.SetValue(value, absl::Cord()); EXPECT_EQ(reflection.GetValue(*value, scratch), ""); } TEST_F(ReflectionTest, StringValue_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(StringValueReflection::GetValue(*value), ""); StringValueReflection::SetValue(value, "Hello World!"); EXPECT_EQ(StringValueReflection::GetValue(*value), "Hello World!"); } TEST_F(ReflectionTest, StringValue_Dynamic) { auto* value = MakeDynamic(); std::string scratch; ASSERT_OK_AND_ASSIGN( auto reflection, GetStringValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetValue(*value, scratch), ""); reflection.SetValue(value, "Hello World!"); EXPECT_EQ(reflection.GetValue(*value, scratch), "Hello World!"); reflection.SetValue(value, absl::Cord()); EXPECT_EQ(reflection.GetValue(*value, scratch), ""); } TEST_F(ReflectionTest, Any_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(AnyReflection::GetTypeUrl(*value), ""); AnyReflection::SetTypeUrl(value, "Hello World!"); EXPECT_EQ(AnyReflection::GetTypeUrl(*value), "Hello World!"); EXPECT_EQ(AnyReflection::GetValue(*value), ""); AnyReflection::SetValue(value, absl::Cord("Hello World!")); EXPECT_EQ(AnyReflection::GetValue(*value), "Hello World!"); } TEST_F(ReflectionTest, Any_Dynamic) { auto* value = MakeDynamic(); std::string scratch; ASSERT_OK_AND_ASSIGN( auto reflection, GetAnyReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetTypeUrl(*value, scratch), ""); reflection.SetTypeUrl(value, "Hello World!"); EXPECT_EQ(reflection.GetTypeUrl(*value, scratch), "Hello World!"); EXPECT_EQ(reflection.GetValue(*value, scratch), ""); reflection.SetValue(value, absl::Cord("Hello World!")); EXPECT_EQ(reflection.GetValue(*value, scratch), "Hello World!"); } TEST_F(ReflectionTest, Duration_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(DurationReflection::GetSeconds(*value), 0); DurationReflection::SetSeconds(value, 1); EXPECT_EQ(DurationReflection::GetSeconds(*value), 1); EXPECT_EQ(DurationReflection::GetNanos(*value), 0); DurationReflection::SetNanos(value, 1); EXPECT_EQ(DurationReflection::GetNanos(*value), 1); EXPECT_THAT(DurationReflection::SetFromAbslDuration( value, absl::Seconds(1) + absl::Nanoseconds(1)), IsOk()); EXPECT_EQ(value->seconds(), 1); EXPECT_EQ(value->nanos(), 1); EXPECT_THAT( DurationReflection::SetFromAbslDuration(value, absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT( DurationReflection::SetFromAbslDuration(value, -absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(ReflectionTest, Duration_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetDurationReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetSeconds(*value), 0); reflection.SetSeconds(value, 1); EXPECT_EQ(reflection.GetSeconds(*value), 1); EXPECT_EQ(reflection.GetNanos(*value), 0); reflection.SetNanos(value, 1); EXPECT_EQ(reflection.GetNanos(*value), 1); EXPECT_THAT(reflection.SetFromAbslDuration( value, absl::Seconds(1) + absl::Nanoseconds(1)), IsOk()); EXPECT_EQ(reflection.GetSeconds(*value), 1); EXPECT_EQ(reflection.GetNanos(*value), 1); EXPECT_THAT(reflection.SetFromAbslDuration(value, absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(reflection.SetFromAbslDuration(value, -absl::InfiniteDuration()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(ReflectionTest, Timestamp_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(TimestampReflection::GetSeconds(*value), 0); TimestampReflection::SetSeconds(value, 1); EXPECT_EQ(TimestampReflection::GetSeconds(*value), 1); EXPECT_EQ(TimestampReflection::GetNanos(*value), 0); TimestampReflection::SetNanos(value, 1); EXPECT_EQ(TimestampReflection::GetNanos(*value), 1); EXPECT_THAT( TimestampReflection::SetFromAbslTime( value, absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)), IsOk()); EXPECT_EQ(value->seconds(), 1); EXPECT_EQ(value->nanos(), 1); EXPECT_THAT( TimestampReflection::SetFromAbslTime(value, absl::InfiniteFuture()), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(TimestampReflection::SetFromAbslTime(value, absl::InfinitePast()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(ReflectionTest, Timestamp_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetTimestampReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetSeconds(*value), 0); reflection.SetSeconds(value, 1); EXPECT_EQ(reflection.GetSeconds(*value), 1); EXPECT_EQ(reflection.GetNanos(*value), 0); reflection.SetNanos(value, 1); EXPECT_EQ(reflection.GetNanos(*value), 1); EXPECT_THAT( reflection.SetFromAbslTime( value, absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)), IsOk()); EXPECT_EQ(reflection.GetSeconds(*value), 1); EXPECT_EQ(reflection.GetNanos(*value), 1); EXPECT_THAT(reflection.SetFromAbslTime(value, absl::InfiniteFuture()), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(reflection.SetFromAbslTime(value, absl::InfinitePast()), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(ReflectionTest, Value_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(ValueReflection::GetKindCase(*value), google::protobuf::Value::KIND_NOT_SET); ValueReflection::SetNullValue(value); EXPECT_EQ(ValueReflection::GetKindCase(*value), google::protobuf::Value::kNullValue); ValueReflection::SetBoolValue(value, true); EXPECT_EQ(ValueReflection::GetKindCase(*value), google::protobuf::Value::kBoolValue); EXPECT_EQ(ValueReflection::GetBoolValue(*value), true); ValueReflection::SetNumberValue(value, 1.0); EXPECT_EQ(ValueReflection::GetKindCase(*value), google::protobuf::Value::kNumberValue); EXPECT_EQ(ValueReflection::GetNumberValue(*value), 1.0); ValueReflection::SetStringValue(value, "Hello World!"); EXPECT_EQ(ValueReflection::GetKindCase(*value), google::protobuf::Value::kStringValue); EXPECT_EQ(ValueReflection::GetStringValue(*value), "Hello World!"); ValueReflection::MutableListValue(value); EXPECT_EQ(ValueReflection::GetKindCase(*value), google::protobuf::Value::kListValue); EXPECT_EQ(ValueReflection::GetListValue(*value).ByteSizeLong(), 0); ValueReflection::MutableStructValue(value); EXPECT_EQ(ValueReflection::GetKindCase(*value), google::protobuf::Value::kStructValue); EXPECT_EQ(ValueReflection::GetStructValue(*value).ByteSizeLong(), 0); } TEST_F(ReflectionTest, Value_Dynamic) { auto* value = MakeDynamic(); std::string scratch; ASSERT_OK_AND_ASSIGN( auto reflection, GetValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.GetKindCase(*value), google::protobuf::Value::KIND_NOT_SET); reflection.SetNullValue(value); EXPECT_EQ(reflection.GetKindCase(*value), google::protobuf::Value::kNullValue); reflection.SetBoolValue(value, true); EXPECT_EQ(reflection.GetKindCase(*value), google::protobuf::Value::kBoolValue); EXPECT_EQ(reflection.GetBoolValue(*value), true); reflection.SetNumberValue(value, 1.0); EXPECT_EQ(reflection.GetKindCase(*value), google::protobuf::Value::kNumberValue); EXPECT_EQ(reflection.GetNumberValue(*value), 1.0); reflection.SetStringValue(value, "Hello World!"); EXPECT_EQ(reflection.GetKindCase(*value), google::protobuf::Value::kStringValue); EXPECT_EQ(reflection.GetStringValue(*value, scratch), "Hello World!"); reflection.MutableListValue(value); EXPECT_EQ(reflection.GetKindCase(*value), google::protobuf::Value::kListValue); EXPECT_EQ(reflection.GetListValue(*value).ByteSizeLong(), 0); EXPECT_THAT(reflection.ReleaseListValue(value), NotNull()); reflection.MutableStructValue(value); EXPECT_EQ(reflection.GetKindCase(*value), google::protobuf::Value::kStructValue); EXPECT_EQ(reflection.GetStructValue(*value).ByteSizeLong(), 0); EXPECT_THAT(reflection.ReleaseStructValue(value), NotNull()); } TEST_F(ReflectionTest, ListValue_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(ListValueReflection::ValuesSize(*value), 0); EXPECT_EQ(ListValueReflection::Values(*value).size(), 0); EXPECT_EQ(ListValueReflection::MutableValues(value).size(), 0); } TEST_F(ReflectionTest, ListValue_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetListValueReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.ValuesSize(*value), 0); EXPECT_EQ(reflection.Values(*value).size(), 0); EXPECT_EQ(reflection.MutableValues(value).size(), 0); } TEST_F(ReflectionTest, StructValue_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(StructReflection::FieldsSize(*value), 0); EXPECT_EQ(StructReflection::BeginFields(*value), StructReflection::EndFields(*value)); EXPECT_FALSE(StructReflection::ContainsField(*value, "foo")); EXPECT_THAT(StructReflection::FindField(*value, "foo"), IsNull()); EXPECT_THAT(StructReflection::InsertField(value, "foo"), NotNull()); EXPECT_TRUE(StructReflection::DeleteField(value, "foo")); } TEST_F(ReflectionTest, StructValue_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetStructReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.FieldsSize(*value), 0); EXPECT_EQ(reflection.BeginFields(*value), reflection.EndFields(*value)); EXPECT_FALSE(reflection.ContainsField(*value, "foo")); EXPECT_THAT(reflection.FindField(*value, "foo"), IsNull()); EXPECT_THAT(reflection.InsertField(value, "foo"), NotNull()); EXPECT_TRUE(reflection.DeleteField(value, "foo")); } TEST_F(ReflectionTest, FieldMask_Generated) { auto* value = MakeGenerated(); EXPECT_EQ(FieldMaskReflection::PathsSize(*value), 0); value->add_paths("foo"); EXPECT_EQ(FieldMaskReflection::PathsSize(*value), 1); EXPECT_EQ(FieldMaskReflection::Paths(*value, 0), "foo"); } TEST_F(ReflectionTest, FieldMask_Dynamic) { auto* value = MakeDynamic(); ASSERT_OK_AND_ASSIGN( auto reflection, GetFieldMaskReflection(ABSL_DIE_IF_NULL(value->GetDescriptor()))); EXPECT_EQ(reflection.PathsSize(*value), 0); value->GetReflection()->AddString( &*value, ABSL_DIE_IF_NULL(value->GetDescriptor()->FindFieldByName("paths")), "foo"); EXPECT_EQ(reflection.PathsSize(*value), 1); EXPECT_EQ(reflection.Paths(*value, 0, scratch_space()), "foo"); } TEST_F(ReflectionTest, NullValue_MissingValue) { google::protobuf::DescriptorPool descriptor_pool; { google::protobuf::FileDescriptorProto file_proto; file_proto.set_name("google/protobuf/struct.proto"); file_proto.set_syntax("editions"); file_proto.set_edition(google::protobuf::EDITION_2023); file_proto.set_package("google.protobuf"); auto* enum_proto = file_proto.add_enum_type(); enum_proto->set_name("NullValue"); auto* value_proto = enum_proto->add_value(); value_proto->set_number(1); value_proto->set_name("NULL_VALUE"); enum_proto->mutable_options()->mutable_features()->set_enum_type( google::protobuf::FeatureSet::CLOSED); ASSERT_THAT(descriptor_pool.BuildFile(file_proto), NotNull()); } EXPECT_THAT( NullValueReflection().Initialize(&descriptor_pool), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("well known protocol buffer enum missing value: "))); } TEST_F(ReflectionTest, NullValue_MultipleValues) { google::protobuf::DescriptorPool descriptor_pool; { google::protobuf::FileDescriptorProto file_proto; file_proto.set_name("google/protobuf/struct.proto"); file_proto.set_syntax("proto3"); file_proto.set_package("google.protobuf"); auto* enum_proto = file_proto.add_enum_type(); enum_proto->set_name("NullValue"); auto* value_proto = enum_proto->add_value(); value_proto->set_number(0); value_proto->set_name("NULL_VALUE"); value_proto = enum_proto->add_value(); value_proto->set_number(1); value_proto->set_name("NULL_VALUE2"); ASSERT_THAT(descriptor_pool.BuildFile(file_proto), NotNull()); } EXPECT_THAT( NullValueReflection().Initialize(&descriptor_pool), StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("well known protocol buffer enum has multiple values: "))); } TEST_F(ReflectionTest, EnumDescriptorMissing) { google::protobuf::DescriptorPool descriptor_pool; EXPECT_THAT(NullValueReflection().Initialize(&descriptor_pool), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("descriptor missing for protocol buffer enum " "well known type: "))); } TEST_F(ReflectionTest, MessageDescriptorMissing) { google::protobuf::DescriptorPool descriptor_pool; EXPECT_THAT(BoolValueReflection().Initialize(&descriptor_pool), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("descriptor missing for protocol buffer " "message well known type: "))); } class AdaptFromMessageTest : public Test { public: google::protobuf::Arena* absl_nonnull arena() ABSL_ATTRIBUTE_LIFETIME_BOUND { return &arena_; } std::string& scratch_space() ABSL_ATTRIBUTE_LIFETIME_BOUND { return scratch_space_; } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() { return GetTestingDescriptorPool(); } google::protobuf::MessageFactory* absl_nonnull message_factory() { return GetTestingMessageFactory(); } template google::protobuf::Message* absl_nonnull MakeDynamic() { const auto* descriptor_pool = GetTestingDescriptorPool(); const auto* descriptor = ABSL_DIE_IF_NULL(descriptor_pool->FindMessageTypeByName( internal::MessageTypeNameFor())); const auto* prototype = ABSL_DIE_IF_NULL(GetTestingMessageFactory()->GetPrototype(descriptor)); return prototype->New(arena()); } template google::protobuf::Message* DynamicParseTextProto(absl::string_view text) { return ::cel::internal::DynamicParseTextProto( arena(), text, descriptor_pool(), message_factory()); } absl::StatusOr AdaptFromMessage(const google::protobuf::Message& message) { return well_known_types::AdaptFromMessage( arena(), message, descriptor_pool(), message_factory(), scratch_space()); } private: google::protobuf::Arena arena_; std::string scratch_space_; }; TEST_F(AdaptFromMessageTest, BoolValue) { auto message = DynamicParseTextProto(R"pb(value: true)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(true))); } TEST_F(AdaptFromMessageTest, Int32Value) { auto message = DynamicParseTextProto(R"pb(value: 1)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1))); } TEST_F(AdaptFromMessageTest, Int64Value) { auto message = DynamicParseTextProto(R"pb(value: 1)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1))); } TEST_F(AdaptFromMessageTest, UInt32Value) { auto message = DynamicParseTextProto(R"pb(value: 1)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1))); } TEST_F(AdaptFromMessageTest, UInt64Value) { auto message = DynamicParseTextProto(R"pb(value: 1)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1))); } TEST_F(AdaptFromMessageTest, FloatValue) { auto message = DynamicParseTextProto(R"pb(value: 1.0)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1))); } TEST_F(AdaptFromMessageTest, DoubleValue) { auto message = DynamicParseTextProto(R"pb(value: 1.0)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1))); } TEST_F(AdaptFromMessageTest, BytesValue) { auto message = DynamicParseTextProto( R"pb(value: "foo")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(BytesValue("foo")))); } TEST_F(AdaptFromMessageTest, StringValue) { auto message = DynamicParseTextProto( R"pb(value: "foo")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(StringValue("foo")))); } TEST_F(AdaptFromMessageTest, Duration) { auto message = DynamicParseTextProto( R"pb(seconds: 1 nanos: 1)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(absl::Seconds(1) + absl::Nanoseconds(1)))); } TEST_F(AdaptFromMessageTest, Duration_SecondsOutOfRange) { auto message = DynamicParseTextProto( R"pb(seconds: 0x7fffffffffffffff nanos: 1)pb"); EXPECT_THAT(AdaptFromMessage(*message), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid duration seconds: "))); } TEST_F(AdaptFromMessageTest, Duration_NanosOutOfRange) { auto message = DynamicParseTextProto( R"pb(seconds: 1 nanos: 0x7fffffff)pb"); EXPECT_THAT(AdaptFromMessage(*message), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid duration nanoseconds: "))); } TEST_F(AdaptFromMessageTest, Duration_SignMismatch) { auto message = DynamicParseTextProto(R"pb(seconds: -1 nanos: 1)pb"); EXPECT_THAT(AdaptFromMessage(*message), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("duration sign mismatch: "))); } TEST_F(AdaptFromMessageTest, Timestamp) { auto message = DynamicParseTextProto(R"pb(seconds: 1 nanos: 1)pb"); EXPECT_THAT( AdaptFromMessage(*message), IsOkAndHolds(VariantWith( absl::UnixEpoch() + absl::Seconds(1) + absl::Nanoseconds(1)))); } TEST_F(AdaptFromMessageTest, Timestamp_SecondsOutOfRange) { auto message = DynamicParseTextProto( R"pb(seconds: 0x7fffffffffffffff nanos: 1)pb"); EXPECT_THAT(AdaptFromMessage(*message), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid timestamp seconds: "))); } TEST_F(AdaptFromMessageTest, Timestamp_NanosOutOfRange) { auto message = DynamicParseTextProto( R"pb(seconds: 1 nanos: 0x7fffffff)pb"); EXPECT_THAT(AdaptFromMessage(*message), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid timestamp nanoseconds: "))); } TEST_F(AdaptFromMessageTest, Value_NullValue) { auto message = DynamicParseTextProto( R"pb(null_value: NULL_VALUE)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(nullptr))); } TEST_F(AdaptFromMessageTest, Value_BoolValue) { auto message = DynamicParseTextProto(R"pb(bool_value: true)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(true))); } TEST_F(AdaptFromMessageTest, Value_NumberValue) { auto message = DynamicParseTextProto( R"pb(number_value: 1.0)pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(1.0))); } TEST_F(AdaptFromMessageTest, Value_StringValue) { auto message = DynamicParseTextProto( R"pb(string_value: "foo")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(StringValue("foo")))); } TEST_F(AdaptFromMessageTest, Value_ListValue) { auto message = DynamicParseTextProto(R"pb(list_value: {})pb"); EXPECT_THAT( AdaptFromMessage(*message), IsOkAndHolds(VariantWith(VariantWith(_)))); } TEST_F(AdaptFromMessageTest, Value_StructValue) { auto message = DynamicParseTextProto(R"pb(struct_value: {})pb"); EXPECT_THAT( AdaptFromMessage(*message), IsOkAndHolds(VariantWith(VariantWith(_)))); } TEST_F(AdaptFromMessageTest, ListValue) { auto message = DynamicParseTextProto(R"pb()pb"); EXPECT_THAT( AdaptFromMessage(*message), IsOkAndHolds(VariantWith(VariantWith(_)))); } TEST_F(AdaptFromMessageTest, Struct) { auto message = DynamicParseTextProto(R"pb()pb"); EXPECT_THAT( AdaptFromMessage(*message), IsOkAndHolds(VariantWith(VariantWith(_)))); } TEST_F(AdaptFromMessageTest, TestAllTypesProto3) { auto message = DynamicParseTextProto(R"pb()pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(absl::monostate()))); } TEST_F(AdaptFromMessageTest, Any_BoolValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.BoolValue")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(false))); } TEST_F(AdaptFromMessageTest, Any_Int32Value) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Int32Value")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0))); } TEST_F(AdaptFromMessageTest, Any_Int64Value) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Int64Value")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0))); } TEST_F(AdaptFromMessageTest, Any_UInt32Value) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.UInt32Value")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0))); } TEST_F(AdaptFromMessageTest, Any_UInt64Value) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.UInt64Value")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0))); } TEST_F(AdaptFromMessageTest, Any_FloatValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.FloatValue")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0))); } TEST_F(AdaptFromMessageTest, Any_DoubleValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.DoubleValue")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0))); } TEST_F(AdaptFromMessageTest, Any_BytesValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.BytesValue")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(BytesValue()))); } TEST_F(AdaptFromMessageTest, Any_StringValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.StringValue")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(StringValue()))); } TEST_F(AdaptFromMessageTest, Any_Duration) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Duration")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(absl::ZeroDuration()))); } TEST_F(AdaptFromMessageTest, Any_Timestamp) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Timestamp")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(absl::UnixEpoch()))); } TEST_F(AdaptFromMessageTest, Any_Value_NullValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Value")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(nullptr))); } TEST_F(AdaptFromMessageTest, Any_Value_BoolValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Value" value: "\x20\x01")pb"); // bool_value: true EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(true))); } TEST_F(AdaptFromMessageTest, Any_Value_NumberValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Value" value: "\x11\x00\x00\x00\x00\x00\x00\x00\x00")pb"); // number_value: // 1.0 EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(0.0))); } TEST_F(AdaptFromMessageTest, Any_Value_StringValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Value" value: "\x1a\x03\x66\x6f\x6f")pb"); // string_value: "foo" EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith(StringValue("foo")))); } TEST_F(AdaptFromMessageTest, Any_Value_ListValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Value" value: "\x32\x00")pb"); // list_value: {} EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith( VariantWith(NotNull())))); } TEST_F(AdaptFromMessageTest, Any_Value_StructValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Value" value: "\x2a\x00")pb"); // struct_value: {} EXPECT_THAT( AdaptFromMessage(*message), IsOkAndHolds(VariantWith(VariantWith(NotNull())))); } TEST_F(AdaptFromMessageTest, Any_ListValue) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.ListValue")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith( VariantWith(NotNull())))); } TEST_F(AdaptFromMessageTest, Any_Struct) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/google.protobuf.Struct")pb"); EXPECT_THAT( AdaptFromMessage(*message), IsOkAndHolds(VariantWith(VariantWith(NotNull())))); } TEST_F(AdaptFromMessageTest, Any_TestAllTypesProto3) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes")pb"); EXPECT_THAT(AdaptFromMessage(*message), IsOkAndHolds(VariantWith>(NotNull()))); } TEST_F(AdaptFromMessageTest, Any_BadTypeUrlDomain) { auto message = DynamicParseTextProto( R"pb(type_url: "type.example.com/google.protobuf.BoolValue")pb"); EXPECT_THAT(AdaptFromMessage(*message), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("unable to find descriptor for type URL: "))); } TEST_F(AdaptFromMessageTest, Any_UnknownMessage) { auto message = DynamicParseTextProto( R"pb(type_url: "type.googleapis.com/message.that.does.not.Exist")pb"); EXPECT_THAT(AdaptFromMessage(*message), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("unable to find descriptor for type name: "))); } } // namespace } // namespace cel::well_known_types ================================================ FILE: parser/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "parser", srcs = [ "parser.cc", ], hdrs = [ "parser.h", ], copts = [ "-fexceptions", ], defines = [ "ANTLR4CPP_STATIC", ], deps = [ ":macro", ":macro_expr_factory", ":macro_registry", ":options", ":parser_interface", ":source_factory", "//common:ast", "//common:constant", "//common:expr_factory", "//common:operators", "//common:source", "//common/ast:expr_proto", "//common/ast:source_info_proto", "//internal:lexis", "//internal:status_macros", "//internal:strings", "//internal:utf8", "//parser/internal:cel_cc_parser", "@antlr4-cpp-runtime", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "macro", srcs = [ "macro.cc", ], hdrs = [ "macro.h", ], deps = [ ":macro_expr_factory", "//common:expr", "//common:operators", "//internal:lexis", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_library( name = "macro_registry", srcs = [ "macro_registry.cc", ], hdrs = [ "macro_registry.h", ], deps = [ ":macro", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_test( name = "macro_registry_test", srcs = ["macro_registry_test.cc"], deps = [ ":macro", ":macro_registry", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", ], ) cc_library( name = "macro_expr_factory", srcs = ["macro_expr_factory.cc"], hdrs = ["macro_expr_factory.h"], deps = [ "//common:constant", "//common:expr", "//common:expr_factory", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], ) cc_test( name = "macro_expr_factory_test", srcs = ["macro_expr_factory_test.cc"], deps = [ ":macro_expr_factory", "//common:expr", "//common:expr_factory", "//internal:testing", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) cc_library( name = "source_factory", hdrs = [ "source_factory.h", ], ) cc_library( name = "options", hdrs = ["options.h"], deps = [ "//parser/internal:options", "@com_google_absl//absl/base:core_headers", ], ) cc_test( name = "parser_test", srcs = ["parser_test.cc"], deps = [ ":macro", ":options", ":parser", ":parser_interface", ":source_factory", "//common:constant", "//common:expr", "//common:source", "//internal:testing", "//testutil:expr_printer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "parser_benchmarks", srcs = ["parser_benchmarks.cc"], tags = ["benchmark"], deps = [ ":macro", ":options", ":parser", ":source_factory", "//common:constant", "//common:expr", "//common:source", "//internal:benchmark", "//internal:testing", "//testutil:expr_printer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "standard_macros", srcs = ["standard_macros.cc"], hdrs = ["standard_macros.h"], deps = [ ":macro", ":macro_registry", ":options", "//internal:status_macros", "@com_google_absl//absl/status", ], ) cc_library( name = "parser_interface", hdrs = ["parser_interface.h"], deps = [ ":macro", ":options", "//common:ast", "//common:source", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_library( name = "parser_subset_factory", srcs = ["parser_subset_factory.cc"], hdrs = ["parser_subset_factory.h"], deps = [ ":macro", ":parser_interface", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) cc_test( name = "standard_macros_test", srcs = ["standard_macros_test.cc"], deps = [ ":macro_registry", ":options", ":parser", ":standard_macros", "//common:source", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", ], ) ================================================ FILE: parser/internal/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("//bazel:antlr.bzl", "antlr_cc_library") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "options", hdrs = ["options.h"], ) antlr_cc_library( name = "cel", src = "Cel.g4", package = "cel_parser_internal", ) ================================================ FILE: parser/internal/Cel.g4 ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. grammar Cel; // Grammar Rules // ============= start : e=expr EOF ; expr : e=conditionalOr (op='?' e1=conditionalOr ':' e2=expr)? ; conditionalOr : e=conditionalAnd (ops+='||' e1+=conditionalAnd)* ; conditionalAnd : e=relation (ops+='&&' e1+=relation)* ; relation : calc | relation op=('<'|'<='|'>='|'>'|'=='|'!='|'in') relation ; calc : unary | calc op=('*'|'/'|'%') calc | calc op=('+'|'-') calc ; unary : member # MemberExpr | (ops+='!')+ member # LogicalNot | (ops+='-')+ member # Negate ; member : primary # PrimaryExpr | member op='.' (opt='?')? id=escapeIdent # Select | member op='.' id=IDENTIFIER open='(' args=exprList? ')' # MemberCall | member op='[' (opt='?')? index=expr ']' # Index ; primary : leadingDot='.'? id=IDENTIFIER # Ident | leadingDot='.'? id=IDENTIFIER (op='(' args=exprList? ')') # GlobalCall | '(' e=expr ')' # Nested | op='[' elems=listInit? ','? ']' # CreateList | op='{' entries=mapInitializerList? ','? '}' # CreateMap | leadingDot='.'? ids+=IDENTIFIER (ops+='.' ids+=IDENTIFIER)* op='{' entries=fieldInitializerList? ','? '}' # CreateMessage | literal # ConstantLiteral ; exprList : e+=expr (',' e+=expr)* ; listInit : elems+=optExpr (',' elems+=optExpr)* ; fieldInitializerList : fields+=optField cols+=':' values+=expr (',' fields+=optField cols+=':' values+=expr)* ; optField : (opt='?')? escapeIdent ; mapInitializerList : keys+=optExpr cols+=':' values+=expr (',' keys+=optExpr cols+=':' values+=expr)* ; escapeIdent : id=IDENTIFIER # SimpleIdentifier | id=ESC_IDENTIFIER # EscapedIdentifier ; optExpr : (opt='?')? e=expr ; literal : sign=MINUS? tok=NUM_INT # Int | tok=NUM_UINT # Uint | sign=MINUS? tok=NUM_FLOAT # Double | tok=STRING # String | tok=BYTES # Bytes | tok=CEL_TRUE # BoolTrue | tok=CEL_FALSE # BoolFalse | tok=NUL # Null ; // Lexer Rules // =========== EQUALS : '=='; NOT_EQUALS : '!='; IN: 'in'; LESS : '<'; LESS_EQUALS : '<='; GREATER_EQUALS : '>='; GREATER : '>'; LOGICAL_AND : '&&'; LOGICAL_OR : '||'; LBRACKET : '['; RPRACKET : ']'; LBRACE : '{'; RBRACE : '}'; LPAREN : '('; RPAREN : ')'; DOT : '.'; COMMA : ','; MINUS : '-'; EXCLAM : '!'; QUESTIONMARK : '?'; COLON : ':'; PLUS : '+'; STAR : '*'; SLASH : '/'; PERCENT : '%'; CEL_TRUE : 'true'; CEL_FALSE : 'false'; NUL : 'null'; fragment BACKSLASH : '\\'; fragment LETTER : 'A'..'Z' | 'a'..'z' ; fragment DIGIT : '0'..'9' ; fragment EXPONENT : ('e' | 'E') ( '+' | '-' )? DIGIT+ ; fragment HEXDIGIT : ('0'..'9'|'a'..'f'|'A'..'F') ; fragment RAW : 'r' | 'R'; fragment ESC_SEQ : ESC_CHAR_SEQ | ESC_BYTE_SEQ | ESC_UNI_SEQ | ESC_OCT_SEQ ; fragment ESC_CHAR_SEQ : BACKSLASH ('a'|'b'|'f'|'n'|'r'|'t'|'v'|'"'|'\''|'\\'|'?'|'`') ; fragment ESC_OCT_SEQ : BACKSLASH ('0'..'3') ('0'..'7') ('0'..'7') ; fragment ESC_BYTE_SEQ : BACKSLASH ( 'x' | 'X' ) HEXDIGIT HEXDIGIT ; fragment ESC_UNI_SEQ : BACKSLASH 'u' HEXDIGIT HEXDIGIT HEXDIGIT HEXDIGIT | BACKSLASH 'U' HEXDIGIT HEXDIGIT HEXDIGIT HEXDIGIT HEXDIGIT HEXDIGIT HEXDIGIT HEXDIGIT ; WHITESPACE : ( '\t' | ' ' | '\r' | '\n'| '\u000C' )+ -> channel(HIDDEN) ; COMMENT : '//' (~'\n')* -> channel(HIDDEN) ; NUM_FLOAT : ( DIGIT+ ('.' DIGIT+) EXPONENT? | DIGIT+ EXPONENT | '.' DIGIT+ EXPONENT? ) ; NUM_INT : ( DIGIT+ | '0x' HEXDIGIT+ ); NUM_UINT : DIGIT+ ( 'u' | 'U' ) | '0x' HEXDIGIT+ ( 'u' | 'U' ) ; STRING : '"' (ESC_SEQ | ~('\\'|'"'|'\n'|'\r'))* '"' | '\'' (ESC_SEQ | ~('\\'|'\''|'\n'|'\r'))* '\'' | '"""' (ESC_SEQ | ~('\\'))*? '"""' | '\'\'\'' (ESC_SEQ | ~('\\'))*? '\'\'\'' | RAW '"' ~('"'|'\n'|'\r')* '"' | RAW '\'' ~('\''|'\n'|'\r')* '\'' | RAW '"""' .*? '"""' | RAW '\'\'\'' .*? '\'\'\'' ; BYTES : ('b' | 'B') STRING; IDENTIFIER : (LETTER | '_') ( LETTER | DIGIT | '_')*; ESC_IDENTIFIER : '`' (LETTER | DIGIT | '_' | '.' | '-' | '/' | ' ')+ '`'; ================================================ FILE: parser/internal/options.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_INTERNAL_OPTIONS_H_ #define THIRD_PARTY_CEL_CPP_PARSER_INTERNAL_OPTIONS_H_ namespace cel_parser_internal { inline constexpr int kDefaultErrorRecoveryLimit = 12; inline constexpr int kDefaultMaxRecursionDepth = 32; inline constexpr int kExpressionSizeCodepointLimit = 100'000; inline constexpr int kDefaultErrorRecoveryTokenLookaheadLimit = 512; inline constexpr bool kDefaultAddMacroCalls = false; } // namespace cel_parser_internal #endif // THIRD_PARTY_CEL_CPP_PARSER_INTERNAL_OPTIONS_H_ ================================================ FILE: parser/macro.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/macro.h" #include #include #include #include #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/expr.h" #include "common/operators.h" #include "internal/lexis.h" #include "parser/macro_expr_factory.h" namespace cel { namespace { using google::api::expr::common::CelOperator; inline MacroExpander ToMacroExpander(GlobalMacroExpander expander) { ABSL_DCHECK(expander); return [expander = std::move(expander)]( MacroExprFactory& factory, absl::optional> target, absl::Span arguments) -> absl::optional { ABSL_DCHECK(!target.has_value()); return (expander)(factory, arguments); }; } inline MacroExpander ToMacroExpander(ReceiverMacroExpander expander) { ABSL_DCHECK(expander); return [expander = std::move(expander)]( MacroExprFactory& factory, absl::optional> target, absl::Span arguments) -> absl::optional { ABSL_DCHECK(target.has_value()); return (expander)(factory, *target, arguments); }; } absl::optional ExpandHasMacro(MacroExprFactory& factory, absl::Span args) { if (args.size() != 1) { return factory.ReportError("has() requires 1 arguments"); } if (!args[0].has_select_expr() || args[0].select_expr().test_only()) { return factory.ReportErrorAt(args[0], "has() argument must be a field selection"); } return factory.NewPresenceTest( args[0].mutable_select_expr().release_operand(), args[0].mutable_select_expr().release_field()); } Macro MakeHasMacro() { auto macro_or_status = Macro::Global(CelOperator::HAS, 1, ExpandHasMacro); ABSL_CHECK_OK(macro_or_status); // Crash OK return std::move(*macro_or_status); } absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 2) { return factory.ReportError("all() requires 2 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "all() variable name must be a simple identifier"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("all() variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = factory.NewCall(CelOperator::NOT_STRICTLY_FALSE, factory.NewAccuIdent()); auto step = factory.NewCall(CelOperator::LOGICAL_AND, factory.NewAccuIdent(), std::move(args[1])); auto result = factory.NewAccuIdent(); return factory.NewComprehension(args[0].ident_expr().name(), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } Macro MakeAllMacro() { auto status_or_macro = Macro::Receiver(CelOperator::ALL, 2, ExpandAllMacro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 2) { return factory.ReportError("exists() requires 2 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "exists() variable name must be a simple identifier"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( CelOperator::NOT_STRICTLY_FALSE, factory.NewCall(CelOperator::LOGICAL_NOT, factory.NewAccuIdent())); auto step = factory.NewCall(CelOperator::LOGICAL_OR, factory.NewAccuIdent(), std::move(args[1])); auto result = factory.NewAccuIdent(); return factory.NewComprehension(args[0].ident_expr().name(), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } Macro MakeExistsMacro() { auto status_or_macro = Macro::Receiver(CelOperator::EXISTS, 2, ExpandExistsMacro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 2) { return factory.ReportError("exists_one() requires 2 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "exists_one() variable name must be a simple identifier"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists_one() variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); auto accu_ident = factory.NewAccuIdent(); auto const_1 = factory.NewIntConst(1); auto inc_step = factory.NewCall(CelOperator::ADD, std::move(accu_ident), std::move(const_1)); auto step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]), std::move(inc_step), factory.NewAccuIdent()); accu_ident = factory.NewAccuIdent(); auto result = factory.NewCall(CelOperator::EQUALS, std::move(accu_ident), factory.NewIntConst(1)); return factory.NewComprehension(args[0].ident_expr().name(), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), std::move(result)); } Macro MakeExistsOneMacro() { auto status_or_macro = Macro::Receiver(CelOperator::EXISTS_ONE, 2, ExpandExistsOneMacro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 2) { return factory.ReportError("map() requires 2 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("map() variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); auto accu_ref = factory.NewAccuIdent(); auto accu_update = factory.NewList(factory.NewListElement(std::move(args[1]))); auto step = factory.NewCall(CelOperator::ADD, std::move(accu_ref), std::move(accu_update)); return factory.NewComprehension(args[0].ident_expr().name(), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), factory.NewAccuIdent()); } Macro MakeMap2Macro() { auto status_or_macro = Macro::Receiver(CelOperator::MAP, 2, ExpandMap2Macro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("map() requires 3 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("map() variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); auto accu_ref = factory.NewAccuIdent(); auto accu_update = factory.NewList(factory.NewListElement(std::move(args[2]))); auto step = factory.NewCall(CelOperator::ADD, std::move(accu_ref), std::move(accu_update)); step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]), std::move(step), factory.NewAccuIdent()); return factory.NewComprehension(args[0].ident_expr().name(), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), factory.NewAccuIdent()); } Macro MakeMap3Macro() { auto status_or_macro = Macro::Receiver(CelOperator::MAP, 3, ExpandMap3Macro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 2) { return factory.ReportError("filter() requires 2 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "filter() variable name must be a simple identifier"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("filter() variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto name = args[0].ident_expr().name(); auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); auto accu_ref = factory.NewAccuIdent(); auto accu_update = factory.NewList(factory.NewListElement(std::move(args[0]))); auto step = factory.NewCall(CelOperator::ADD, std::move(accu_ref), std::move(accu_update)); step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]), std::move(step), factory.NewAccuIdent()); return factory.NewComprehension(std::move(name), std::move(target), factory.AccuVarName(), std::move(init), std::move(condition), std::move(step), factory.NewAccuIdent()); } Macro MakeFilterMacro() { auto status_or_macro = Macro::Receiver(CelOperator::FILTER, 2, ExpandFilterMacro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 2) { return factory.ReportError("optMap() requires 2 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "optMap() variable name must be a simple identifier"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optMap() variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); auto target_copy = factory.Copy(target); std::vector call_args; call_args.reserve(3); call_args.push_back(factory.NewMemberCall("hasValue", std::move(target))); auto iter_range = factory.NewList(); auto accu_init = factory.NewMemberCall("value", std::move(target_copy)); auto condition = factory.NewBoolConst(false); auto fold = factory.NewComprehension( "#unused", std::move(iter_range), std::move(var_name), std::move(accu_init), std::move(condition), std::move(args[0]), std::move(args[1])); call_args.push_back(factory.NewCall("optional.of", std::move(fold))); call_args.push_back(factory.NewCall("optional.none")); return factory.NewCall(CelOperator::CONDITIONAL, std::move(call_args)); } Macro MakeOptMapMacro() { auto status_or_macro = Macro::Receiver("optMap", 2, ExpandOptMapMacro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 2) { return factory.ReportError("optFlatMap() requires 2 arguments"); } if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { return factory.ReportErrorAt( args[0], "optFlatMap() variable name must be a simple identifier"); } if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optFlatMap() variable name cannot be ", kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); auto target_copy = factory.Copy(target); std::vector call_args; call_args.reserve(3); call_args.push_back(factory.NewMemberCall("hasValue", std::move(target))); auto iter_range = factory.NewList(); auto accu_init = factory.NewMemberCall("value", std::move(target_copy)); auto condition = factory.NewBoolConst(false); call_args.push_back(factory.NewComprehension( "#unused", std::move(iter_range), std::move(var_name), std::move(accu_init), std::move(condition), std::move(args[0]), std::move(args[1]))); call_args.push_back(factory.NewCall("optional.none")); return factory.NewCall(CelOperator::CONDITIONAL, std::move(call_args)); } Macro MakeOptFlatMapMacro() { auto status_or_macro = Macro::Receiver("optFlatMap", 2, ExpandOptFlatMapMacro); ABSL_CHECK_OK(status_or_macro); // Crash OK return std::move(*status_or_macro); } } // namespace absl::StatusOr Macro::Global(absl::string_view name, size_t argument_count, GlobalMacroExpander expander) { if (!expander) { return absl::InvalidArgumentError( absl::StrCat("macro expander for `", name, "` cannot be empty")); } return Make(name, argument_count, ToMacroExpander(std::move(expander)), /*receiver_style=*/false, /*var_arg_style=*/false); } absl::StatusOr Macro::GlobalVarArg(absl::string_view name, GlobalMacroExpander expander) { if (!expander) { return absl::InvalidArgumentError( absl::StrCat("macro expander for `", name, "` cannot be empty")); } return Make(name, 0, ToMacroExpander(std::move(expander)), /*receiver_style=*/false, /*var_arg_style=*/true); } absl::StatusOr Macro::Receiver(absl::string_view name, size_t argument_count, ReceiverMacroExpander expander) { if (!expander) { return absl::InvalidArgumentError( absl::StrCat("macro expander for `", name, "` cannot be empty")); } return Make(name, argument_count, ToMacroExpander(std::move(expander)), /*receiver_style=*/true, /*var_arg_style=*/false); } absl::StatusOr Macro::ReceiverVarArg(absl::string_view name, ReceiverMacroExpander expander) { if (!expander) { return absl::InvalidArgumentError( absl::StrCat("macro expander for `", name, "` cannot be empty")); } return Make(name, 0, ToMacroExpander(std::move(expander)), /*receiver_style=*/true, /*var_arg_style=*/true); } std::vector Macro::AllMacros() { return {HasMacro(), AllMacro(), ExistsMacro(), ExistsOneMacro(), Map2Macro(), Map3Macro(), FilterMacro()}; } std::string Macro::Key(absl::string_view name, size_t argument_count, bool receiver_style, bool var_arg_style) { if (var_arg_style) { return absl::StrCat(name, ":*:", receiver_style ? "true" : "false"); } return absl::StrCat(name, ":", argument_count, ":", receiver_style ? "true" : "false"); } absl::StatusOr Macro::Make(absl::string_view name, size_t argument_count, MacroExpander expander, bool receiver_style, bool var_arg_style) { if (!internal::LexisIsIdentifier(name)) { return absl::InvalidArgumentError(absl::StrCat( "macro function name `", name, "` is not a valid identifier")); } if (!expander) { return absl::InvalidArgumentError( absl::StrCat("macro expander for `", name, "` cannot be empty")); } return Macro(std::make_shared( std::string(name), Key(name, argument_count, receiver_style, var_arg_style), argument_count, std::move(expander), receiver_style, var_arg_style)); } const Macro& HasMacro() { static const absl::NoDestructor macro(MakeHasMacro()); return *macro; } const Macro& AllMacro() { static const absl::NoDestructor macro(MakeAllMacro()); return *macro; } const Macro& ExistsMacro() { static const absl::NoDestructor macro(MakeExistsMacro()); return *macro; } const Macro& ExistsOneMacro() { static const absl::NoDestructor macro(MakeExistsOneMacro()); return *macro; } const Macro& Map2Macro() { static const absl::NoDestructor macro(MakeMap2Macro()); return *macro; } const Macro& Map3Macro() { static const absl::NoDestructor macro(MakeMap3Macro()); return *macro; } const Macro& FilterMacro() { static const absl::NoDestructor macro(MakeFilterMacro()); return *macro; } const Macro& OptMapMacro() { static const absl::NoDestructor macro(MakeOptMapMacro()); return *macro; } const Macro& OptFlatMapMacro() { static const absl::NoDestructor macro(MakeOptFlatMapMacro()); return *macro; } } // namespace cel ================================================ FILE: parser/macro.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_MACRO_H_ #define THIRD_PARTY_CEL_CPP_PARSER_MACRO_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/expr.h" #include "parser/macro_expr_factory.h" namespace cel { // MacroExpander converts the arguments of a function call that matches a // Macro. // // If this is a receiver-style macro, the second argument (optional expr) will // be engaged. In the case of a global call, it will be `absl::nullopt`. // // Should return the replacement subexpression if replacement should occur, // otherwise absl::nullopt. If `absl::nullopt` is returned, none of the // arguments including the target must have been modified. Doing so is undefined // behavior. Otherwise the expander is free to mutate the arguments and either // include or exclude them from the result. // // We use `std::reference_wrapper` to be consistent with the fact that we // do not use raw pointers elsewhere with `Expr` and friends. Ideally we would // just use `absl::optional`, but that is not currently allowed and our // `optional_ref` is internal. using MacroExpander = absl::AnyInvocable( MacroExprFactory&, absl::optional>, absl::Span) const>; // `GlobalMacroExpander` is a `MacroExpander` for global macros. using GlobalMacroExpander = absl::AnyInvocable( MacroExprFactory&, absl::Span) const>; // `ReceiverMacroExpander` is a `MacroExpander` for receiver-style macros. using ReceiverMacroExpander = absl::AnyInvocable( MacroExprFactory&, Expr&, absl::Span) const>; // Macro interface for describing the function signature to match and the // MacroExpander to apply. // // Note: when a Macro should apply to multiple overloads (based on arg count) of // a given function, a Macro should be created per arg-count. class Macro final { public: static absl::StatusOr Global(absl::string_view name, size_t argument_count, GlobalMacroExpander expander); static absl::StatusOr GlobalVarArg(absl::string_view name, GlobalMacroExpander expander); static absl::StatusOr Receiver(absl::string_view name, size_t argument_count, ReceiverMacroExpander expander); static absl::StatusOr ReceiverVarArg(absl::string_view name, ReceiverMacroExpander expander); Macro(const Macro&) = default; Macro(Macro&&) = default; Macro& operator=(const Macro&) = default; Macro& operator=(Macro&&) = default; // Function name to match. absl::string_view function() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return rep_->function; } // argument_count() for the function call. // // When the macro is a var-arg style macro, the return value will be zero, but // the MacroKey will contain a `*` where the arg count would have been. size_t argument_count() const { return rep_->arg_count; } // is_receiver_style returns true if the macro matches a receiver style call. bool is_receiver_style() const { return rep_->receiver_style; } bool is_variadic() const { return rep_->var_arg_style; } // key() returns the macro signatures accepted by this macro. // // Format: `::`. // // When the macros is a var-arg style macro, the `arg-count` value is // represented as a `*`. absl::string_view key() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return rep_->key; } // Expander returns the MacroExpander to apply when the macro key matches the // parsed call signature. const MacroExpander& expander() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return rep_->expander; } ABSL_MUST_USE_RESULT absl::optional Expand( MacroExprFactory& factory, absl::optional> target, absl::Span arguments) const { return (expander())(factory, target, arguments); } friend void swap(Macro& lhs, Macro& rhs) noexcept { using std::swap; swap(lhs.rep_, rhs.rep_); } ABSL_DEPRECATED("use MacroRegistry and RegisterStandardMacros") static std::vector AllMacros(); private: struct Rep final { Rep(std::string function, std::string key, size_t arg_count, MacroExpander expander, bool receiver_style, bool var_arg_style) : function(std::move(function)), key(std::move(key)), arg_count(arg_count), expander(std::move(expander)), receiver_style(receiver_style), var_arg_style(var_arg_style) {} std::string function; std::string key; size_t arg_count; MacroExpander expander; bool receiver_style; bool var_arg_style; }; static std::string Key(absl::string_view name, size_t argument_count, bool receiver_style, bool var_arg_style); static absl::StatusOr Make(absl::string_view name, size_t argument_count, MacroExpander expander, bool receiver_style, bool var_arg_style); explicit Macro(std::shared_ptr rep) : rep_(std::move(rep)) {} std::shared_ptr rep_; }; // The macro "has(m.f)" which tests the presence of a field, avoiding the // need to specify the field as a string. const Macro& HasMacro(); // The macro "range.all(var, predicate)", which is true if for all // elements in range the predicate holds. const Macro& AllMacro(); // The macro "range.exists(var, predicate)", which is true if for at least // one element in range the predicate holds. const Macro& ExistsMacro(); // The macro "range.exists_one(var, predicate)", which is true if for // exactly one element in range the predicate holds. const Macro& ExistsOneMacro(); // The macro "range.map(var, function)", applies the function to the vars // in the range. const Macro& Map2Macro(); // The macro "range.map(var, predicate, function)", applies the function // to the vars in the range for which the predicate holds true. The other // variables are filtered out. const Macro& Map3Macro(); // The macro "range.filter(var, predicate)", filters out the variables for // which the predicate is false. const Macro& FilterMacro(); // `OptMapMacro` // // Apply a transformation to the optional's underlying value if it is not empty // and return an optional typed result based on the transformation. The // transformation expression type must return a type T which is wrapped into // an optional. // // msg.?elements.optMap(e, e.size()).orValue(0) const Macro& OptMapMacro(); // `OptFlatMapMacro` // // Apply a transformation to the optional's underlying value if it is not empty // and return the result. The transform expression must return an optional(T) // rather than type T. This can be useful when dealing with zero values and // conditionally generating an empty or non-empty result in ways which cannot // be expressed with `optMap`. // // msg.?elements.optFlatMap(e, e[?0]) // return the first element if present. const Macro& OptFlatMapMacro(); } // namespace cel namespace google::api::expr::parser { using MacroExpander = cel::MacroExpander; using Macro = cel::Macro; } // namespace google::api::expr::parser #endif // THIRD_PARTY_CEL_CPP_PARSER_MACRO_H_ ================================================ FILE: parser/macro_expr_factory.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/macro_expr_factory.h" #include #include #include "absl/functional/overload.h" #include "absl/types/optional.h" #include "absl/types/variant.h" #include "common/constant.h" #include "common/expr.h" namespace cel { Expr MacroExprFactory::Copy(const Expr& expr) { // Copying logic is recursive at the moment, we alter it to be iterative in // the future. return absl::visit( absl::Overload( [this, &expr](const UnspecifiedExpr&) -> Expr { return NewUnspecified(CopyId(expr)); }, [this, &expr](const Constant& const_expr) -> Expr { return NewConst(CopyId(expr), const_expr); }, [this, &expr](const IdentExpr& ident_expr) -> Expr { return NewIdent(CopyId(expr), ident_expr.name()); }, [this, &expr](const SelectExpr& select_expr) -> Expr { const auto id = CopyId(expr); return select_expr.test_only() ? NewPresenceTest(id, Copy(select_expr.operand()), select_expr.field()) : NewSelect(id, Copy(select_expr.operand()), select_expr.field()); }, [this, &expr](const CallExpr& call_expr) -> Expr { const auto id = CopyId(expr); absl::optional target; if (call_expr.has_target()) { target = Copy(call_expr.target()); } std::vector args; args.reserve(call_expr.args().size()); for (const auto& arg : call_expr.args()) { args.push_back(Copy(arg)); } return target.has_value() ? NewMemberCall(id, call_expr.function(), std::move(*target), std::move(args)) : NewCall(id, call_expr.function(), std::move(args)); }, [this, &expr](const ListExpr& list_expr) -> Expr { const auto id = CopyId(expr); std::vector elements; elements.reserve(list_expr.elements().size()); for (const auto& element : list_expr.elements()) { elements.push_back(Copy(element)); } return NewList(id, std::move(elements)); }, [this, &expr](const StructExpr& struct_expr) -> Expr { const auto id = CopyId(expr); std::vector fields; fields.reserve(struct_expr.fields().size()); for (const auto& field : struct_expr.fields()) { fields.push_back(Copy(field)); } return NewStruct(id, struct_expr.name(), std::move(fields)); }, [this, &expr](const MapExpr& map_expr) -> Expr { const auto id = CopyId(expr); std::vector entries; entries.reserve(map_expr.entries().size()); for (const auto& entry : map_expr.entries()) { entries.push_back(Copy(entry)); } return NewMap(id, std::move(entries)); }, [this, &expr](const ComprehensionExpr& comprehension_expr) -> Expr { const auto id = CopyId(expr); auto iter_range = Copy(comprehension_expr.iter_range()); auto accu_init = Copy(comprehension_expr.accu_init()); auto loop_condition = Copy(comprehension_expr.loop_condition()); auto loop_step = Copy(comprehension_expr.loop_step()); auto result = Copy(comprehension_expr.result()); return NewComprehension( id, comprehension_expr.iter_var(), std::move(iter_range), comprehension_expr.accu_var(), std::move(accu_init), std::move(loop_condition), std::move(loop_step), std::move(result)); }), expr.kind()); } ListExprElement MacroExprFactory::Copy(const ListExprElement& element) { return NewListElement(Copy(element.expr()), element.optional()); } StructExprField MacroExprFactory::Copy(const StructExprField& field) { auto field_id = CopyId(field.id()); auto field_value = Copy(field.value()); return NewStructField(field_id, field.name(), std::move(field_value), field.optional()); } MapExprEntry MacroExprFactory::Copy(const MapExprEntry& entry) { auto entry_id = CopyId(entry.id()); auto entry_key = Copy(entry.key()); auto entry_value = Copy(entry.value()); return NewMapEntry(entry_id, std::move(entry_key), std::move(entry_value), entry.optional()); } } // namespace cel ================================================ FILE: parser/macro_expr_factory.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_MACRO_EXPR_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_PARSER_MACRO_EXPR_FACTORY_H_ #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "common/expr.h" #include "common/expr_factory.h" namespace cel { class ParserMacroExprFactory; class TestMacroExprFactory; // `MacroExprFactory` is a specialization of `ExprFactory` for `MacroExpander` // which disallows explicitly specifying IDs. class MacroExprFactory : protected ExprFactory { protected: using ExprFactory::IsArrayLike; using ExprFactory::IsExprLike; using ExprFactory::IsStringLike; template struct IsRValue : std::bool_constant< std::disjunction_v, std::is_same>> {}; public: ABSL_MUST_USE_RESULT Expr Copy(const Expr& expr); ABSL_MUST_USE_RESULT ListExprElement Copy(const ListExprElement& element); ABSL_MUST_USE_RESULT StructExprField Copy(const StructExprField& field); ABSL_MUST_USE_RESULT MapExprEntry Copy(const MapExprEntry& entry); ABSL_MUST_USE_RESULT Expr NewUnspecified() { return NewUnspecified(NextId()); } ABSL_MUST_USE_RESULT Expr NewNullConst() { return NewNullConst(NextId()); } ABSL_MUST_USE_RESULT Expr NewBoolConst(bool value) { return NewBoolConst(NextId(), value); } ABSL_MUST_USE_RESULT Expr NewIntConst(int64_t value) { return NewIntConst(NextId(), value); } ABSL_MUST_USE_RESULT Expr NewUintConst(uint64_t value) { return NewUintConst(NextId(), value); } ABSL_MUST_USE_RESULT Expr NewDoubleConst(double value) { return NewDoubleConst(NextId(), value); } ABSL_MUST_USE_RESULT Expr NewBytesConst(std::string value) { return NewBytesConst(NextId(), std::move(value)); } ABSL_MUST_USE_RESULT Expr NewBytesConst(absl::string_view value) { return NewBytesConst(NextId(), value); } ABSL_MUST_USE_RESULT Expr NewBytesConst(const char* absl_nullable value) { return NewBytesConst(NextId(), value); } ABSL_MUST_USE_RESULT Expr NewStringConst(std::string value) { return NewStringConst(NextId(), std::move(value)); } ABSL_MUST_USE_RESULT Expr NewStringConst(absl::string_view value) { return NewStringConst(NextId(), value); } ABSL_MUST_USE_RESULT Expr NewStringConst(const char* absl_nullable value) { return NewStringConst(NextId(), value); } template ::value>> ABSL_MUST_USE_RESULT Expr NewIdent(Name name) { return NewIdent(NextId(), std::move(name)); } absl::string_view AccuVarName() { return ExprFactory::AccuVarName(); } ABSL_MUST_USE_RESULT Expr NewAccuIdent() { return NewAccuIdent(NextId()); } template ::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT Expr NewSelect(Operand operand, Field field) { return NewSelect(NextId(), std::move(operand), std::move(field)); } template ::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT Expr NewPresenceTest(Operand operand, Field field) { return NewPresenceTest(NextId(), std::move(operand), std::move(field)); } template < typename Function, typename... Args, typename = std::enable_if_t::value>, typename = std::enable_if_t...>>> ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args&&... args) { std::vector array; array.reserve(sizeof...(Args)); (array.push_back(std::forward(args)), ...); return NewCall(NextId(), std::move(function), std::move(array)); } template ::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args args) { return NewCall(NextId(), std::move(function), std::move(args)); } template < typename Function, typename Target, typename... Args, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t...>>> ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, Args&&... args) { std::vector array; array.reserve(sizeof...(Args)); (array.push_back(std::forward(args)), ...); return NewMemberCall(NextId(), std::move(function), std::move(target), std::move(array)); } template ::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, Args args) { return NewMemberCall(NextId(), std::move(function), std::move(target), std::move(args)); } using ExprFactory::NewListElement; template ...>>> ABSL_MUST_USE_RESULT Expr NewList(Elements&&... elements) { std::vector array; array.reserve(sizeof...(Elements)); (array.push_back(std::forward(elements)), ...); return NewList(NextId(), std::move(array)); } template ::value>> ABSL_MUST_USE_RESULT Expr NewList(Elements elements) { return NewList(NextId(), std::move(elements)); } template ::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT StructExprField NewStructField(Name name, Value value, bool optional = false) { return NewStructField(NextId(), std::move(name), std::move(value), optional); } template ::value>, typename = std::enable_if_t< std::conjunction_v...>>> ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields&&... fields) { std::vector array; array.reserve(sizeof...(Fields)); (array.push_back(std::forward(fields)), ...); return NewStruct(NextId(), std::move(name), std::move(array)); } template < typename Name, typename Fields, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields fields) { return NewStruct(NextId(), std::move(name), std::move(fields)); } template ::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT MapExprEntry NewMapEntry(Key key, Value value, bool optional = false) { return NewMapEntry(NextId(), std::move(key), std::move(value), optional); } template ...>>> ABSL_MUST_USE_RESULT Expr NewMap(Entries&&... entries) { std::vector array; array.reserve(sizeof...(Entries)); (array.push_back(std::forward(entries)), ...); return NewMap(NextId(), std::move(array)); } template ::value>> ABSL_MUST_USE_RESULT Expr NewMap(Entries entries) { return NewMap(NextId(), std::move(entries)); } template ::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT Expr NewComprehension(IterVar iter_var, IterRange iter_range, AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, LoopStep loop_step, Result result) { return NewComprehension(NextId(), std::move(iter_var), std::move(iter_range), std::move(accu_var), std::move(accu_init), std::move(loop_condition), std::move(loop_step), std::move(result)); } template ::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>, typename = std::enable_if_t::value>> ABSL_MUST_USE_RESULT Expr NewComprehension( IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, LoopStep loop_step, Result result) { return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), std::move(iter_range), std::move(accu_var), std::move(accu_init), std::move(loop_condition), std::move(loop_step), std::move(result)); } ABSL_MUST_USE_RESULT virtual Expr ReportError(absl::string_view message) = 0; ABSL_MUST_USE_RESULT virtual Expr ReportErrorAt( const Expr& expr, absl::string_view message) = 0; protected: using ExprFactory::AccuVarName; using ExprFactory::NewAccuIdent; using ExprFactory::NewBoolConst; using ExprFactory::NewBytesConst; using ExprFactory::NewCall; using ExprFactory::NewComprehension; using ExprFactory::NewConst; using ExprFactory::NewDoubleConst; using ExprFactory::NewIdent; using ExprFactory::NewIntConst; using ExprFactory::NewList; using ExprFactory::NewMap; using ExprFactory::NewMapEntry; using ExprFactory::NewMemberCall; using ExprFactory::NewNullConst; using ExprFactory::NewPresenceTest; using ExprFactory::NewSelect; using ExprFactory::NewStringConst; using ExprFactory::NewStruct; using ExprFactory::NewStructField; using ExprFactory::NewUintConst; using ExprFactory::NewUnspecified; ABSL_MUST_USE_RESULT virtual ExprId NextId() = 0; ABSL_MUST_USE_RESULT virtual ExprId CopyId(ExprId id) = 0; ABSL_MUST_USE_RESULT ExprId CopyId(const Expr& expr) { return CopyId(expr.id()); } private: friend class ParserMacroExprFactory; friend class TestMacroExprFactory; explicit MacroExprFactory() = default; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_MACRO_EXPR_FACTORY_H_ ================================================ FILE: parser/macro_expr_factory_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/macro_expr_factory.h" #include #include #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/expr.h" #include "common/expr_factory.h" #include "internal/testing.h" namespace cel { class TestMacroExprFactory final : public MacroExprFactory { public: TestMacroExprFactory() = default; ExprId id() const { return id_; } Expr ReportError(absl::string_view) override { return NewUnspecified(NextId()); } Expr ReportErrorAt(const Expr&, absl::string_view) override { return NewUnspecified(NextId()); } using MacroExprFactory::NewBoolConst; using MacroExprFactory::NewCall; using MacroExprFactory::NewComprehension; using MacroExprFactory::NewIdent; using MacroExprFactory::NewList; using MacroExprFactory::NewListElement; using MacroExprFactory::NewMap; using MacroExprFactory::NewMapEntry; using MacroExprFactory::NewMemberCall; using MacroExprFactory::NewSelect; using MacroExprFactory::NewStruct; using MacroExprFactory::NewStructField; using MacroExprFactory::NewUnspecified; protected: ExprId NextId() override { return id_++; } ExprId CopyId(ExprId id) override { if (id == 0) { return 0; } return NextId(); } private: int64_t id_ = 1; }; namespace { TEST(MacroExprFactory, CopyUnspecified) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); } TEST(MacroExprFactory, CopyIdent) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewIdent("foo")), factory.NewIdent(2, "foo")); } TEST(MacroExprFactory, CopyConst) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewBoolConst(true)), factory.NewBoolConst(2, true)); } TEST(MacroExprFactory, CopySelect) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewSelect(factory.NewIdent("foo"), "bar")), factory.NewSelect(3, factory.NewIdent(4, "foo"), "bar")); } TEST(MacroExprFactory, CopyCall) { TestMacroExprFactory factory; std::vector copied_args; copied_args.reserve(1); copied_args.push_back(factory.NewIdent(6, "baz")); EXPECT_EQ(factory.Copy(factory.NewMemberCall("bar", factory.NewIdent("foo"), factory.NewIdent("baz"))), factory.NewMemberCall(4, "bar", factory.NewIdent(5, "foo"), absl::MakeSpan(copied_args))); } TEST(MacroExprFactory, CopyList) { TestMacroExprFactory factory; std::vector copied_elements; copied_elements.reserve(1); copied_elements.push_back(factory.NewListElement(factory.NewIdent(4, "foo"))); EXPECT_EQ(factory.Copy(factory.NewList( factory.NewListElement(factory.NewIdent("foo")))), factory.NewList(3, absl::MakeSpan(copied_elements))); } TEST(MacroExprFactory, CopyStruct) { TestMacroExprFactory factory; std::vector copied_fields; copied_fields.reserve(1); copied_fields.push_back( factory.NewStructField(5, "bar", factory.NewIdent(6, "baz"))); EXPECT_EQ(factory.Copy(factory.NewStruct( "foo", factory.NewStructField("bar", factory.NewIdent("baz")))), factory.NewStruct(4, "foo", absl::MakeSpan(copied_fields))); } TEST(MacroExprFactory, CopyMap) { TestMacroExprFactory factory; std::vector copied_entries; copied_entries.reserve(1); copied_entries.push_back(factory.NewMapEntry(6, factory.NewIdent(7, "bar"), factory.NewIdent(8, "baz"))); EXPECT_EQ(factory.Copy(factory.NewMap(factory.NewMapEntry( factory.NewIdent("bar"), factory.NewIdent("baz")))), factory.NewMap(5, absl::MakeSpan(copied_entries))); } TEST(MacroExprFactory, CopyComprehension) { TestMacroExprFactory factory; EXPECT_EQ( factory.Copy(factory.NewComprehension( "foo", factory.NewList(), "bar", factory.NewBoolConst(true), factory.NewIdent("baz"), factory.NewIdent("foo"), factory.NewIdent("bar"))), factory.NewComprehension( 7, "foo", factory.NewList(8, std::vector()), "bar", factory.NewBoolConst(9, true), factory.NewIdent(10, "baz"), factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); } } // namespace } // namespace cel ================================================ FILE: parser/macro_registry.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/macro_registry.h" #include #include #include #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "parser/macro.h" namespace cel { absl::Status MacroRegistry::RegisterMacro(const Macro& macro) { if (!RegisterMacroImpl(macro)) { return absl::AlreadyExistsError( absl::StrCat("macro already exists: ", macro.key())); } return absl::OkStatus(); } absl::Status MacroRegistry::RegisterMacros(absl::Span macros) { for (size_t i = 0; i < macros.size(); ++i) { const auto& macro = macros[i]; if (!RegisterMacroImpl(macro)) { for (size_t j = 0; j < i; ++j) { macros_.erase(macros[j].key()); } return absl::AlreadyExistsError( absl::StrCat("macro already exists: ", macro.key())); } } return absl::OkStatus(); } absl::optional MacroRegistry::FindMacro(absl::string_view name, size_t arg_count, bool receiver_style) const { // :: if (name.empty() || absl::StrContains(name, ':')) { return absl::nullopt; } // Try argument count specific key first. auto key = absl::StrCat(name, ":", arg_count, ":", receiver_style ? "true" : "false"); if (auto it = macros_.find(key); it != macros_.end()) { return it->second; } // Next try variadic. key = absl::StrCat(name, ":*:", receiver_style ? "true" : "false"); if (auto it = macros_.find(key); it != macros_.end()) { return it->second; } return absl::nullopt; } std::vector MacroRegistry::ListMacros() const { std::vector macros; macros.reserve(macros_.size()); for (auto it = macros_.begin(); it != macros_.end(); ++it) { macros.push_back(it->second); } return macros; } bool MacroRegistry::RegisterMacroImpl(const Macro& macro) { return macros_.insert(std::pair{macro.key(), macro}).second; } } // namespace cel ================================================ FILE: parser/macro_registry.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "parser/macro.h" namespace cel { class MacroRegistry final { public: MacroRegistry() = default; // Move-only. MacroRegistry(MacroRegistry&&) = default; MacroRegistry& operator=(MacroRegistry&&) = default; // Registers `macro`. absl::Status RegisterMacro(const Macro& macro); // Registers all `macros`. If an error is encountered registering one, the // rest are not registered and the error is returned. absl::Status RegisterMacros(absl::Span macros); absl::optional FindMacro(absl::string_view name, size_t arg_count, bool receiver_style) const; // Returns a copy of all registered macros. std::vector ListMacros() const; private: bool RegisterMacroImpl(const Macro& macro); absl::flat_hash_map macros_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ ================================================ FILE: parser/macro_registry_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/macro_registry.h" #include "absl/status/status.h" #include "absl/types/optional.h" #include "internal/testing.h" #include "parser/macro.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::testing::Eq; using ::testing::Ne; TEST(MacroRegistry, RegisterAndFind) { MacroRegistry macros; EXPECT_THAT(macros.RegisterMacro(HasMacro()), IsOk()); EXPECT_THAT(macros.FindMacro("has", 1, false), Ne(absl::nullopt)); } TEST(MacroRegistry, RegisterRollsback) { MacroRegistry macros; EXPECT_THAT(macros.RegisterMacros({HasMacro(), AllMacro(), AllMacro()}), StatusIs(absl::StatusCode::kAlreadyExists)); EXPECT_THAT(macros.FindMacro("has", 1, false), Eq(absl::nullopt)); } } // namespace } // namespace cel ================================================ FILE: parser/options.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_OPTIONS_H_ #define THIRD_PARTY_CEL_CPP_PARSER_OPTIONS_H_ #include "absl/base/attributes.h" #include "parser/internal/options.h" namespace cel { // Options for configuring the limits and features of the parser. struct ParserOptions final { // Limit of the number of error recovery attempts made by the ANTLR parser // when processing an input. This limit, when reached, will halt further // parsing of the expression. int error_recovery_limit = ::cel_parser_internal::kDefaultErrorRecoveryLimit; // Limit on the amount of recursive parse instructions permitted when building // the abstract syntax tree for the expression. This prevents pathological // inputs from causing stack overflows. int max_recursion_depth = ::cel_parser_internal::kDefaultMaxRecursionDepth; // Limit on the number of codepoints in the input string which the parser will // attempt to parse. int expression_size_codepoint_limit = ::cel_parser_internal::kExpressionSizeCodepointLimit; // Limit on the number of lookahead tokens to consume when attempting to // recover from an error. int error_recovery_token_lookahead_limit = ::cel_parser_internal::kDefaultErrorRecoveryTokenLookaheadLimit; // Add macro calls to macro_calls list in source_info. bool add_macro_calls = ::cel_parser_internal::kDefaultAddMacroCalls; // Enable support for optional syntax. bool enable_optional_syntax = false; // Disable standard macros (has, all, exists, exists_one, filter, map). bool disable_standard_macros = false; // Deprecated: The builtin and extension macros now always use the new // accumulator variable name. // This option has no effect. bool enable_hidden_accumulator_var = true; // Enables support for identifier quoting syntax: // "message.`skewer-case-field`" // // Limited to field specifiers in select and message creation, // enabled by default bool enable_quoted_identifiers = true; }; } // namespace cel namespace google::api::expr::parser { using ParserOptions = ::cel::ParserOptions; ABSL_DEPRECATED("Use ParserOptions().error_recovery_limit instead.") inline constexpr int kDefaultErrorRecoveryLimit = ::cel_parser_internal::kDefaultErrorRecoveryLimit; ABSL_DEPRECATED("Use ParserOptions().max_recursion_depth instead.") inline constexpr int kDefaultMaxRecursionDepth = ::cel_parser_internal::kDefaultMaxRecursionDepth; ABSL_DEPRECATED("Use ParserOptions().expression_size_codepoint_limit instead.") inline constexpr int kExpressionSizeCodepointLimit = ::cel_parser_internal::kExpressionSizeCodepointLimit; ABSL_DEPRECATED( "Use ParserOptions().error_recovery_token_lookahead_limit instead.") inline constexpr int kDefaultErrorRecoveryTokenLookaheadLimit = ::cel_parser_internal::kDefaultErrorRecoveryTokenLookaheadLimit; ABSL_DEPRECATED("Use ParserOptions().add_macro_calls instead.") inline constexpr bool kDefaultAddMacroCalls = ::cel_parser_internal::kDefaultAddMacroCalls; } // namespace google::api::expr::parser #endif // THIRD_PARTY_CEL_CPP_PARSER_OPTIONS_H_ ================================================ FILE: parser/parser.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/parser.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/macros.h" #include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "antlr4-runtime.h" #include "common/ast.h" #include "common/ast/expr_proto.h" #include "common/ast/source_info_proto.h" #include "common/constant.h" #include "common/expr_factory.h" #include "common/operators.h" #include "common/source.h" #include "internal/lexis.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "internal/utf8.h" #pragma push_macro("IN") #undef IN #include "parser/internal/CelBaseVisitor.h" #include "parser/internal/CelLexer.h" #include "parser/internal/CelParser.h" #pragma pop_macro("IN") #include "parser/macro.h" #include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser_interface.h" #include "parser/source_factory.h" namespace google::api::expr::parser { namespace { class ParserVisitor; } } // namespace google::api::expr::parser namespace cel { namespace { constexpr const char kHiddenAccumulatorVariableName[] = "@result"; std::any ExprPtrToAny(std::unique_ptr&& expr) { return std::make_any(expr.release()); } std::any ExprToAny(Expr&& expr) { return ExprPtrToAny(std::make_unique(std::move(expr))); } std::unique_ptr ExprPtrFromAny(std::any&& any) { return absl::WrapUnique(std::any_cast(std::move(any))); } Expr ExprFromAny(std::any&& any) { auto expr = ExprPtrFromAny(std::move(any)); return std::move(*expr); } struct ParserError { std::string message; SourceRange range; }; std::string DisplayParserError(const cel::Source& source, const ParserError& error) { auto location = source.GetLocation(error.range.begin).value_or(SourceLocation{}); return absl::StrCat(absl::StrFormat("ERROR: %s:%zu:%zu: %s", source.description(), location.line, // add one to the 0-based column location.column + 1, error.message), source.DisplayErrorLocation(location)); } int32_t PositiveOrMax(int32_t value) { return value >= 0 ? value : std::numeric_limits::max(); } SourceRange SourceRangeFromToken(const antlr4::Token* token) { SourceRange range; if (token != nullptr) { if (auto start = token->getStartIndex(); start != INVALID_INDEX) { range.begin = static_cast(start); } if (auto end = token->getStopIndex(); end != INVALID_INDEX) { range.end = static_cast(end + 1); } } return range; } SourceRange SourceRangeFromParserRuleContext( const antlr4::ParserRuleContext* context) { SourceRange range; if (context != nullptr) { if (auto start = context->getStart() != nullptr ? context->getStart()->getStartIndex() : INVALID_INDEX; start != INVALID_INDEX) { range.begin = static_cast(start); } if (auto end = context->getStop() != nullptr ? context->getStop()->getStopIndex() : INVALID_INDEX; end != INVALID_INDEX) { range.end = static_cast(end + 1); } } return range; } } // namespace class ParserMacroExprFactory final : public MacroExprFactory { public: explicit ParserMacroExprFactory(const cel::Source& source) : source_(source) {} void BeginMacro(SourceRange macro_position) { macro_position_ = macro_position; } void EndMacro() { macro_position_ = SourceRange{}; } Expr ReportError(absl::string_view message) override { return ReportError(macro_position_, message); } Expr ReportError(int64_t expr_id, absl::string_view message) { return ReportError(GetSourceRange(expr_id), message); } Expr ReportError(SourceRange range, absl::string_view message) { ++error_count_; if (errors_.size() <= 100) { errors_.push_back(ParserError{std::string(message), range}); } return NewUnspecified(NextId(range)); } Expr ReportErrorAt(const Expr& expr, absl::string_view message) override { return ReportError(GetSourceRange(expr.id()), message); } SourceRange GetSourceRange(int64_t id) const { if (auto it = positions_.find(id); it != positions_.end()) { return it->second; } return SourceRange{}; } int64_t NextId(const SourceRange& range) { auto id = expr_id_++; if (range.begin != -1 || range.end != -1) { positions_.insert(std::pair{id, range}); } return id; } bool HasErrors() const { return error_count_ != 0; } std::string ErrorMessage() { // Errors are collected as they are encountered, not by their location // within the source. To have a more stable error message as implementation // details change, we sort the collected errors by their source location // first. std::stable_sort( errors_.begin(), errors_.end(), [](const ParserError& lhs, const ParserError& rhs) -> bool { auto lhs_begin = PositiveOrMax(lhs.range.begin); auto lhs_end = PositiveOrMax(lhs.range.end); auto rhs_begin = PositiveOrMax(rhs.range.begin); auto rhs_end = PositiveOrMax(rhs.range.end); return lhs_begin < rhs_begin || (lhs_begin == rhs_begin && lhs_end < rhs_end); }); // Build the summary error message using the sorted errors. bool errors_truncated = error_count_ > 100; std::vector messages; messages.reserve( errors_.size() + errors_truncated); // Reserve space for the transform and an // additional element when truncation occurs. std::transform(errors_.begin(), errors_.end(), std::back_inserter(messages), [this](const ParserError& error) { return cel::DisplayParserError(source_, error); }); if (errors_truncated) { messages.emplace_back( absl::StrCat(error_count_ - 100, " more errors were truncated.")); } return absl::StrJoin(messages, "\n"); } void AddMacroCall(int64_t macro_id, absl::string_view function, absl::optional target, std::vector arguments) { macro_calls_.insert( {macro_id, target.has_value() ? NewMemberCall(0, function, std::move(*target), std::move(arguments)) : NewCall(0, function, std::move(arguments))}); } Expr BuildMacroCallArg(const Expr& expr) { if (auto it = macro_calls_.find(expr.id()); it != macro_calls_.end()) { return NewUnspecified(expr.id()); } return absl::visit( absl::Overload( [this, &expr](const UnspecifiedExpr&) -> Expr { return NewUnspecified(expr.id()); }, [this, &expr](const Constant& const_expr) -> Expr { return NewConst(expr.id(), const_expr); }, [this, &expr](const IdentExpr& ident_expr) -> Expr { return NewIdent(expr.id(), ident_expr.name()); }, [this, &expr](const SelectExpr& select_expr) -> Expr { return select_expr.test_only() ? NewPresenceTest( expr.id(), BuildMacroCallArg(select_expr.operand()), select_expr.field()) : NewSelect(expr.id(), BuildMacroCallArg(select_expr.operand()), select_expr.field()); }, [this, &expr](const CallExpr& call_expr) -> Expr { std::vector macro_arguments; macro_arguments.reserve(call_expr.args().size()); for (const auto& argument : call_expr.args()) { macro_arguments.push_back(BuildMacroCallArg(argument)); } absl::optional macro_target; if (call_expr.has_target()) { macro_target = BuildMacroCallArg(call_expr.target()); } return macro_target.has_value() ? NewMemberCall(expr.id(), call_expr.function(), std::move(*macro_target), std::move(macro_arguments)) : NewCall(expr.id(), call_expr.function(), std::move(macro_arguments)); }, [this, &expr](const ListExpr& list_expr) -> Expr { std::vector macro_elements; macro_elements.reserve(list_expr.elements().size()); for (const auto& element : list_expr.elements()) { auto& cloned_element = macro_elements.emplace_back(); if (element.has_expr()) { cloned_element.set_expr(BuildMacroCallArg(element.expr())); } cloned_element.set_optional(element.optional()); } return NewList(expr.id(), std::move(macro_elements)); }, [this, &expr](const StructExpr& struct_expr) -> Expr { std::vector macro_fields; macro_fields.reserve(struct_expr.fields().size()); for (const auto& field : struct_expr.fields()) { auto& macro_field = macro_fields.emplace_back(); macro_field.set_id(field.id()); macro_field.set_name(field.name()); macro_field.set_value(BuildMacroCallArg(field.value())); macro_field.set_optional(field.optional()); } return NewStruct(expr.id(), struct_expr.name(), std::move(macro_fields)); }, [this, &expr](const MapExpr& map_expr) -> Expr { std::vector macro_entries; macro_entries.reserve(map_expr.entries().size()); for (const auto& entry : map_expr.entries()) { auto& macro_entry = macro_entries.emplace_back(); macro_entry.set_id(entry.id()); macro_entry.set_key(BuildMacroCallArg(entry.key())); macro_entry.set_value(BuildMacroCallArg(entry.value())); macro_entry.set_optional(entry.optional()); } return NewMap(expr.id(), std::move(macro_entries)); }, [this, &expr](const ComprehensionExpr& comprehension_expr) -> Expr { return NewComprehension( expr.id(), comprehension_expr.iter_var(), BuildMacroCallArg(comprehension_expr.iter_range()), comprehension_expr.accu_var(), BuildMacroCallArg(comprehension_expr.accu_init()), BuildMacroCallArg(comprehension_expr.loop_condition()), BuildMacroCallArg(comprehension_expr.loop_step()), BuildMacroCallArg(comprehension_expr.result())); }), expr.kind()); } using ExprFactory::NewBoolConst; using ExprFactory::NewBytesConst; using ExprFactory::NewCall; using ExprFactory::NewComprehension; using ExprFactory::NewConst; using ExprFactory::NewDoubleConst; using ExprFactory::NewIdent; using ExprFactory::NewIntConst; using ExprFactory::NewList; using ExprFactory::NewListElement; using ExprFactory::NewMap; using ExprFactory::NewMapEntry; using ExprFactory::NewMemberCall; using ExprFactory::NewNullConst; using ExprFactory::NewPresenceTest; using ExprFactory::NewSelect; using ExprFactory::NewStringConst; using ExprFactory::NewStruct; using ExprFactory::NewStructField; using ExprFactory::NewUintConst; using ExprFactory::NewUnspecified; const absl::btree_map& positions() const { return positions_; } const absl::flat_hash_map& macro_calls() const { return macro_calls_; } absl::flat_hash_map release_macro_calls() { using std::swap; absl::flat_hash_map result; swap(result, macro_calls_); return result; } void EraseId(ExprId id) { positions_.erase(id); if (expr_id_ == id + 1) { --expr_id_; } } protected: int64_t NextId() override { return NextId(macro_position_); } int64_t CopyId(int64_t id) override { if (id == 0) { return 0; } return NextId(GetSourceRange(id)); } private: int64_t expr_id_ = 1; absl::btree_map positions_; absl::flat_hash_map macro_calls_; std::vector errors_; size_t error_count_ = 0; const Source& source_; SourceRange macro_position_; }; } // namespace cel namespace google::api::expr::parser { namespace { using ::antlr4::CharStream; using ::antlr4::CommonTokenStream; using ::antlr4::DefaultErrorStrategy; using ::antlr4::ParseCancellationException; using ::antlr4::Parser; using ::antlr4::ParserRuleContext; using ::antlr4::Token; using ::antlr4::misc::IntervalSet; using ::antlr4::tree::ErrorNode; using ::antlr4::tree::ParseTreeListener; using ::antlr4::tree::TerminalNode; using ::cel::Expr; using ::cel::ExprFromAny; using ::cel::ExprKind; using ::cel::ExprToAny; using ::cel::IdentExpr; using ::cel::ListExprElement; using ::cel::MapExprEntry; using ::cel::SelectExpr; using ::cel::SourceRangeFromParserRuleContext; using ::cel::SourceRangeFromToken; using ::cel::StructExprField; using ::cel_parser_internal::CelBaseVisitor; using ::cel_parser_internal::CelLexer; using ::cel_parser_internal::CelParser; using common::CelOperator; using common::ReverseLookupOperator; using ::cel::expr::ParsedExpr; class CodePointStream final : public CharStream { public: CodePointStream(cel::SourceContentView buffer, absl::string_view source_name) : buffer_(buffer), source_name_(source_name), size_(buffer_.size()), index_(0) {} void consume() override { if (ABSL_PREDICT_FALSE(index_ >= size_)) { ABSL_ASSERT(LA(1) == IntStream::EOF); throw antlr4::IllegalStateException("cannot consume EOF"); } index_++; } size_t LA(ptrdiff_t i) override { if (ABSL_PREDICT_FALSE(i == 0)) { return 0; } auto p = static_cast(index_); if (i < 0) { i++; if (p + i - 1 < 0) { return IntStream::EOF; } } if (p + i - 1 >= static_cast(size_)) { return IntStream::EOF; } return buffer_.at(static_cast(p + i - 1)); } ptrdiff_t mark() override { return -1; } void release(ptrdiff_t marker) override {} size_t index() override { return index_; } void seek(size_t index) override { index_ = std::min(index, size_); } size_t size() override { return size_; } std::string getSourceName() const override { return source_name_.empty() ? IntStream::UNKNOWN_SOURCE_NAME : std::string(source_name_); } std::string getText(const antlr4::misc::Interval& interval) override { if (ABSL_PREDICT_FALSE(interval.a < 0 || interval.b < 0)) { return std::string(); } size_t start = static_cast(interval.a); if (ABSL_PREDICT_FALSE(start >= size_)) { return std::string(); } size_t stop = static_cast(interval.b); if (ABSL_PREDICT_FALSE(stop >= size_)) { stop = size_ - 1; } return buffer_.ToString(static_cast(start), static_cast(stop) + 1); } std::string toString() const override { return buffer_.ToString(); } private: cel::SourceContentView const buffer_; const absl::string_view source_name_; const size_t size_; size_t index_; }; // Scoped helper for incrementing the parse recursion count. // Increments on creation, decrements on destruction (stack unwind). class ScopedIncrement final { public: explicit ScopedIncrement(int& recursion_depth) : recursion_depth_(recursion_depth) { ++recursion_depth_; } ~ScopedIncrement() { --recursion_depth_; } private: int& recursion_depth_; }; // balancer performs tree balancing on operators whose arguments are of equal // precedence. // // The purpose of the balancer is to ensure a compact serialization format for // the logical &&, || operators which have a tendency to create long DAGs which // are skewed in one direction. Since the operators are commutative re-ordering // the terms *must not* affect the evaluation result. // // Based on code from //third_party/cel/go/parser/helper.go class ExpressionBalancer final { public: ExpressionBalancer(cel::ParserMacroExprFactory& factory, std::string function, Expr expr); // addTerm adds an operation identifier and term to the set of terms to be // balanced. void AddTerm(int64_t op, Expr term); // balance creates a balanced tree from the sub-terms and returns the final // Expr value. Expr Balance(); private: // balancedTree recursively balances the terms provided to a commutative // operator. Expr BalancedTree(int lo, int hi); private: cel::ParserMacroExprFactory& factory_; std::string function_; std::vector terms_; std::vector ops_; }; ExpressionBalancer::ExpressionBalancer(cel::ParserMacroExprFactory& factory, std::string function, Expr expr) : factory_(factory), function_(std::move(function)) { terms_.push_back(std::move(expr)); } void ExpressionBalancer::AddTerm(int64_t op, Expr term) { terms_.push_back(std::move(term)); ops_.push_back(op); } Expr ExpressionBalancer::Balance() { if (terms_.size() == 1) { return std::move(terms_[0]); } return BalancedTree(0, ops_.size() - 1); } Expr ExpressionBalancer::BalancedTree(int lo, int hi) { int mid = (lo + hi + 1) / 2; std::vector arguments; arguments.reserve(2); if (mid == lo) { arguments.push_back(std::move(terms_[mid])); } else { arguments.push_back(BalancedTree(lo, mid - 1)); } if (mid == hi) { arguments.push_back(std::move(terms_[mid + 1])); } else { arguments.push_back(BalancedTree(mid + 1, hi)); } return factory_.NewCall(ops_[mid], function_, std::move(arguments)); } class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: ParserVisitor(const cel::Source& source, int max_recursion_depth, const cel::MacroRegistry& macro_registry, bool add_macro_calls = false, bool enable_optional_syntax = false, bool enable_quoted_identifiers = false) : source_(source), factory_(source_), macro_registry_(macro_registry), recursion_depth_(0), max_recursion_depth_(max_recursion_depth), add_macro_calls_(add_macro_calls), enable_optional_syntax_(enable_optional_syntax), enable_quoted_identifiers_(enable_quoted_identifiers) {} ~ParserVisitor() override = default; std::any visit(antlr4::tree::ParseTree* tree) override; std::any visitStart(CelParser::StartContext* ctx) override; std::any visitExpr(CelParser::ExprContext* ctx) override; std::any visitConditionalOr(CelParser::ConditionalOrContext* ctx) override; std::any visitConditionalAnd(CelParser::ConditionalAndContext* ctx) override; std::any visitRelation(CelParser::RelationContext* ctx) override; std::any visitCalc(CelParser::CalcContext* ctx) override; std::any visitUnary(CelParser::UnaryContext* ctx); std::any visitLogicalNot(CelParser::LogicalNotContext* ctx) override; std::any visitNegate(CelParser::NegateContext* ctx) override; std::any visitSelect(CelParser::SelectContext* ctx) override; std::any visitMemberCall(CelParser::MemberCallContext* ctx) override; std::any visitIndex(CelParser::IndexContext* ctx) override; std::any visitCreateMessage(CelParser::CreateMessageContext* ctx) override; std::any visitFieldInitializerList( CelParser::FieldInitializerListContext* ctx) override; std::vector visitFields( CelParser::FieldInitializerListContext* ctx); std::any visitGlobalCall(CelParser::GlobalCallContext* ctx) override; std::any visitIdent(CelParser::IdentContext* ctx) override; std::any visitNested(CelParser::NestedContext* ctx) override; std::any visitCreateList(CelParser::CreateListContext* ctx) override; std::vector visitList(CelParser::ListInitContext* ctx); std::vector visitList(CelParser::ExprListContext* ctx); std::any visitCreateMap(CelParser::CreateMapContext* ctx) override; std::any visitConstantLiteral( CelParser::ConstantLiteralContext* ctx) override; std::any visitPrimaryExpr(CelParser::PrimaryExprContext* ctx) override; std::any visitMemberExpr(CelParser::MemberExprContext* ctx) override; std::any visitMapInitializerList( CelParser::MapInitializerListContext* ctx) override; std::vector visitEntries( CelParser::MapInitializerListContext* ctx); std::any visitInt(CelParser::IntContext* ctx) override; std::any visitUint(CelParser::UintContext* ctx) override; std::any visitDouble(CelParser::DoubleContext* ctx) override; std::any visitString(CelParser::StringContext* ctx) override; std::any visitBytes(CelParser::BytesContext* ctx) override; std::any visitBoolTrue(CelParser::BoolTrueContext* ctx) override; std::any visitBoolFalse(CelParser::BoolFalseContext* ctx) override; std::any visitNull(CelParser::NullContext* ctx) override; // Note: this is destructive and intended to be called after the parse is // finished. cel::SourceInfo GetSourceInfo(); EnrichedSourceInfo enriched_source_info() const; void syntaxError(antlr4::Recognizer* recognizer, antlr4::Token* offending_symbol, size_t line, size_t col, const std::string& msg, std::exception_ptr e) override; bool HasErrored() const; std::string ErrorMessage(); private: template Expr GlobalCallOrMacro(int64_t expr_id, absl::string_view function, Args&&... args) { std::vector arguments; arguments.reserve(sizeof...(Args)); (arguments.push_back(std::forward(args)), ...); return GlobalCallOrMacroImpl(expr_id, function, std::move(arguments)); } Expr GlobalCallOrMacroImpl(int64_t expr_id, absl::string_view function, std::vector args); Expr ReceiverCallOrMacroImpl(int64_t expr_id, absl::string_view function, Expr target, std::vector args); std::string ExtractQualifiedName(antlr4::ParserRuleContext* ctx, const Expr& e); std::string NormalizeIdentifier(CelParser::EscapeIdentContext* ctx); // Attempt to unnest parse context. // // Walk the parse tree to the first complex term to reduce recursive depth in // the visit* calls. antlr4::tree::ParseTree* UnnestContext(antlr4::tree::ParseTree* tree); private: const cel::Source& source_; cel::ParserMacroExprFactory factory_; const cel::MacroRegistry& macro_registry_; int recursion_depth_; const int max_recursion_depth_; const bool add_macro_calls_; const bool enable_optional_syntax_; const bool enable_quoted_identifiers_; }; template ::value>> T* tree_as(antlr4::tree::ParseTree* tree) { return dynamic_cast(tree); } std::any ParserVisitor::visit(antlr4::tree::ParseTree* tree) { ScopedIncrement inc(recursion_depth_); if (recursion_depth_ > max_recursion_depth_) { return ExprToAny(factory_.ReportError( absl::StrFormat("Exceeded max recursion depth of %d when parsing.", max_recursion_depth_))); } tree = UnnestContext(tree); if (auto* ctx = tree_as(tree)) { return visitStart(ctx); } else if (auto* ctx = tree_as(tree)) { return visitExpr(ctx); } else if (auto* ctx = tree_as(tree)) { return visitConditionalAnd(ctx); } else if (auto* ctx = tree_as(tree)) { return visitConditionalOr(ctx); } else if (auto* ctx = tree_as(tree)) { return visitRelation(ctx); } else if (auto* ctx = tree_as(tree)) { return visitCalc(ctx); } else if (auto* ctx = tree_as(tree)) { return visitLogicalNot(ctx); } else if (auto* ctx = tree_as(tree)) { return visitPrimaryExpr(ctx); } else if (auto* ctx = tree_as(tree)) { return visitMemberExpr(ctx); } else if (auto* ctx = tree_as(tree)) { return visitSelect(ctx); } else if (auto* ctx = tree_as(tree)) { return visitMemberCall(ctx); } else if (auto* ctx = tree_as(tree)) { return visitMapInitializerList(ctx); } else if (auto* ctx = tree_as(tree)) { return visitNegate(ctx); } else if (auto* ctx = tree_as(tree)) { return visitIndex(ctx); } else if (auto* ctx = tree_as(tree)) { return visitUnary(ctx); } else if (auto* ctx = tree_as(tree)) { return visitCreateList(ctx); } else if (auto* ctx = tree_as(tree)) { return visitCreateMessage(ctx); } else if (auto* ctx = tree_as(tree)) { return visitCreateMap(ctx); } if (tree) { return ExprToAny( factory_.ReportError(SourceRangeFromParserRuleContext( tree_as(tree)), "unknown parsetree type")); } return ExprToAny(factory_.ReportError("<> parsetree")); } std::any ParserVisitor::visitPrimaryExpr(CelParser::PrimaryExprContext* pctx) { CelParser::PrimaryContext* primary = pctx->primary(); if (auto* ctx = tree_as(primary)) { return visitNested(ctx); } else if (auto* ctx = tree_as(primary)) { return visitIdent(ctx); } else if (auto* ctx = tree_as(primary)) { return visitGlobalCall(ctx); } else if (auto* ctx = tree_as(primary)) { return visitCreateList(ctx); } else if (auto* ctx = tree_as(primary)) { return visitCreateMap(ctx); } else if (auto* ctx = tree_as(primary)) { return visitCreateMessage(ctx); } else if (auto* ctx = tree_as(primary)) { return visitConstantLiteral(ctx); } if (factory_.HasErrors()) { // ANTLR creates PrimaryContext rather than a derived class during certain // error conditions. This is odd, but we ignore it as we already have errors // that occurred. return ExprToAny(factory_.NewUnspecified(factory_.NextId({}))); } return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(pctx), "invalid primary expression")); } std::any ParserVisitor::visitMemberExpr(CelParser::MemberExprContext* mctx) { CelParser::MemberContext* member = mctx->member(); if (auto* ctx = tree_as(member)) { return visitPrimaryExpr(ctx); } else if (auto* ctx = tree_as(member)) { return visitSelect(ctx); } else if (auto* ctx = tree_as(member)) { return visitMemberCall(ctx); } else if (auto* ctx = tree_as(member)) { return visitIndex(ctx); } return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(mctx), "unsupported simple expression")); } std::any ParserVisitor::visitStart(CelParser::StartContext* ctx) { return visit(ctx->expr()); } antlr4::tree::ParseTree* ParserVisitor::UnnestContext( antlr4::tree::ParseTree* tree) { antlr4::tree::ParseTree* last = nullptr; while (tree != last) { last = tree; if (auto* ctx = tree_as(tree)) { tree = ctx->expr(); } if (auto* ctx = tree_as(tree)) { if (ctx->op != nullptr) { return ctx; } tree = ctx->e; } if (auto* ctx = tree_as(tree)) { if (!ctx->ops.empty()) { return ctx; } tree = ctx->e; } if (auto* ctx = tree_as(tree)) { if (!ctx->ops.empty()) { return ctx; } tree = ctx->e; } if (auto* ctx = tree_as(tree)) { if (ctx->calc() == nullptr) { return ctx; } tree = ctx->calc(); } if (auto* ctx = tree_as(tree)) { if (ctx->unary() == nullptr) { return ctx; } tree = ctx->unary(); } if (auto* ctx = tree_as(tree)) { tree = ctx->member(); } if (auto* ctx = tree_as(tree)) { if (auto* nested = tree_as(ctx->primary())) { tree = nested->e; } else { return ctx; } } } return tree; } std::any ParserVisitor::visitExpr(CelParser::ExprContext* ctx) { auto result = ExprFromAny(visit(ctx->e)); if (!ctx->op) { return ExprToAny(std::move(result)); } std::vector arguments; arguments.reserve(3); arguments.push_back(std::move(result)); int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); arguments.push_back(ExprFromAny(visit(ctx->e1))); arguments.push_back(ExprFromAny(visit(ctx->e2))); return ExprToAny( factory_.NewCall(op_id, CelOperator::CONDITIONAL, std::move(arguments))); } std::any ParserVisitor::visitConditionalOr( CelParser::ConditionalOrContext* ctx) { auto result = ExprFromAny(visit(ctx->e)); if (ctx->ops.empty()) { return ExprToAny(std::move(result)); } ExpressionBalancer b(factory_, CelOperator::LOGICAL_OR, std::move(result)); for (size_t i = 0; i < ctx->ops.size(); ++i) { auto op = ctx->ops[i]; if (i >= ctx->e1.size()) { return ExprToAny( factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "unexpected character, wanted '||'")); } auto next = ExprFromAny(visit(ctx->e1[i])); int64_t op_id = factory_.NextId(SourceRangeFromToken(op)); b.AddTerm(op_id, std::move(next)); } return ExprToAny(b.Balance()); } std::any ParserVisitor::visitConditionalAnd( CelParser::ConditionalAndContext* ctx) { auto result = ExprFromAny(visit(ctx->e)); if (ctx->ops.empty()) { return ExprToAny(std::move(result)); } ExpressionBalancer b(factory_, CelOperator::LOGICAL_AND, std::move(result)); for (size_t i = 0; i < ctx->ops.size(); ++i) { auto op = ctx->ops[i]; if (i >= ctx->e1.size()) { return ExprToAny( factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "unexpected character, wanted '&&'")); } auto next = ExprFromAny(visit(ctx->e1[i])); int64_t op_id = factory_.NextId(SourceRangeFromToken(op)); b.AddTerm(op_id, std::move(next)); } return ExprToAny(b.Balance()); } std::any ParserVisitor::visitRelation(CelParser::RelationContext* ctx) { if (ctx->calc()) { return visit(ctx->calc()); } std::string op_text; if (ctx->op) { op_text = ctx->op->getText(); } auto op = ReverseLookupOperator(op_text); if (op) { auto lhs = ExprFromAny(visit(ctx->relation(0))); int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); auto rhs = ExprFromAny(visit(ctx->relation(1))); return ExprToAny( GlobalCallOrMacro(op_id, *op, std::move(lhs), std::move(rhs))); } return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "operator not found")); } std::any ParserVisitor::visitCalc(CelParser::CalcContext* ctx) { if (ctx->unary()) { return visit(ctx->unary()); } std::string op_text; if (ctx->op) { op_text = ctx->op->getText(); } auto op = ReverseLookupOperator(op_text); if (op) { auto lhs = ExprFromAny(visit(ctx->calc(0))); int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); auto rhs = ExprFromAny(visit(ctx->calc(1))); return ExprToAny( GlobalCallOrMacro(op_id, *op, std::move(lhs), std::move(rhs))); } return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "operator not found")); } std::any ParserVisitor::visitUnary(CelParser::UnaryContext* ctx) { return ExprToAny(factory_.NewStringConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), "<>")); } std::any ParserVisitor::visitLogicalNot(CelParser::LogicalNotContext* ctx) { if (ctx->ops.size() % 2 == 0) { return visit(ctx->member()); } int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->ops[0])); auto target = ExprFromAny(visit(ctx->member())); return ExprToAny( GlobalCallOrMacro(op_id, CelOperator::LOGICAL_NOT, std::move(target))); } std::any ParserVisitor::visitNegate(CelParser::NegateContext* ctx) { if (ctx->ops.size() % 2 == 0) { return visit(ctx->member()); } int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->ops[0])); auto target = ExprFromAny(visit(ctx->member())); return ExprToAny( GlobalCallOrMacro(op_id, CelOperator::NEGATE, std::move(target))); } std::string ParserVisitor::NormalizeIdentifier( CelParser::EscapeIdentContext* ctx) { if (auto* raw_id = tree_as(ctx); raw_id) { return raw_id->id->getText(); } if (auto* escaped_id = tree_as(ctx); escaped_id) { if (!enable_quoted_identifiers_) { factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "unsupported syntax '`'"); } auto escaped_id_text = escaped_id->id->getText(); return escaped_id_text.substr(1, escaped_id_text.size() - 2); } // Fallthrough might occur if the parser is in an error state. return ""; } std::any ParserVisitor::visitSelect(CelParser::SelectContext* ctx) { auto operand = ExprFromAny(visit(ctx->member())); // Handle the error case where no valid identifier is specified. if (!ctx->id || !ctx->op) { return ExprToAny(factory_.NewUnspecified( factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); } auto id = NormalizeIdentifier(ctx->id); if (ctx->opt != nullptr) { if (!enable_optional_syntax_) { return ExprToAny(factory_.ReportError( SourceRangeFromParserRuleContext(ctx), "unsupported syntax '.?'")); } auto op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); std::vector arguments; arguments.reserve(2); arguments.push_back(std::move(operand)); arguments.push_back(factory_.NewStringConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), std::move(id))); return ExprToAny(factory_.NewCall(op_id, "_?._", std::move(arguments))); } return ExprToAny( factory_.NewSelect(factory_.NextId(SourceRangeFromToken(ctx->op)), std::move(operand), std::move(id))); } std::any ParserVisitor::visitMemberCall(CelParser::MemberCallContext* ctx) { auto operand = ExprFromAny(visit(ctx->member())); // Handle the error case where no valid identifier is specified. if (!ctx->id) { return ExprToAny(factory_.NewUnspecified( factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); } auto id = ctx->id->getText(); int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->open)); auto args = visitList(ctx->args); return ExprToAny( ReceiverCallOrMacroImpl(op_id, id, std::move(operand), std::move(args))); } std::any ParserVisitor::visitIndex(CelParser::IndexContext* ctx) { auto target = ExprFromAny(visit(ctx->member())); int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); auto index = ExprFromAny(visit(ctx->index)); if (!enable_optional_syntax_ && ctx->opt != nullptr) { return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "unsupported syntax '.?'")); } return ExprToAny(GlobalCallOrMacro( op_id, ctx->opt != nullptr ? "_[?_]" : CelOperator::INDEX, std::move(target), std::move(index))); } std::any ParserVisitor::visitCreateMessage( CelParser::CreateMessageContext* ctx) { std::vector parts; parts.reserve(ctx->ids.size()); for (const auto* id : ctx->ids) { parts.push_back(id->getText()); } std::string name; if (ctx->leadingDot) { name.push_back('.'); name.append(absl::StrJoin(parts, ".")); } else { name = absl::StrJoin(parts, "."); } int64_t obj_id = factory_.NextId(SourceRangeFromToken(ctx->op)); std::vector fields; if (ctx->entries) { fields = visitFields(ctx->entries); } return ExprToAny( factory_.NewStruct(obj_id, std::move(name), std::move(fields))); } std::any ParserVisitor::visitFieldInitializerList( CelParser::FieldInitializerListContext* ctx) { return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "<>")); } std::vector ParserVisitor::visitFields( CelParser::FieldInitializerListContext* ctx) { std::vector res; if (!ctx || ctx->fields.empty()) { return res; } res.reserve(ctx->fields.size()); for (size_t i = 0; i < ctx->fields.size(); ++i) { if (i >= ctx->cols.size() || i >= ctx->values.size()) { // This is the result of a syntax error detected elsewhere. return res; } auto* f = ctx->fields[i]; if (!f->escapeIdent()) { ABSL_DCHECK(HasErrored()); // This is the result of a syntax error detected elsewhere. return res; } std::string id = NormalizeIdentifier(f->escapeIdent()); int64_t init_id = factory_.NextId(SourceRangeFromToken(ctx->cols[i])); if (!enable_optional_syntax_ && f->opt) { factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "unsupported syntax '?'"); continue; } auto value = ExprFromAny(visit(ctx->values[i])); res.push_back(factory_.NewStructField(init_id, std::move(id), std::move(value), f->opt != nullptr)); } return res; } std::any ParserVisitor::visitIdent(CelParser::IdentContext* ctx) { std::string ident_name; if (ctx->leadingDot) { ident_name = "."; } if (!ctx->id) { return ExprToAny(factory_.NewUnspecified( factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); } // check if ID is in reserved identifiers if (cel::internal::LexisIsReserved(ctx->id->getText())) { return ExprToAny(factory_.ReportError( SourceRangeFromParserRuleContext(ctx), absl::StrFormat("reserved identifier: %s", ctx->id->getText()))); } ident_name += ctx->id->getText(); return ExprToAny(factory_.NewIdent( factory_.NextId(SourceRangeFromToken(ctx->id)), std::move(ident_name))); } std::any ParserVisitor::visitGlobalCall(CelParser::GlobalCallContext* ctx) { std::string ident_name; if (ctx->leadingDot) { ident_name = "."; } if (!ctx->id || !ctx->op) { return ExprToAny(factory_.NewUnspecified( factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); } // check if ID is in reserved identifiers if (cel::internal::LexisIsReserved(ctx->id->getText())) { return ExprToAny(factory_.ReportError( SourceRangeFromParserRuleContext(ctx), absl::StrFormat("reserved identifier: %s", ctx->id->getText()))); } ident_name += ctx->id->getText(); int64_t op_id = factory_.NextId(SourceRangeFromToken(ctx->op)); auto args = visitList(ctx->args); return ExprToAny( GlobalCallOrMacroImpl(op_id, std::move(ident_name), std::move(args))); } std::any ParserVisitor::visitNested(CelParser::NestedContext* ctx) { return visit(ctx->e); } std::any ParserVisitor::visitCreateList(CelParser::CreateListContext* ctx) { int64_t list_id = factory_.NextId(SourceRangeFromToken(ctx->op)); auto elems = visitList(ctx->elems); return ExprToAny(factory_.NewList(list_id, std::move(elems))); } std::vector ParserVisitor::visitList( CelParser::ListInitContext* ctx) { std::vector rv; if (!ctx) return rv; rv.reserve(ctx->elems.size()); for (size_t i = 0; i < ctx->elems.size(); ++i) { auto* expr_ctx = ctx->elems[i]; if (expr_ctx == nullptr) { return rv; } if (!enable_optional_syntax_ && expr_ctx->opt != nullptr) { factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "unsupported syntax '?'"); rv.push_back(factory_.NewListElement(factory_.NewUnspecified(0), false)); continue; } rv.push_back(factory_.NewListElement(ExprFromAny(visitExpr(expr_ctx->e)), expr_ctx->opt != nullptr)); } return rv; } std::vector ParserVisitor::visitList(CelParser::ExprListContext* ctx) { std::vector rv; if (!ctx) return rv; std::transform(ctx->e.begin(), ctx->e.end(), std::back_inserter(rv), [this](CelParser::ExprContext* expr_ctx) { return ExprFromAny(visitExpr(expr_ctx)); }); return rv; } std::any ParserVisitor::visitCreateMap(CelParser::CreateMapContext* ctx) { int64_t struct_id = factory_.NextId(SourceRangeFromToken(ctx->op)); std::vector entries; if (ctx->entries) { entries = visitEntries(ctx->entries); } return ExprToAny(factory_.NewMap(struct_id, std::move(entries))); } std::any ParserVisitor::visitConstantLiteral( CelParser::ConstantLiteralContext* clctx) { CelParser::LiteralContext* literal = clctx->literal(); if (auto* ctx = tree_as(literal)) { return visitInt(ctx); } else if (auto* ctx = tree_as(literal)) { return visitUint(ctx); } else if (auto* ctx = tree_as(literal)) { return visitDouble(ctx); } else if (auto* ctx = tree_as(literal)) { return visitString(ctx); } else if (auto* ctx = tree_as(literal)) { return visitBytes(ctx); } else if (auto* ctx = tree_as(literal)) { return visitBoolFalse(ctx); } else if (auto* ctx = tree_as(literal)) { return visitBoolTrue(ctx); } else if (auto* ctx = tree_as(literal)) { return visitNull(ctx); } return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(clctx), "invalid constant literal expression")); } std::any ParserVisitor::visitMapInitializerList( CelParser::MapInitializerListContext* ctx) { return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "<>")); } std::vector ParserVisitor::visitEntries( CelParser::MapInitializerListContext* ctx) { std::vector res; if (!ctx || ctx->keys.empty()) { return res; } res.reserve(ctx->cols.size()); for (size_t i = 0; i < ctx->cols.size(); ++i) { auto id = factory_.NextId(SourceRangeFromToken(ctx->cols[i])); if (!enable_optional_syntax_ && ctx->keys[i]->opt) { factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "unsupported syntax '?'"); res.push_back(factory_.NewMapEntry(0, factory_.NewUnspecified(0), factory_.NewUnspecified(0), false)); continue; } auto key = ExprFromAny(visit(ctx->keys[i]->e)); auto value = ExprFromAny(visit(ctx->values[i])); res.push_back(factory_.NewMapEntry(id, std::move(key), std::move(value), ctx->keys[i]->opt != nullptr)); } return res; } std::any ParserVisitor::visitInt(CelParser::IntContext* ctx) { std::string value; if (ctx->sign) { value = ctx->sign->getText(); } value += ctx->tok->getText(); int64_t int_value; if (absl::StartsWith(ctx->tok->getText(), "0x")) { if (absl::SimpleHexAtoi(value, &int_value)) { return ExprToAny(factory_.NewIntConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), int_value)); } else { return ExprToAny(factory_.ReportError( SourceRangeFromParserRuleContext(ctx), "invalid hex int literal")); } } if (absl::SimpleAtoi(value, &int_value)) { return ExprToAny(factory_.NewIntConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), int_value)); } else { return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "invalid int literal")); } } std::any ParserVisitor::visitUint(CelParser::UintContext* ctx) { std::string value = ctx->tok->getText(); // trim the 'u' designator included in the uint literal. if (!value.empty()) { value.resize(value.size() - 1); } uint64_t uint_value; if (absl::StartsWith(ctx->tok->getText(), "0x")) { if (absl::SimpleHexAtoi(value, &uint_value)) { return ExprToAny(factory_.NewUintConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), uint_value)); } else { return ExprToAny(factory_.ReportError( SourceRangeFromParserRuleContext(ctx), "invalid hex uint literal")); } } if (absl::SimpleAtoi(value, &uint_value)) { return ExprToAny(factory_.NewUintConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), uint_value)); } else { return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "invalid uint literal")); } } std::any ParserVisitor::visitDouble(CelParser::DoubleContext* ctx) { std::string value; if (ctx->sign) { value = ctx->sign->getText(); } value += ctx->tok->getText(); double double_value; if (absl::SimpleAtod(value, &double_value)) { return ExprToAny(factory_.NewDoubleConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), double_value)); } else { return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), "invalid double literal")); } } std::any ParserVisitor::visitString(CelParser::StringContext* ctx) { auto status_or_value = cel::internal::ParseStringLiteral(ctx->tok->getText()); if (!status_or_value.ok()) { return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), status_or_value.status().message())); } return ExprToAny(factory_.NewStringConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), std::move(status_or_value).value())); } std::any ParserVisitor::visitBytes(CelParser::BytesContext* ctx) { auto status_or_value = cel::internal::ParseBytesLiteral(ctx->tok->getText()); if (!status_or_value.ok()) { return ExprToAny(factory_.ReportError(SourceRangeFromParserRuleContext(ctx), status_or_value.status().message())); } return ExprToAny(factory_.NewBytesConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), std::move(status_or_value).value())); } std::any ParserVisitor::visitBoolTrue(CelParser::BoolTrueContext* ctx) { return ExprToAny(factory_.NewBoolConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), true)); } std::any ParserVisitor::visitBoolFalse(CelParser::BoolFalseContext* ctx) { return ExprToAny(factory_.NewBoolConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)), false)); } std::any ParserVisitor::visitNull(CelParser::NullContext* ctx) { return ExprToAny(factory_.NewNullConst( factory_.NextId(SourceRangeFromParserRuleContext(ctx)))); } cel::SourceInfo ParserVisitor::GetSourceInfo() { cel::SourceInfo source_info; source_info.set_location(std::string(source_.description())); for (const auto& positions : factory_.positions()) { source_info.mutable_positions().insert( std::pair{positions.first, positions.second.begin}); } source_info.mutable_line_offsets().reserve(source_.line_offsets().size()); for (const auto& line_offset : source_.line_offsets()) { source_info.mutable_line_offsets().push_back(line_offset); } source_info.mutable_macro_calls() = factory_.release_macro_calls(); return source_info; } EnrichedSourceInfo ParserVisitor::enriched_source_info() const { std::map> offsets; for (const auto& positions : factory_.positions()) { offsets.insert( std::pair{positions.first, std::pair{positions.second.begin, positions.second.end - 1}}); } return EnrichedSourceInfo(std::move(offsets)); } void ParserVisitor::syntaxError(antlr4::Recognizer* recognizer, antlr4::Token* offending_symbol, size_t line, size_t col, const std::string& msg, std::exception_ptr e) { cel::SourceRange range; if (auto position = source_.GetPosition(cel::SourceLocation{ static_cast(line), static_cast(col)}); position) { range.begin = *position; } factory_.ReportError(range, absl::StrCat("Syntax error: ", msg)); } bool ParserVisitor::HasErrored() const { return factory_.HasErrors(); } std::string ParserVisitor::ErrorMessage() { return factory_.ErrorMessage(); } Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, absl::string_view function, std::vector args) { if (auto macro = macro_registry_.FindMacro(function, args.size(), false); macro) { std::vector macro_args; if (add_macro_calls_) { macro_args.reserve(args.size()); for (const auto& arg : args) { macro_args.push_back(factory_.BuildMacroCallArg(arg)); } } factory_.BeginMacro(factory_.GetSourceRange(expr_id)); auto expr = macro->Expand(factory_, absl::nullopt, absl::MakeSpan(args)); factory_.EndMacro(); if (expr) { if (add_macro_calls_) { factory_.AddMacroCall(expr->id(), function, absl::nullopt, std::move(macro_args)); } // We did not end up using `expr_id`. Delete metadata. factory_.EraseId(expr_id); return std::move(*expr); } } return factory_.NewCall(expr_id, function, std::move(args)); } Expr ParserVisitor::ReceiverCallOrMacroImpl(int64_t expr_id, absl::string_view function, Expr target, std::vector args) { if (auto macro = macro_registry_.FindMacro(function, args.size(), true); macro) { Expr macro_target; std::vector macro_args; if (add_macro_calls_) { macro_args.reserve(args.size()); macro_target = factory_.BuildMacroCallArg(target); for (const auto& arg : args) { macro_args.push_back(factory_.BuildMacroCallArg(arg)); } } factory_.BeginMacro(factory_.GetSourceRange(expr_id)); auto expr = macro->Expand(factory_, std::ref(target), absl::MakeSpan(args)); factory_.EndMacro(); if (expr) { if (add_macro_calls_) { factory_.AddMacroCall(expr->id(), function, std::move(macro_target), std::move(macro_args)); } // We did not end up using `expr_id`. Delete metadata. factory_.EraseId(expr_id); return std::move(*expr); } } return factory_.NewMemberCall(expr_id, function, std::move(target), std::move(args)); } std::string ParserVisitor::ExtractQualifiedName(antlr4::ParserRuleContext* ctx, const Expr& e) { if (e == Expr{}) { return ""; } if (const auto* ident_expr = absl::get_if(&e.kind()); ident_expr) { return ident_expr->name(); } if (const auto* select_expr = absl::get_if(&e.kind()); select_expr) { std::string prefix = ExtractQualifiedName(ctx, select_expr->operand()); if (!prefix.empty()) { return absl::StrCat(prefix, ".", select_expr->field()); } } factory_.ReportError(factory_.GetSourceRange(e.id()), "expected a qualified name"); return ""; } // Replacements for absl::StrReplaceAll for escaping standard whitespace // characters. static constexpr auto kStandardReplacements = std::array, 3>{ std::make_pair("\n", "\\n"), std::make_pair("\r", "\\r"), std::make_pair("\t", "\\t"), }; static constexpr absl::string_view kSingleQuote = "'"; // ExprRecursionListener extends the standard ANTLR CelParser to ensure that // recursive entries into the 'expr' rule are limited to a configurable depth so // as to prevent stack overflows. class ExprRecursionListener final : public ParseTreeListener { public: explicit ExprRecursionListener( const int max_recursion_depth = kDefaultMaxRecursionDepth) : max_recursion_depth_(max_recursion_depth), recursion_depth_(0) {} ~ExprRecursionListener() override {} void visitTerminal(TerminalNode* node) override {}; void visitErrorNode(ErrorNode* error) override {}; void enterEveryRule(ParserRuleContext* ctx) override; void exitEveryRule(ParserRuleContext* ctx) override; private: const int max_recursion_depth_; int recursion_depth_; }; void ExprRecursionListener::enterEveryRule(ParserRuleContext* ctx) { // Throw a ParseCancellationException since the parsing would otherwise // continue if this were treated as a syntax error and the problem would // continue to manifest. if (ctx->getRuleIndex() == CelParser::RuleExpr) { if (recursion_depth_ > max_recursion_depth_) { throw ParseCancellationException( absl::StrFormat("Expression recursion limit exceeded. limit: %d", max_recursion_depth_)); } recursion_depth_++; } } void ExprRecursionListener::exitEveryRule(ParserRuleContext* ctx) { if (ctx->getRuleIndex() == CelParser::RuleExpr) { recursion_depth_--; } } class RecoveryLimitErrorStrategy final : public DefaultErrorStrategy { public: explicit RecoveryLimitErrorStrategy( int recovery_limit = kDefaultErrorRecoveryLimit, int recovery_token_lookahead_limit = kDefaultErrorRecoveryTokenLookaheadLimit) : recovery_limit_(recovery_limit), recovery_attempts_(0), recovery_token_lookahead_limit_(recovery_token_lookahead_limit) {} void recover(Parser* recognizer, std::exception_ptr e) override { checkRecoveryLimit(recognizer); DefaultErrorStrategy::recover(recognizer, e); } Token* recoverInline(Parser* recognizer) override { checkRecoveryLimit(recognizer); return DefaultErrorStrategy::recoverInline(recognizer); } // Override the ANTLR implementation to introduce a token lookahead limit as // this prevents pathologically constructed, yet small (< 16kb) inputs from // consuming inordinate amounts of compute. // // This method is only called on error recovery paths. void consumeUntil(Parser* recognizer, const IntervalSet& set) override { size_t ttype = recognizer->getInputStream()->LA(1); int recovery_search_depth = 0; while (ttype != Token::EOF && !set.contains(ttype) && recovery_search_depth++ < recovery_token_lookahead_limit_) { recognizer->consume(); ttype = recognizer->getInputStream()->LA(1); } // Halt all parsing if the lookahead limit is reached during error recovery. if (recovery_search_depth == recovery_token_lookahead_limit_) { throw ParseCancellationException("Unable to find a recovery token"); } } protected: std::string escapeWSAndQuote(const std::string& s) const override { std::string result; result.reserve(s.size() + 2); absl::StrAppend(&result, kSingleQuote, s, kSingleQuote); absl::StrReplaceAll(kStandardReplacements, &result); return result; } private: void checkRecoveryLimit(Parser* recognizer) { if (recovery_attempts_++ >= recovery_limit_) { std::string too_many_errors = absl::StrFormat("More than %d parse errors.", recovery_limit_); recognizer->notifyErrorListeners(too_many_errors); throw ParseCancellationException(too_many_errors); } } int recovery_limit_; int recovery_attempts_; int recovery_token_lookahead_limit_; }; struct ParseResult { cel::Expr expr; cel::SourceInfo source_info; EnrichedSourceInfo enriched_source_info; }; absl::StatusOr ParseImpl(const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options) { try { CodePointStream input(source.content(), source.description()); if (input.size() > options.expression_size_codepoint_limit) { return absl::InvalidArgumentError(absl::StrCat( "expression size exceeds codepoint limit.", " input size: ", input.size(), ", limit: ", options.expression_size_codepoint_limit)); } CelLexer lexer(&input); CommonTokenStream tokens(&lexer); CelParser parser(&tokens); ExprRecursionListener listener(options.max_recursion_depth); ParserVisitor visitor( source, options.max_recursion_depth, registry, options.add_macro_calls, options.enable_optional_syntax, options.enable_quoted_identifiers); lexer.removeErrorListeners(); parser.removeErrorListeners(); lexer.addErrorListener(&visitor); parser.addErrorListener(&visitor); parser.addParseListener(&listener); // Limit the number of error recovery attempts to prevent bad expressions // from consuming lots of cpu / memory. parser.setErrorHandler(std::make_shared( options.error_recovery_limit, options.error_recovery_token_lookahead_limit)); Expr expr; try { expr = ExprFromAny(visitor.visit(parser.start())); } catch (const ParseCancellationException& e) { if (visitor.HasErrored()) { return absl::InvalidArgumentError(visitor.ErrorMessage()); } return absl::CancelledError(e.what()); } if (visitor.HasErrored()) { return absl::InvalidArgumentError(visitor.ErrorMessage()); } return { ParseResult{.expr = std::move(expr), .source_info = visitor.GetSourceInfo(), .enriched_source_info = visitor.enriched_source_info()}}; } catch (const std::exception& e) { return absl::AbortedError(e.what()); } catch (const char* what) { // ANTLRv4 has historically thrown C string literals. return absl::AbortedError(what); } catch (...) { // We guarantee to never throw and always return a status. return absl::UnknownError("An unknown exception occurred"); } } class ParserImpl : public cel::Parser { public: explicit ParserImpl(const ParserOptions& options, cel::MacroRegistry macro_registry, absl::flat_hash_set library_ids) : options_(options), macro_registry_(std::move(macro_registry)), library_ids_(std::move(library_ids)) {} absl::StatusOr> Parse( const cel::Source& source) const override { CEL_ASSIGN_OR_RETURN(auto parse_result, ParseImpl(source, macro_registry_, options_)); return std::make_unique(std::move(parse_result.expr), std::move(parse_result.source_info)); } std::unique_ptr ToBuilder() const override; private: const ParserOptions options_; const cel::MacroRegistry macro_registry_; absl::flat_hash_set library_ids_; }; class ParserBuilderImpl : public cel::ParserBuilder { public: explicit ParserBuilderImpl(const ParserOptions& options) : options_(options) {} ParserOptions& GetOptions() override { return options_; } absl::Status AddMacro(const cel::Macro& macro) override { for (const auto& existing_macro : macros_) { if (existing_macro.key() == macro.key()) { return absl::AlreadyExistsError( absl::StrCat("macro already exists: ", macro.key())); } } macros_.push_back(macro); return absl::OkStatus(); } absl::Status AddLibrary(cel::ParserLibrary library) override { if (!library.id.empty()) { auto [it, inserted] = library_ids_.insert(library.id); if (!inserted) { return absl::AlreadyExistsError( absl::StrCat("parser library already exists: ", library.id)); } } libraries_.push_back(std::move(library)); return absl::OkStatus(); } absl::Status AddLibrarySubset(cel::ParserLibrarySubset subset) override { if (subset.library_id.empty()) { return absl::InvalidArgumentError("subset must have a library id"); } std::string library_id = subset.library_id; auto [it, inserted] = library_subsets_.insert({library_id, std::move(subset)}); if (!inserted) { return absl::AlreadyExistsError( absl::StrCat("parser library subset already exists: ", library_id)); } return absl::OkStatus(); } absl::StatusOr> Build() override { using std::swap; // Save the old configured macros so they aren't affected by applying the // libraries and can be restored if an error occurs. std::vector individual_macros; swap(individual_macros, macros_); absl::Cleanup cleanup([&] { swap(macros_, individual_macros); }); cel::MacroRegistry macro_registry; for (const auto& library : libraries_) { CEL_RETURN_IF_ERROR(library.configure(*this)); if (!library.id.empty()) { auto it = library_subsets_.find(library.id); if (it != library_subsets_.end()) { const cel::ParserLibrarySubset& subset = it->second; for (const auto& macro : macros_) { if (subset.should_include_macro(macro)) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(macro)); } } macros_.clear(); continue; } } CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(macros_)); macros_.clear(); } absl::flat_hash_set library_ids(library_ids_); // Hack to support adding the standard library macros either by option or // with a library configurer. if (!options_.disable_standard_macros && !library_ids_.contains("stdlib")) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(Macro::AllMacros())); library_ids.insert("stdlib"); } if (options_.enable_optional_syntax && !library_ids_.contains("optional")) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptMapMacro())); CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptFlatMapMacro())); library_ids.insert("optional"); } CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(individual_macros)); return std::make_unique(options_, std::move(macro_registry), std::move(library_ids)); } private: friend class ParserImpl; ParserOptions options_; std::vector macros_; absl::flat_hash_set library_ids_; std::vector libraries_; absl::flat_hash_map library_subsets_; }; std::unique_ptr ParserImpl::ToBuilder() const { auto ins = std::make_unique(options_); ins->library_ids_ = library_ids_; ins->macros_ = macro_registry_.ListMacros(); return ins; } } // namespace absl::StatusOr Parse(absl::string_view expression, absl::string_view description, const ParserOptions& options) { std::vector macros; if (!options.disable_standard_macros) { macros = Macro::AllMacros(); } if (options.enable_optional_syntax) { macros.push_back(cel::OptMapMacro()); macros.push_back(cel::OptFlatMapMacro()); } return ParseWithMacros(expression, macros, description, options); } absl::StatusOr ParseWithMacros(absl::string_view expression, const std::vector& macros, absl::string_view description, const ParserOptions& options) { CEL_ASSIGN_OR_RETURN(auto verbose_parsed_expr, EnrichedParse(expression, macros, description, options)); return verbose_parsed_expr.parsed_expr(); } absl::StatusOr EnrichedParse( absl::string_view expression, const std::vector& macros, absl::string_view description, const ParserOptions& options) { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); cel::MacroRegistry macro_registry; CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(macros)); return EnrichedParse(*source, macro_registry, options); } absl::StatusOr EnrichedParse( const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options) { CEL_ASSIGN_OR_RETURN(ParseResult parse_result, ParseImpl(source, registry, options)); ParsedExpr parsed_expr; CEL_RETURN_IF_ERROR(cel::ast_internal::ExprToProto( parse_result.expr, parsed_expr.mutable_expr())); CEL_RETURN_IF_ERROR(cel::ast_internal::SourceInfoToProto( parse_result.source_info, parsed_expr.mutable_source_info())); return VerboseParsedExpr(std::move(parsed_expr), std::move(parse_result.enriched_source_info)); } absl::StatusOr Parse( const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options) { CEL_ASSIGN_OR_RETURN(auto verbose_expr, EnrichedParse(source, registry, options)); return verbose_expr.parsed_expr(); } } // namespace google::api::expr::parser namespace cel { // Creates a new parser builder. // // Intended for use with the Compiler class, most users should prefer the free // functions above for independent parsing of expressions. std::unique_ptr NewParserBuilder(const ParserOptions& options) { return std::make_unique( options); } } // namespace cel ================================================ FILE: parser/parser.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // CEL does not support calling the parser during C++ static initialization. // Callers must ensure the parser is only invoked after C++ static initializers // are run. Failing to do so is undefined behavior. The current reason for this // is the parser uses ANTLRv4, which also makes no guarantees about being safe // with regard to C++ static initialization. As such, neither do we. #ifndef THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ #define THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/source.h" #include "parser/macro.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser_interface.h" #include "parser/source_factory.h" namespace google::api::expr::parser { class VerboseParsedExpr { public: VerboseParsedExpr(cel::expr::ParsedExpr parsed_expr, EnrichedSourceInfo enriched_source_info) : parsed_expr_(std::move(parsed_expr)), enriched_source_info_(std::move(enriched_source_info)) {} const cel::expr::ParsedExpr& parsed_expr() const { return parsed_expr_; } const EnrichedSourceInfo& enriched_source_info() const { return enriched_source_info_; } private: cel::expr::ParsedExpr parsed_expr_; EnrichedSourceInfo enriched_source_info_; }; // See comments at the top of the file for information about usage during C++ // static initialization. absl::StatusOr EnrichedParse( absl::string_view expression, const std::vector& macros, absl::string_view description = "", const ParserOptions& options = ParserOptions()); // See comments at the top of the file for information about usage during C++ // static initialization. absl::StatusOr Parse( absl::string_view expression, absl::string_view description = "", const ParserOptions& options = ParserOptions()); // See comments at the top of the file for information about usage during C++ // static initialization. absl::StatusOr ParseWithMacros( absl::string_view expression, const std::vector& macros, absl::string_view description = "", const ParserOptions& options = ParserOptions()); // See comments at the top of the file for information about usage during C++ // static initialization. absl::StatusOr EnrichedParse( const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options = ParserOptions()); // See comments at the top of the file for information about usage during C++ // static initialization. absl::StatusOr Parse( const cel::Source& source, const cel::MacroRegistry& registry, const ParserOptions& options = ParserOptions()); } // namespace google::api::expr::parser namespace cel { // Creates a new parser builder. // // Intended for use with the Compiler class, most users should prefer the free // functions above for independent parsing of expressions. std::unique_ptr NewParserBuilder( const ParserOptions& options = {}); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_H_ ================================================ FILE: parser/parser_benchmarks.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/log/absl_check.h" #include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" #include "internal/benchmark.h" #include "internal/testing.h" #include "parser/macro.h" #include "parser/options.h" #include "parser/parser.h" namespace google::api::expr::parser { namespace { using ::absl_testing::IsOk; using ::testing::Not; enum class ParseResult { kSuccess, kError }; struct TestInfo { static TestInfo ErrorCase(absl::string_view expr) { TestInfo info; info.expr = expr; info.result = ParseResult::kError; return info; } // The expression to parse. std::string expr = ""; // The expected result of the parse. ParseResult result = ParseResult::kSuccess; }; const std::vector& GetTestCases() { static const std::vector* kInstance = new std::vector{ // Simple test cases we started with {"x * 2"}, {"x * 2u"}, {"x * 2.0"}, {"\"\\u2764\""}, {"\"\u2764\""}, {"! false"}, {"-a"}, {"a.b(5)"}, {"a[3]"}, {"SomeMessage{foo: 5, bar: \"xyz\"}"}, {"[3, 4, 5]"}, {"{foo: 5, bar: \"xyz\"}"}, {"a > 5 && a < 10"}, {"a < 5 || a > 10"}, TestInfo::ErrorCase("{"), // test cases from Go {"\"A\""}, {"true"}, {"false"}, {"0"}, {"42"}, {"0u"}, {"23u"}, {"24u"}, {"0xAu"}, {"-0xA"}, {"0xA"}, {"-1"}, {"4--4"}, {"4--4.1"}, {"b\"abc\""}, {"23.39"}, {"!a"}, {"a"}, {"a?b:c"}, {"a || b"}, {"a || b || c || d || e || f "}, {"a && b"}, {"a && b && c && d && e && f && g"}, {"a && b && c && d || e && f && g && h"}, {"a + b"}, {"a - b"}, {"a * b"}, {"a / b"}, {"a % b"}, {"a in b"}, {"a == b"}, {"a != b"}, {"a > b"}, {"a >= b"}, {"a < b"}, {"a <= b"}, {"a.b"}, {"a.b.c"}, {"a[b]"}, {"foo{ }"}, {"foo{ a:b }"}, {"foo{ a:b, c:d }"}, {"{}"}, {"{a:b, c:d}"}, {"[]"}, {"[a]"}, {"[a, b, c]"}, {"(a)"}, {"((a))"}, {"a()"}, {"a(b)"}, {"a(b, c)"}, {"a.b()"}, {"a.b(c)"}, {"aaa.bbb(ccc)"}, // Parse error tests TestInfo::ErrorCase("*@a | b"), TestInfo::ErrorCase("a | b"), TestInfo::ErrorCase("?"), TestInfo::ErrorCase("t{>C}"), // Macro tests {"has(m.f)"}, {"m.exists_one(v, f)"}, {"m.map(v, f)"}, {"m.map(v, p, f)"}, {"m.filter(v, p)"}, // Tests from Java parser {"[] + [1,2,3,] + [4]"}, {"{1:2u, 2:3u}"}, {"TestAllTypes{single_int32: 1, single_int64: 2}"}, TestInfo::ErrorCase("TestAllTypes(){single_int32: 1, single_int64: 2}"), {"size(x) == x.size()"}, TestInfo::ErrorCase("1 + $"), TestInfo::ErrorCase("1 + 2\n" "3 +"), {"\"\\\"\""}, {"[1,3,4][0]"}, TestInfo::ErrorCase("1.all(2, 3)"), {"x[\"a\"].single_int32 == 23"}, {"x.single_nested_message != null"}, {"false && !true || false ? 2 : 3"}, {"b\"abc\" + B\"def\""}, {"1 + 2 * 3 - 1 / 2 == 6 % 1"}, {"---a"}, TestInfo::ErrorCase("1 + +"), {"\"abc\" + \"def\""}, TestInfo::ErrorCase("{\"a\": 1}.\"a\""), {"\"\\xC3\\XBF\""}, {"\"\\303\\277\""}, {"\"hi\\u263A \\u263Athere\""}, {"\"\\U000003A8\\?\""}, {"\"\\a\\b\\f\\n\\r\\t\\v'\\\"\\\\\\? Legal escapes\""}, TestInfo::ErrorCase("\"\\xFh\""), TestInfo::ErrorCase( "\"\\a\\b\\f\\n\\r\\t\\v\\'\\\"\\\\\\? Illegal escape \\>\""), {"'😁' in ['😁', '😑', '😦']"}, {"'\u00ff' in ['\u00ff', '\u00ff', '\u00ff']"}, {"'\u00ff' in ['\uffff', '\U00100000', '\U0010ffff']"}, {"'\u00ff' in ['\U00100000', '\uffff', '\U0010ffff']"}, TestInfo::ErrorCase("'😁' in ['😁', '😑', '😦']\n" " && in.😁"), TestInfo::ErrorCase("as"), TestInfo::ErrorCase("break"), TestInfo::ErrorCase("const"), TestInfo::ErrorCase("continue"), TestInfo::ErrorCase("else"), TestInfo::ErrorCase("for"), TestInfo::ErrorCase("function"), TestInfo::ErrorCase("if"), TestInfo::ErrorCase("import"), TestInfo::ErrorCase("in"), TestInfo::ErrorCase("let"), TestInfo::ErrorCase("loop"), TestInfo::ErrorCase("package"), TestInfo::ErrorCase("namespace"), TestInfo::ErrorCase("return"), TestInfo::ErrorCase("var"), TestInfo::ErrorCase("void"), TestInfo::ErrorCase("while"), TestInfo::ErrorCase("[1, 2, 3].map(var, var * var)"), TestInfo::ErrorCase("[\n\t\r[\n\t\r[\n\t\r]\n\t\r]\n\t\r"), // Identifier quoting syntax tests. {"a.`b`"}, {"a.`b-c`"}, {"a.`b c`"}, {"a.`b/c`"}, {"a.`b.c`"}, {"a.`in`"}, {"A{`b`: 1}"}, {"A{`b-c`: 1}"}, {"A{`b c`: 1}"}, {"A{`b/c`: 1}"}, {"A{`b.c`: 1}"}, {"A{`in`: 1}"}, {"has(a.`b/c`)"}, // Unsupported quoted identifiers. TestInfo::ErrorCase("a.`b\tc`"), TestInfo::ErrorCase("a.`@foo`"), TestInfo::ErrorCase("a.`$foo`"), TestInfo::ErrorCase("`a.b`"), TestInfo::ErrorCase("`a.b`()"), TestInfo::ErrorCase("foo.`a.b`()"), // Macro calls tests {"x.filter(y, y.filter(z, z > 0))"}, {"has(a.b).filter(c, c)"}, {"x.filter(y, y.exists(z, has(z.a)) && y.exists(z, has(z.b)))"}, {"has(a.b).asList().exists(c, c)"}, TestInfo::ErrorCase("b'\\UFFFFFFFF'"), {"a.?b[?0] && a[?c]"}, {"{?'key': value}"}, {"[?a, ?b]"}, {"[?a[?b]]"}, {"Msg{?field: value}"}, {"m.optMap(v, f)"}, {"m.optFlatMap(v, f)"}}; return *kInstance; } class BenchmarkCaseTest : public testing::TestWithParam {}; TEST_P(BenchmarkCaseTest, ExpectedResult) { std::vector macros = Macro::AllMacros(); macros.push_back(cel::OptMapMacro()); macros.push_back(cel::OptFlatMapMacro()); const TestInfo& test_info = GetParam(); ParserOptions options; options.enable_optional_syntax = true; options.enable_quoted_identifiers = true; auto result = EnrichedParse(test_info.expr, macros, "", options); switch (test_info.result) { case ParseResult::kSuccess: ASSERT_THAT(result, IsOk()); break; case ParseResult::kError: ASSERT_THAT(result, Not(IsOk())); break; } } INSTANTIATE_TEST_SUITE_P(CelParserTest, BenchmarkCaseTest, testing::ValuesIn(GetTestCases())); // This is not a proper microbenchmark, but is used to check for major // regressions in the ANTLR generated code or concurrency issues. Each benchmark // iteration parses all of the basic test cases from the unit-tests. void BM_Parse(benchmark::State& state) { std::vector macros = Macro::AllMacros(); macros.push_back(cel::OptMapMacro()); macros.push_back(cel::OptFlatMapMacro()); ParserOptions options; options.enable_optional_syntax = true; options.enable_quoted_identifiers = true; for (auto s : state) { for (const auto& test_case : GetTestCases()) { auto result = ParseWithMacros(test_case.expr, macros, "", options); ABSL_DCHECK_EQ(result.ok(), test_case.result == ParseResult::kSuccess); benchmark::DoNotOptimize(result); } } } BENCHMARK(BM_Parse)->ThreadRange(1, std::thread::hardware_concurrency()); } // namespace } // namespace google::api::expr::parser ================================================ FILE: parser/parser_interface.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ #define THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/ast.h" #include "common/source.h" #include "parser/macro.h" #include "parser/options.h" namespace cel { class Parser; class ParserBuilder; // Callable for configuring a ParserBuilder. using ParserBuilderConfigurer = absl::AnyInvocable; struct ParserLibrary { // Optional identifier to avoid collisions re-adding the same macros. If // empty, it is not considered for collision detection. std::string id; ParserBuilderConfigurer configure; }; // Declares a subset of a parser library. struct ParserLibrarySubset { // The id of the library to subset. Only one subset can be applied per // library id. // // Must be non-empty. std::string library_id; using MacroPredicate = absl::AnyInvocable; MacroPredicate should_include_macro; }; // Interface for building a CEL parser, see comments on `Parser` below. class ParserBuilder { public: virtual ~ParserBuilder() = default; // Returns the (mutable) current parser options. virtual ParserOptions& GetOptions() = 0; // Adds a macro to the parser. // Standard macros should be automatically added based on parser options. virtual absl::Status AddMacro(const cel::Macro& macro) = 0; virtual absl::Status AddLibrary(ParserLibrary library) = 0; virtual absl::Status AddLibrarySubset(ParserLibrarySubset subset) = 0; // Builds a new parser instance, may error if incompatible macros are added. virtual absl::StatusOr> Build() = 0; }; // Interface for stateful CEL parser objects for use with a `Compiler` // (bundled parse and type check). This is not needed for most users: // prefer using the free functions in `parser.h` for more flexibility. class Parser { public: virtual ~Parser() = default; // Parses the given source into a CEL AST. virtual absl::StatusOr> Parse( const cel::Source& source) const = 0; // Returns a builder initialized with the configuration of this parser. virtual std::unique_ptr ToBuilder() const = 0; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ ================================================ FILE: parser/parser_subset_factory.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/parser_subset_factory.h" #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "parser/macro.h" #include "parser/parser_interface.h" namespace cel { cel::ParserLibrarySubset::MacroPredicate IncludeMacrosByNamePredicate( absl::flat_hash_set macro_names) { return [macro_names_set = std::move(macro_names)](const Macro& macro) { return macro_names_set.contains(macro.function()); }; } cel::ParserLibrarySubset::MacroPredicate IncludeMacrosByNamePredicate( absl::Span macro_names) { return IncludeMacrosByNamePredicate( absl::flat_hash_set(macro_names.begin(), macro_names.end())); } cel::ParserLibrarySubset::MacroPredicate ExcludeMacrosByNamePredicate( absl::flat_hash_set macro_names) { return [macro_names_set = std::move(macro_names)](const Macro& macro) { return !macro_names_set.contains(macro.function()); }; } cel::ParserLibrarySubset::MacroPredicate ExcludeMacrosByNamePredicate( absl::Span macro_names) { return ExcludeMacrosByNamePredicate( absl::flat_hash_set(macro_names.begin(), macro_names.end())); } } // namespace cel ================================================ FILE: parser/parser_subset_factory.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_PARSER_SUBSET_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_PARSER_PARSER_SUBSET_FACTORY_H_ #include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "parser/parser_interface.h" namespace cel { // Predicate that only includes the given macro by name. cel::ParserLibrarySubset::MacroPredicate IncludeMacrosByNamePredicate( absl::flat_hash_set macro_names); cel::ParserLibrarySubset::MacroPredicate IncludeMacrosByNamePredicate( absl::Span macro_names); // Predicate that excludes the given macros by name. cel::ParserLibrarySubset::MacroPredicate ExcludeMacrosByNamePredicate( absl::flat_hash_set macro_names); cel::ParserLibrarySubset::MacroPredicate ExcludeMacrosByNamePredicate( absl::Span macro_names); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_SUBSET_FACTORY_H_ ================================================ FILE: parser/parser_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/parser.h" #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/constant.h" #include "common/expr.h" #include "common/source.h" #include "internal/testing.h" #include "parser/macro.h" #include "parser/options.h" #include "parser/parser_interface.h" #include "parser/source_factory.h" #include "testutil/expr_printer.h" namespace google::api::expr::parser { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::ConstantKindCase; using ::cel::ExprKindCase; using ::cel::test::ExprPrinter; using ::cel::expr::Expr; using ::testing::HasSubstr; using ::testing::Not; struct TestInfo { TestInfo(const std::string& I, const std::string& P, const std::string& E = "", const std::string& L = "", const std::string& R = "", const std::string& M = "") : I(I), P(P), E(E), L(L), R(R), M(M) {} // I contains the input expression to be parsed. std::string I; // P contains the type/id adorned debug output of the expression tree. std::string P; // E contains the expected error output for a failed parse, or "" if the parse // is expected to be successful. std::string E; // L contains the expected source adorned debug output of the expression tree. std::string L; // R contains the expected enriched source info output of the expression tree. std::string R; // M contains the expected macro call output of hte expression tree. std::string M; }; std::vector test_cases = { // Simple test cases we started with {"x * 2", "_*_(\n" " x^#1:Expr.Ident#,\n" " 2^#3:int64#\n" ")^#2:Expr.Call#"}, {"x * 2u", "_*_(\n" " x^#1:Expr.Ident#,\n" " 2u^#3:uint64#\n" ")^#2:Expr.Call#"}, {"x * 2.0", "_*_(\n" " x^#1:Expr.Ident#,\n" " 2.0^#3:double#\n" ")^#2:Expr.Call#"}, {"\"\\u2764\"", "\"\u2764\"^#1:string#"}, {"\"\u2764\"", "\"\u2764\"^#1:string#"}, {"! false", "!_(\n" " false^#2:bool#\n" ")^#1:Expr.Call#"}, {"-a", "-_(\n" " a^#2:Expr.Ident#\n" ")^#1:Expr.Call#"}, {"a.b(5)", "a^#1:Expr.Ident#.b(\n" " 5^#3:int64#\n" ")^#2:Expr.Call#"}, {"a[3]", "_[_](\n" " a^#1:Expr.Ident#,\n" " 3^#3:int64#\n" ")^#2:Expr.Call#"}, {"SomeMessage{foo: 5, bar: \"xyz\"}", "SomeMessage{\n" " foo:5^#3:int64#^#2:Expr.CreateStruct.Entry#,\n" " bar:\"xyz\"^#5:string#^#4:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"[3, 4, 5]", "[\n" " 3^#2:int64#,\n" " 4^#3:int64#,\n" " 5^#4:int64#\n" "]^#1:Expr.CreateList#"}, {"{foo: 5, bar: \"xyz\"}", "{\n" " foo^#3:Expr.Ident#:5^#4:int64#^#2:Expr.CreateStruct.Entry#,\n" " bar^#6:Expr.Ident#:\"xyz\"^#7:string#^#5:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"a > 5 && a < 10", "_&&_(\n" " _>_(\n" " a^#1:Expr.Ident#,\n" " 5^#3:int64#\n" " )^#2:Expr.Call#,\n" " _<_(\n" " a^#4:Expr.Ident#,\n" " 10^#6:int64#\n" " )^#5:Expr.Call#\n" ")^#7:Expr.Call#"}, {"a < 5 || a > 10", "_||_(\n" " _<_(\n" " a^#1:Expr.Ident#,\n" " 5^#3:int64#\n" " )^#2:Expr.Call#,\n" " _>_(\n" " a^#4:Expr.Ident#,\n" " 10^#6:int64#\n" " )^#5:Expr.Call#\n" ")^#7:Expr.Call#"}, {"{", "", "ERROR: :1:2: Syntax error: mismatched input '' expecting " "{'[', " "'{', '}', '(', '.', ',', '-', '!', '\\u003F', 'true', 'false', 'null', " "NUM_FLOAT, " "NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n | {\n" " | .^"}, // test cases from Go {"\"A\"", "\"A\"^#1:string#"}, {"true", "true^#1:bool#"}, {"false", "false^#1:bool#"}, {"0", "0^#1:int64#"}, {"42", "42^#1:int64#"}, {"0u", "0u^#1:uint64#"}, {"23u", "23u^#1:uint64#"}, {"24u", "24u^#1:uint64#"}, {"0xAu", "10u^#1:uint64#"}, {"-0xA", "-10^#1:int64#"}, {"0xA", "10^#1:int64#"}, {"-1", "-1^#1:int64#"}, {"4--4", "_-_(\n" " 4^#1:int64#,\n" " -4^#3:int64#\n" ")^#2:Expr.Call#"}, {"4--4.1", "_-_(\n" " 4^#1:int64#,\n" " -4.1^#3:double#\n" ")^#2:Expr.Call#"}, {"b\"abc\"", "b\"abc\"^#1:bytes#"}, {"23.39", "23.39^#1:double#"}, {"!a", "!_(\n" " a^#2:Expr.Ident#\n" ")^#1:Expr.Call#"}, {"null", "null^#1:NullValue#"}, {"a", "a^#1:Expr.Ident#"}, {"a?b:c", "_?_:_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#,\n" " c^#4:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a || b", "_||_(\n" " a^#1:Expr.Ident#,\n" " b^#2:Expr.Ident#\n" ")^#3:Expr.Call#"}, {"a || b || c || d || e || f ", "_||_(\n" " _||_(\n" " _||_(\n" " a^#1:Expr.Ident#,\n" " b^#2:Expr.Ident#\n" " )^#3:Expr.Call#,\n" " c^#4:Expr.Ident#\n" " )^#5:Expr.Call#,\n" " _||_(\n" " _||_(\n" " d^#6:Expr.Ident#,\n" " e^#8:Expr.Ident#\n" " )^#9:Expr.Call#,\n" " f^#10:Expr.Ident#\n" " )^#11:Expr.Call#\n" ")^#7:Expr.Call#"}, {"a && b", "_&&_(\n" " a^#1:Expr.Ident#,\n" " b^#2:Expr.Ident#\n" ")^#3:Expr.Call#"}, {"a && b && c && d && e && f && g", "_&&_(\n" " _&&_(\n" " _&&_(\n" " a^#1:Expr.Ident#,\n" " b^#2:Expr.Ident#\n" " )^#3:Expr.Call#,\n" " _&&_(\n" " c^#4:Expr.Ident#,\n" " d^#6:Expr.Ident#\n" " )^#7:Expr.Call#\n" " )^#5:Expr.Call#,\n" " _&&_(\n" " _&&_(\n" " e^#8:Expr.Ident#,\n" " f^#10:Expr.Ident#\n" " )^#11:Expr.Call#,\n" " g^#12:Expr.Ident#\n" " )^#13:Expr.Call#\n" ")^#9:Expr.Call#"}, {"a && b && c && d || e && f && g && h", "_||_(\n" " _&&_(\n" " _&&_(\n" " a^#1:Expr.Ident#,\n" " b^#2:Expr.Ident#\n" " )^#3:Expr.Call#,\n" " _&&_(\n" " c^#4:Expr.Ident#,\n" " d^#6:Expr.Ident#\n" " )^#7:Expr.Call#\n" " )^#5:Expr.Call#,\n" " _&&_(\n" " _&&_(\n" " e^#8:Expr.Ident#,\n" " f^#9:Expr.Ident#\n" " )^#10:Expr.Call#,\n" " _&&_(\n" " g^#11:Expr.Ident#,\n" " h^#13:Expr.Ident#\n" " )^#14:Expr.Call#\n" " )^#12:Expr.Call#\n" ")^#15:Expr.Call#"}, {"a + b", "_+_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a - b", "_-_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a * b", "_*_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a / b", "_/_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, { "a % b", "_%_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#", }, {"a in b", "@in(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a == b", "_==_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a != b", "_!=_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a > b", "_>_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a >= b", "_>=_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a < b", "_<_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a <= b", "_<=_(\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"a.b", "a^#1:Expr.Ident#.b^#2:Expr.Select#"}, {"a.b.c", "a^#1:Expr.Ident#.b^#2:Expr.Select#.c^#3:Expr.Select#"}, {"a[b]", "_[_](\n" " a^#1:Expr.Ident#,\n" " b^#3:Expr.Ident#\n" ")^#2:Expr.Call#"}, {"foo{ }", "foo{}^#1:Expr.CreateStruct#"}, {"foo{ a:b }", "foo{\n" " a:b^#3:Expr.Ident#^#2:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"foo{ a:b, c:d }", "foo{\n" " a:b^#3:Expr.Ident#^#2:Expr.CreateStruct.Entry#,\n" " c:d^#5:Expr.Ident#^#4:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"{}", "{}^#1:Expr.CreateStruct#"}, {"{a:b, c:d}", "{\n" " a^#3:Expr.Ident#:b^#4:Expr.Ident#^#2:Expr.CreateStruct.Entry#,\n" " c^#6:Expr.Ident#:d^#7:Expr.Ident#^#5:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"[]", "[]^#1:Expr.CreateList#"}, {"[a]", "[\n" " a^#2:Expr.Ident#\n" "]^#1:Expr.CreateList#"}, {"[a, b, c]", "[\n" " a^#2:Expr.Ident#,\n" " b^#3:Expr.Ident#,\n" " c^#4:Expr.Ident#\n" "]^#1:Expr.CreateList#"}, {"(a)", "a^#1:Expr.Ident#"}, {"((a))", "a^#1:Expr.Ident#"}, {"a()", "a()^#1:Expr.Call#"}, {"a(b)", "a(\n" " b^#2:Expr.Ident#\n" ")^#1:Expr.Call#"}, {"a(b, c)", "a(\n" " b^#2:Expr.Ident#,\n" " c^#3:Expr.Ident#\n" ")^#1:Expr.Call#"}, {"a.b()", "a^#1:Expr.Ident#.b()^#2:Expr.Call#"}, { "a.b(c)", "a^#1:Expr.Ident#.b(\n" " c^#3:Expr.Ident#\n" ")^#2:Expr.Call#", /* E */ "", "a^#1[1,0]#.b(\n" " c^#3[1,4]#\n" ")^#2[1,3]#", "[1,0,0]^#[2,3,3]^#[3,4,4]", }, { "aaa.bbb(ccc)", "aaa^#1:Expr.Ident#.bbb(\n" " ccc^#3:Expr.Ident#\n" ")^#2:Expr.Call#", /* E */ "", "aaa^#1[1,0]#.bbb(\n" " ccc^#3[1,8]#\n" ")^#2[1,7]#", "[1,0,2]^#[2,7,7]^#[3,8,10]", }, // Parse error tests {"*@a | b", "", "ERROR: :1:1: Syntax error: extraneous input '*' expecting {'[', " "'{', " "'(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | *@a | b\n" " | ^\n" "ERROR: :1:2: Syntax error: token recognition error at: '@'\n" " | *@a | b\n" " | .^\n" "ERROR: :1:5: Syntax error: token recognition error at: '| '\n" " | *@a | b\n" " | ....^\n" "ERROR: :1:7: Syntax error: extraneous input 'b' expecting \n" " | *@a | b\n" " | ......^"}, {"a | b", "", "ERROR: :1:3: Syntax error: token recognition error at: '| '\n" " | a | b\n" " | ..^\n" "ERROR: :1:5: Syntax error: extraneous input 'b' expecting \n" " | a | b\n" " | ....^"}, {"?", "", "ERROR: :1:1: Syntax error: mismatched input '?' expecting " "{'[', '{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, " "NUM_INT, NUM_UINT, STRING, BYTES, IDENTIFIER}\n | ?\n | ^\n" "ERROR: :1:2: Syntax error: mismatched input '' expecting " "{'[', '{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, " "NUM_INT, NUM_UINT, STRING, BYTES, IDENTIFIER}\n | ?\n | .^\n" "ERROR: :4294967295:0: <> parsetree"}, {"t{>C}", "", "ERROR: :1:3: Syntax error: extraneous input '>' expecting {'}', " "',', '\\u003F', IDENTIFIER, ESC_IDENTIFIER}\n | t{>C}\n | ..^\nERROR: " ":1:5: " "Syntax error: " "mismatched input '}' expecting ':'\n | t{>C}\n | ....^"}, // Macro tests {"has(m.f)", "m^#2:Expr.Ident#.f~test-only~^#4:Expr.Select#", "", "m^#2[1,4]#.f~test-only~^#4[1,3]#", "[2,4,4]^#[3,5,5]^#[4,3,3]", "has(\n" " m^#2:Expr.Ident#.f^#3:Expr.Select#\n" ")^#4:has"}, {"m.exists_one(v, f)", "__comprehension__(\n" " // Variable\n" " v,\n" " // Target\n" " m^#1:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " 0^#5:int64#,\n" " // LoopCondition\n" " true^#6:bool#,\n" " // LoopStep\n" " _?_:_(\n" " f^#4:Expr.Ident#,\n" " _+_(\n" " @result^#7:Expr.Ident#,\n" " 1^#8:int64#\n" " )^#9:Expr.Call#,\n" " @result^#10:Expr.Ident#\n" " )^#11:Expr.Call#,\n" " // Result\n" " _==_(\n" " @result^#12:Expr.Ident#,\n" " 1^#13:int64#\n" " )^#14:Expr.Call#)^#15:Expr.Comprehension#", "", "", "", "m^#1:Expr.Ident#.exists_one(\n" " v^#3:Expr.Ident#,\n" " f^#4:Expr.Ident#\n" ")^#15:exists_one"}, {"m.map(v, f)", "__comprehension__(\n" " // Variable\n" " v,\n" " // Target\n" " m^#1:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " []^#5:Expr.CreateList#,\n" " // LoopCondition\n" " true^#6:bool#,\n" " // LoopStep\n" " _+_(\n" " @result^#7:Expr.Ident#,\n" " [\n" " f^#4:Expr.Ident#\n" " ]^#8:Expr.CreateList#\n" " )^#9:Expr.Call#,\n" " // Result\n" " @result^#10:Expr.Ident#)^#11:Expr.Comprehension#", "", "", "", "m^#1:Expr.Ident#.map(\n" " v^#3:Expr.Ident#,\n" " f^#4:Expr.Ident#\n" ")^#11:map"}, {"m.map(v, p, f)", "__comprehension__(\n" " // Variable\n" " v,\n" " // Target\n" " m^#1:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " []^#6:Expr.CreateList#,\n" " // LoopCondition\n" " true^#7:bool#,\n" " // LoopStep\n" " _?_:_(\n" " p^#4:Expr.Ident#,\n" " _+_(\n" " @result^#8:Expr.Ident#,\n" " [\n" " f^#5:Expr.Ident#\n" " ]^#9:Expr.CreateList#\n" " )^#10:Expr.Call#,\n" " @result^#11:Expr.Ident#\n" " )^#12:Expr.Call#,\n" " // Result\n" " @result^#13:Expr.Ident#)^#14:Expr.Comprehension#", "", "", "", "m^#1:Expr.Ident#.map(\n" " v^#3:Expr.Ident#,\n" " p^#4:Expr.Ident#,\n" " f^#5:Expr.Ident#\n" ")^#14:map"}, {"m.filter(v, p)", "__comprehension__(\n" " // Variable\n" " v,\n" " // Target\n" " m^#1:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " []^#5:Expr.CreateList#,\n" " // LoopCondition\n" " true^#6:bool#,\n" " // LoopStep\n" " _?_:_(\n" " p^#4:Expr.Ident#,\n" " _+_(\n" " @result^#7:Expr.Ident#,\n" " [\n" " v^#3:Expr.Ident#\n" " ]^#8:Expr.CreateList#\n" " )^#9:Expr.Call#,\n" " @result^#10:Expr.Ident#\n" " )^#11:Expr.Call#,\n" " // Result\n" " @result^#12:Expr.Ident#)^#13:Expr.Comprehension#", "", "", "", "m^#1:Expr.Ident#.filter(\n" " v^#3:Expr.Ident#,\n" " p^#4:Expr.Ident#\n" ")^#13:filter"}, // Tests from Java parser {"[] + [1,2,3,] + [4]", "_+_(\n" " _+_(\n" " []^#1:Expr.CreateList#,\n" " [\n" " 1^#4:int64#,\n" " 2^#5:int64#,\n" " 3^#6:int64#\n" " ]^#3:Expr.CreateList#\n" " )^#2:Expr.Call#,\n" " [\n" " 4^#9:int64#\n" " ]^#8:Expr.CreateList#\n" ")^#7:Expr.Call#"}, {"{1:2u, 2:3u}", "{\n" " 1^#3:int64#:2u^#4:uint64#^#2:Expr.CreateStruct.Entry#,\n" " 2^#6:int64#:3u^#7:uint64#^#5:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"TestAllTypes{single_int32: 1, single_int64: 2}", "TestAllTypes{\n" " single_int32:1^#3:int64#^#2:Expr.CreateStruct.Entry#,\n" " single_int64:2^#5:int64#^#4:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"TestAllTypes(){single_int32: 1, single_int64: 2}", "", "ERROR: :1:15: Syntax error: mismatched input '{' expecting \n" " | TestAllTypes(){single_int32: 1, single_int64: 2}\n" " | ..............^"}, {"size(x) == x.size()", "_==_(\n" " size(\n" " x^#2:Expr.Ident#\n" " )^#1:Expr.Call#,\n" " x^#4:Expr.Ident#.size()^#5:Expr.Call#\n" ")^#3:Expr.Call#"}, {"1 + $", "", "ERROR: :1:5: Syntax error: token recognition error at: '$'\n" " | 1 + $\n" " | ....^\n" "ERROR: :1:6: Syntax error: mismatched input '' expecting " "{'[', " "'{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | 1 + $\n" " | .....^"}, {"1 + 2\n" "3 +", "", "ERROR: :2:1: Syntax error: mismatched input '3' expecting \n" " | 3 +\n" " | ^"}, {"\"\\\"\"", "\"\\\"\"^#1:string#"}, {"[1,3,4][0]", "_[_](\n" " [\n" " 1^#2:int64#,\n" " 3^#3:int64#,\n" " 4^#4:int64#\n" " ]^#1:Expr.CreateList#,\n" " 0^#6:int64#\n" ")^#5:Expr.Call#"}, {"1.all(2, 3)", "", "ERROR: :1:7: all() variable name must be a simple identifier\n" " | 1.all(2, 3)\n" " | ......^"}, {"x[\"a\"].single_int32 == 23", "_==_(\n" " _[_](\n" " x^#1:Expr.Ident#,\n" " \"a\"^#3:string#\n" " )^#2:Expr.Call#.single_int32^#4:Expr.Select#,\n" " 23^#6:int64#\n" ")^#5:Expr.Call#"}, {"x.single_nested_message != null", "_!=_(\n" " x^#1:Expr.Ident#.single_nested_message^#2:Expr.Select#,\n" " null^#4:NullValue#\n" ")^#3:Expr.Call#"}, {"false && !true || false ? 2 : 3", "_?_:_(\n" " _||_(\n" " _&&_(\n" " false^#1:bool#,\n" " !_(\n" " true^#3:bool#\n" " )^#2:Expr.Call#\n" " )^#4:Expr.Call#,\n" " false^#5:bool#\n" " )^#6:Expr.Call#,\n" " 2^#8:int64#,\n" " 3^#9:int64#\n" ")^#7:Expr.Call#"}, {"b\"abc\" + B\"def\"", "_+_(\n" " b\"abc\"^#1:bytes#,\n" " b\"def\"^#3:bytes#\n" ")^#2:Expr.Call#"}, {"1 + 2 * 3 - 1 / 2 == 6 % 1", "_==_(\n" " _-_(\n" " _+_(\n" " 1^#1:int64#,\n" " _*_(\n" " 2^#3:int64#,\n" " 3^#5:int64#\n" " )^#4:Expr.Call#\n" " )^#2:Expr.Call#,\n" " _/_(\n" " 1^#7:int64#,\n" " 2^#9:int64#\n" " )^#8:Expr.Call#\n" " )^#6:Expr.Call#,\n" " _%_(\n" " 6^#11:int64#,\n" " 1^#13:int64#\n" " )^#12:Expr.Call#\n" ")^#10:Expr.Call#"}, {"---a", "-_(\n" " a^#2:Expr.Ident#\n" ")^#1:Expr.Call#"}, {"1 + +", "", "ERROR: :1:5: Syntax error: mismatched input '+' expecting {'[', " "'{'," " '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT," " STRING, BYTES, IDENTIFIER}\n" " | 1 + +\n" " | ....^\n" "ERROR: :1:6: Syntax error: mismatched input '' expecting " "{'[', " "'{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | 1 + +\n" " | .....^"}, {"\"abc\" + \"def\"", "_+_(\n" " \"abc\"^#1:string#,\n" " \"def\"^#3:string#\n" ")^#2:Expr.Call#"}, {"{\"a\": 1}.\"a\"", "", "ERROR: :1:10: Syntax error: no viable alternative at input " "'.\"a\"'\n" " | {\"a\": 1}.\"a\"\n" " | .........^"}, {"\"\\xC3\\XBF\"", "\"ÿ\"^#1:string#"}, {"\"\\303\\277\"", "\"ÿ\"^#1:string#"}, {"\"hi\\u263A \\u263Athere\"", "\"hi☺ ☺there\"^#1:string#"}, {"\"\\U000003A8\\?\"", "\"Ψ?\"^#1:string#"}, {"\"\\a\\b\\f\\n\\r\\t\\v'\\\"\\\\\\? Legal escapes\"", "\"\\x07\\x08\\x0c\\n\\r\\t\\x0b'\\\"\\\\? Legal escapes\"^#1:string#"}, {"\"\\xFh\"", "", "ERROR: :1:1: Syntax error: token recognition error at: '\"\\xFh'\n" " | \"\\xFh\"\n" " | ^\n" "ERROR: :1:6: Syntax error: token recognition error at: '\"'\n" " | \"\\xFh\"\n" " | .....^\n" "ERROR: :1:7: Syntax error: mismatched input '' expecting " "{'[', " "'{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | \"\\xFh\"\n" " | ......^"}, {"\"\\a\\b\\f\\n\\r\\t\\v\\'\\\"\\\\\\? Illegal escape \\>\"", "", "ERROR: :1:1: Syntax error: token recognition error at: " "'\"\\a\\b\\f\\n\\r\\t\\v\\'\\\"\\\\\\? Illegal escape \\>'\n" " | \"\\a\\b\\f\\n\\r\\t\\v\\'\\\"\\\\\\? Illegal escape \\>\"\n" " | ^\n" "ERROR: :1:42: Syntax error: token recognition error at: '\"'\n" " | \"\\a\\b\\f\\n\\r\\t\\v\\'\\\"\\\\\\? Illegal escape \\>\"\n" " | .........................................^\n" "ERROR: :1:43: Syntax error: mismatched input '' expecting " "{'['," " '{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | \"\\a\\b\\f\\n\\r\\t\\v\\'\\\"\\\\\\? Illegal escape \\>\"\n" " | ..........................................^"}, {"'😁' in ['😁', '😑', '😦']", "@in(\n" " \"😁\"^#1:string#,\n" " [\n" " \"😁\"^#4:string#,\n" " \"😑\"^#5:string#,\n" " \"😦\"^#6:string#\n" " ]^#3:Expr.CreateList#\n" ")^#2:Expr.Call#"}, {"'\u00ff' in ['\u00ff', '\u00ff', '\u00ff']", "@in(\n" " \"\u00ff\"^#1:string#,\n" " [\n" " \"\u00ff\"^#4:string#,\n" " \"\u00ff\"^#5:string#,\n" " \"\u00ff\"^#6:string#\n" " ]^#3:Expr.CreateList#\n" ")^#2:Expr.Call#"}, {"'\u00ff' in ['\uffff', '\U00100000', '\U0010ffff']", "@in(\n" " \"\u00ff\"^#1:string#,\n" " [\n" " \"\uffff\"^#4:string#,\n" " \"\U00100000\"^#5:string#,\n" " \"\U0010ffff\"^#6:string#\n" " ]^#3:Expr.CreateList#\n" ")^#2:Expr.Call#"}, {"'\u00ff' in ['\U00100000', '\uffff', '\U0010ffff']", "@in(\n" " \"\u00ff\"^#1:string#,\n" " [\n" " \"\U00100000\"^#4:string#,\n" " \"\uffff\"^#5:string#,\n" " \"\U0010ffff\"^#6:string#\n" " ]^#3:Expr.CreateList#\n" ")^#2:Expr.Call#"}, {"'😁' in ['😁', '😑', '😦']\n" " && in.😁", "", "ERROR: :2:7: Syntax error: extraneous input 'in' expecting {'[', " "'{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | && in.😁\n" " | ......^\n" "ERROR: :2:10: Syntax error: token recognition error at: '😁'\n" " | && in.😁\n" " | .........^\n" "ERROR: :2:11: Syntax error: no viable alternative at input '.'\n" " | && in.😁\n" " | ..........^"}, {"as", "", "ERROR: :1:1: reserved identifier: as\n" " | as\n" " | ^"}, {"break", "", "ERROR: :1:1: reserved identifier: break\n" " | break\n" " | ^"}, {"const", "", "ERROR: :1:1: reserved identifier: const\n" " | const\n" " | ^"}, {"continue", "", "ERROR: :1:1: reserved identifier: continue\n" " | continue\n" " | ^"}, {"else", "", "ERROR: :1:1: reserved identifier: else\n" " | else\n" " | ^"}, {"for", "", "ERROR: :1:1: reserved identifier: for\n" " | for\n" " | ^"}, {"function", "", "ERROR: :1:1: reserved identifier: function\n" " | function\n" " | ^"}, {"if", "", "ERROR: :1:1: reserved identifier: if\n" " | if\n" " | ^"}, {"import", "", "ERROR: :1:1: reserved identifier: import\n" " | import\n" " | ^"}, {"in", "", "ERROR: :1:1: Syntax error: mismatched input 'in' expecting {'[', " "'{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | in\n" " | ^\n" "ERROR: :1:3: Syntax error: mismatched input '' expecting " "{'[', " "'{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | in\n" " | ..^"}, {"let", "", "ERROR: :1:1: reserved identifier: let\n" " | let\n" " | ^"}, {"loop", "", "ERROR: :1:1: reserved identifier: loop\n" " | loop\n" " | ^"}, {"package", "", "ERROR: :1:1: reserved identifier: package\n" " | package\n" " | ^"}, {"namespace", "", "ERROR: :1:1: reserved identifier: namespace\n" " | namespace\n" " | ^"}, {"return", "", "ERROR: :1:1: reserved identifier: return\n" " | return\n" " | ^"}, {"var", "", "ERROR: :1:1: reserved identifier: var\n" " | var\n" " | ^"}, {"void", "", "ERROR: :1:1: reserved identifier: void\n" " | void\n" " | ^"}, {"while", "", "ERROR: :1:1: reserved identifier: while\n" " | while\n" " | ^"}, {"[1, 2, 3].map(var, var * var)", "", "ERROR: :1:15: reserved identifier: var\n" " | [1, 2, 3].map(var, var * var)\n" " | ..............^\n" "ERROR: :1:15: map() variable name must be a simple identifier\n" " | [1, 2, 3].map(var, var * var)\n" " | ..............^\n" "ERROR: :1:20: reserved identifier: var\n" " | [1, 2, 3].map(var, var * var)\n" " | ...................^\n" "ERROR: :1:26: reserved identifier: var\n" " | [1, 2, 3].map(var, var * var)\n" " | .........................^"}, {"[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[['too many']]]]]]]]]]]]]]]]]]]]]]]]]]]]" "]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]" "]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]" "]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]" "]]]]]]", "", "Expression recursion limit exceeded. limit: 32", "", "", ""}, { // Note, the ANTLR parse stack may recurse much more deeply and permit // more detailed expressions than the visitor can recurse over in // practice. "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[['just fine'],[1],[2],[3],[4],[5]]]]]]]" "]]]]]]]]]]]]]]]]]]]]]]]]", "", // parse output not validated as it is too large. "", "", "", "", }, { "[\n\t\r[\n\t\r[\n\t\r]\n\t\r]\n\t\r", "", // parse output not validated as it is too large. "ERROR: :6:3: Syntax error: mismatched input '' expecting " "{']', ','}\n" " | \r\n" " | ..^", }, // Identifier quoting syntax tests. {"a.`b`", "a^#1:Expr.Ident#.b^#2:Expr.Select#"}, {"a.`b-c`", "a^#1:Expr.Ident#.b-c^#2:Expr.Select#"}, {"a.`b c`", "a^#1:Expr.Ident#.b c^#2:Expr.Select#"}, {"a.`b/c`", "a^#1:Expr.Ident#.b/c^#2:Expr.Select#"}, {"a.`b.c`", "a^#1:Expr.Ident#.b.c^#2:Expr.Select#"}, {"a.`in`", "a^#1:Expr.Ident#.in^#2:Expr.Select#"}, {"A{`b`: 1}", "A{\n" " b:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"A{`b-c`: 1}", "A{\n" " b-c:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"A{`b c`: 1}", "A{\n" " b c:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"A{`b/c`: 1}", "A{\n" " b/c:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"A{`b.c`: 1}", "A{\n" " b.c:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"A{`in`: 1}", "A{\n" " in:1^#3:int64#^#2:Expr.CreateStruct.Entry#\n" "}^#1:Expr.CreateStruct#"}, {"has(a.`b/c`)", "a^#2:Expr.Ident#.b/c~test-only~^#4:Expr.Select#"}, // Unsupported quoted identifiers. {"a.`b\tc`", "", "ERROR: :1:3: Syntax error: token recognition error at: '`b\\t'\n" " | a.`b c`\n" " | ..^\n" "ERROR: :1:7: Syntax error: token recognition error at: '`'\n" " | a.`b c`\n" " | ......^"}, {"a.`@foo`", "", "ERROR: :1:3: Syntax error: token recognition error at: '`@'\n" " | a.`@foo`\n" " | ..^\n" "ERROR: :1:8: Syntax error: token recognition error at: '`'\n" " | a.`@foo`\n" " | .......^"}, {"a.`$foo`", "", "ERROR: :1:3: Syntax error: token recognition error at: '`$'\n" " | a.`$foo`\n" " | ..^\n" "ERROR: :1:8: Syntax error: token recognition error at: '`'\n" " | a.`$foo`\n" " | .......^"}, {"`a.b`", "", "ERROR: :1:1: Syntax error: mismatched input '`a.b`' expecting " "{'[', '{', " "'(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, " "NUM_UINT, STRING, " "BYTES, IDENTIFIER}\n" " | `a.b`\n" " | ^"}, {"`a.b`()", "", "ERROR: :1:1: Syntax error: extraneous input '`a.b`' expecting " "{'[', '{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, " "NUM_INT, NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | `a.b`()\n" " | ^\n" "ERROR: :1:7: Syntax error: mismatched input ')' expecting {'[', " "'{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM" "_INT, NUM_UINT, STRING, BYTES, IDENTIFIER}\n" " | `a.b`()\n" " | ......^"}, {"foo.`a.b`()", "", "ERROR: :1:10: Syntax error: mismatched input '(' expecting \n" " | foo.`a.b`()\n" " | .........^"}, // Macro calls tests {"x.filter(y, y.filter(z, z > 0))", "__comprehension__(\n" " // Variable\n" " y,\n" " // Target\n" " x^#1:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " []^#19:Expr.CreateList#,\n" " // LoopCondition\n" " true^#20:bool#,\n" " // LoopStep\n" " _?_:_(\n" " __comprehension__(\n" " // Variable\n" " z,\n" " // Target\n" " y^#4:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " []^#10:Expr.CreateList#,\n" " // LoopCondition\n" " true^#11:bool#,\n" " // LoopStep\n" " _?_:_(\n" " _>_(\n" " z^#7:Expr.Ident#,\n" " 0^#9:int64#\n" " )^#8:Expr.Call#,\n" " _+_(\n" " @result^#12:Expr.Ident#,\n" " [\n" " z^#6:Expr.Ident#\n" " ]^#13:Expr.CreateList#\n" " )^#14:Expr.Call#,\n" " @result^#15:Expr.Ident#\n" " )^#16:Expr.Call#,\n" " // Result\n" " @result^#17:Expr.Ident#)^#18:Expr.Comprehension#,\n" " _+_(\n" " @result^#21:Expr.Ident#,\n" " [\n" " y^#3:Expr.Ident#\n" " ]^#22:Expr.CreateList#\n" " )^#23:Expr.Call#,\n" " @result^#24:Expr.Ident#\n" " )^#25:Expr.Call#,\n" " // Result\n" " @result^#26:Expr.Ident#)^#27:Expr.Comprehension#" "", "", "", "", "x^#1:Expr.Ident#.filter(\n" " y^#3:Expr.Ident#,\n" " ^#18:filter#\n" ")^#27:filter#,\n" "y^#4:Expr.Ident#.filter(\n" " z^#6:Expr.Ident#,\n" " _>_(\n" " z^#7:Expr.Ident#,\n" " 0^#9:int64#\n" " )^#8:Expr.Call#\n" ")^#18:filter"}, {"has(a.b).filter(c, c)", "__comprehension__(\n" " // Variable\n" " c,\n" " // Target\n" " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " []^#8:Expr.CreateList#,\n" " // LoopCondition\n" " true^#9:bool#,\n" " // LoopStep\n" " _?_:_(\n" " c^#7:Expr.Ident#,\n" " _+_(\n" " @result^#10:Expr.Ident#,\n" " [\n" " c^#6:Expr.Ident#\n" " ]^#11:Expr.CreateList#\n" " )^#12:Expr.Call#,\n" " @result^#13:Expr.Ident#\n" " )^#14:Expr.Call#,\n" " // Result\n" " @result^#15:Expr.Ident#)^#16:Expr.Comprehension#", "", "", "", "^#4:has#.filter(\n" " c^#6:Expr.Ident#,\n" " c^#7:Expr.Ident#\n" ")^#16:filter#,\n" "has(\n" " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" ")^#4:has"}, {"x.filter(y, y.exists(z, has(z.a)) && y.exists(z, has(z.b)))", "__comprehension__(\n" " // Variable\n" " y,\n" " // Target\n" " x^#1:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " []^#35:Expr.CreateList#,\n" " // LoopCondition\n" " true^#36:bool#,\n" " // LoopStep\n" " _?_:_(\n" " _&&_(\n" " __comprehension__(\n" " // Variable\n" " z,\n" " // Target\n" " y^#4:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " false^#11:bool#,\n" " // LoopCondition\n" " @not_strictly_false(\n" " !_(\n" " @result^#12:Expr.Ident#\n" " )^#13:Expr.Call#\n" " )^#14:Expr.Call#,\n" " // LoopStep\n" " _||_(\n" " @result^#15:Expr.Ident#,\n" " z^#8:Expr.Ident#.a~test-only~^#10:Expr.Select#\n" " )^#16:Expr.Call#,\n" " // Result\n" " @result^#17:Expr.Ident#)^#18:Expr.Comprehension#,\n" " __comprehension__(\n" " // Variable\n" " z,\n" " // Target\n" " y^#19:Expr.Ident#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " false^#26:bool#,\n" " // LoopCondition\n" " @not_strictly_false(\n" " !_(\n" " @result^#27:Expr.Ident#\n" " )^#28:Expr.Call#\n" " )^#29:Expr.Call#,\n" " // LoopStep\n" " _||_(\n" " @result^#30:Expr.Ident#,\n" " z^#23:Expr.Ident#.b~test-only~^#25:Expr.Select#\n" " )^#31:Expr.Call#,\n" " // Result\n" " @result^#32:Expr.Ident#)^#33:Expr.Comprehension#\n" " )^#34:Expr.Call#,\n" " _+_(\n" " @result^#37:Expr.Ident#,\n" " [\n" " y^#3:Expr.Ident#\n" " ]^#38:Expr.CreateList#\n" " )^#39:Expr.Call#,\n" " @result^#40:Expr.Ident#\n" " )^#41:Expr.Call#,\n" " // Result\n" " @result^#42:Expr.Ident#)^#43:Expr.Comprehension#", "", "", "", "x^#1:Expr.Ident#.filter(\n" " y^#3:Expr.Ident#,\n" " _&&_(\n" " ^#18:exists#,\n" " ^#33:exists#\n" " )^#34:Expr.Call#\n" ")^#43:filter#,\n" "y^#19:Expr.Ident#.exists(\n" " z^#21:Expr.Ident#,\n" " ^#25:has#\n" ")^#33:exists#,\n" "has(\n" " z^#23:Expr.Ident#.b^#24:Expr.Select#\n" ")^#25:has#,\n" "y^#4:Expr.Ident#." "exists(\n" " z^#6:Expr.Ident#,\n" " ^#10:has#\n" ")^#18:exists#,\n" "has(\n" " z^#8:Expr.Ident#.a^#9:Expr.Select#\n" ")^#10:has"}, {"has(a.b).asList().exists(c, c)", "__comprehension__(\n" " // Variable\n" " c,\n" " // Target\n" " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#.asList()^#5:Expr.Call#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " false^#9:bool#,\n" " // LoopCondition\n" " @not_strictly_false(\n" " !_(\n" " @result^#10:Expr.Ident#\n" " )^#11:Expr.Call#\n" " )^#12:Expr.Call#,\n" " // LoopStep\n" " _||_(\n" " @result^#13:Expr.Ident#,\n" " c^#8:Expr.Ident#\n" " )^#14:Expr.Call#,\n" " // Result\n" " @result^#15:Expr.Ident#)^#16:Expr.Comprehension#", "", "", "", "^#4:has#.asList()^#5:Expr.Call#.exists(\n" " c^#7:Expr.Ident#,\n" " c^#8:Expr.Ident#\n" ")^#16:exists#,\n" "has(\n" " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" ")^#4:has"}, {"[has(a.b), has(c.d)].exists(e, e)", "__comprehension__(\n" " // Variable\n" " e,\n" " // Target\n" " [\n" " a^#3:Expr.Ident#.b~test-only~^#5:Expr.Select#,\n" " c^#7:Expr.Ident#.d~test-only~^#9:Expr.Select#\n" " ]^#1:Expr.CreateList#,\n" " // Accumulator\n" " @result,\n" " // Init\n" " false^#13:bool#,\n" " // LoopCondition\n" " @not_strictly_false(\n" " !_(\n" " @result^#14:Expr.Ident#\n" " )^#15:Expr.Call#\n" " )^#16:Expr.Call#,\n" " // LoopStep\n" " _||_(\n" " @result^#17:Expr.Ident#,\n" " e^#12:Expr.Ident#\n" " )^#18:Expr.Call#,\n" " // Result\n" " @result^#19:Expr.Ident#)^#20:Expr.Comprehension#", "", "", "", "[\n" " ^#5:has#,\n" " ^#9:has#\n" "]^#1:Expr.CreateList#.exists(\n" " e^#11:Expr.Ident#,\n" " e^#12:Expr.Ident#\n" ")^#20:exists#,\n" "has(\n" " c^#7:Expr.Ident#.d^#8:Expr.Select#\n" ")^#9:has#,\n" "has(\n" " a^#3:Expr.Ident#.b^#4:Expr.Select#\n" ")^#5:has"}, {"b'\\UFFFFFFFF'", "", "ERROR: :1:1: Invalid bytes literal: Illegal escape sequence: " "Unicode escape sequence \\U cannot be used in bytes literals\n | " "b'\\UFFFFFFFF'\n | ^"}, {"a.?b[?0] && a[?c]", "_&&_(\n _[?_](\n _?._(\n a^#1:Expr.Ident#,\n " "\"b\"^#3:string#\n )^#2:Expr.Call#,\n 0^#5:int64#\n " ")^#4:Expr.Call#,\n _[?_](\n a^#6:Expr.Ident#,\n " "c^#8:Expr.Ident#\n )^#7:Expr.Call#\n)^#9:Expr.Call#"}, {"{?'key': value}", "{\n " "?\"key\"^#3:string#:value^#4:Expr.Ident#^#2:Expr.CreateStruct.Entry#\n}^#" "1:Expr.CreateStruct#"}, {"[?a, ?b]", "[\n ?a^#2:Expr.Ident#,\n ?b^#3:Expr.Ident#\n]^#1:Expr.CreateList#"}, {"[?a[?b]]", "[\n ?_[?_](\n a^#2:Expr.Ident#,\n b^#4:Expr.Ident#\n " ")^#3:Expr.Call#\n]^#1:Expr.CreateList#"}, {"Msg{?field: value}", "Msg{\n " "?field:value^#3:Expr.Ident#^#2:Expr.CreateStruct.Entry#\n}^#1:Expr." "CreateStruct#"}, {"m.optMap(v, f)", "_?_:_(\n m^#1:Expr.Ident#.hasValue()^#6:Expr.Call#,\n optional.of(\n " " __comprehension__(\n // Variable\n #unused,\n // " "Target\n []^#7:Expr.CreateList#,\n // Accumulator\n v,\n " " // Init\n m^#5:Expr.Ident#.value()^#8:Expr.Call#,\n // " "LoopCondition\n false^#9:bool#,\n // LoopStep\n " "v^#3:Expr.Ident#,\n // Result\n " "f^#4:Expr.Ident#)^#10:Expr.Comprehension#\n )^#11:Expr.Call#,\n " "optional.none()^#12:Expr.Call#\n)^#13:Expr.Call#"}, {"m.optFlatMap(v, f)", "_?_:_(\n m^#1:Expr.Ident#.hasValue()^#6:Expr.Call#,\n " "__comprehension__(\n // Variable\n #unused,\n // Target\n " "[]^#7:Expr.CreateList#,\n // Accumulator\n v,\n // Init\n " "m^#5:Expr.Ident#.value()^#8:Expr.Call#,\n // LoopCondition\n " "false^#9:bool#,\n // LoopStep\n v^#3:Expr.Ident#,\n // Result\n " " f^#4:Expr.Ident#)^#10:Expr.Comprehension#,\n " "optional.none()^#11:Expr.Call#\n)^#12:Expr.Call#"}}; absl::string_view ConstantKind(const cel::Constant& c) { switch (c.kind_case()) { case ConstantKindCase::kBool: return "bool"; case ConstantKindCase::kInt: return "int64"; case ConstantKindCase::kUint: return "uint64"; case ConstantKindCase::kDouble: return "double"; case ConstantKindCase::kString: return "string"; case ConstantKindCase::kBytes: return "bytes"; case ConstantKindCase::kNull: return "NullValue"; default: return "unspecified_constant"; } } absl::string_view ExprKind(const cel::Expr& e) { switch (e.kind_case()) { case ExprKindCase::kConstant: // special cased, this doesn't appear. return "Expr.Constant"; case ExprKindCase::kIdentExpr: return "Expr.Ident"; case ExprKindCase::kSelectExpr: return "Expr.Select"; case ExprKindCase::kCallExpr: return "Expr.Call"; case ExprKindCase::kListExpr: return "Expr.CreateList"; case ExprKindCase::kMapExpr: case ExprKindCase::kStructExpr: return "Expr.CreateStruct"; case ExprKindCase::kComprehensionExpr: return "Expr.Comprehension"; default: return "unspecified_expr"; } } class KindAndIdAdorner : public cel::test::ExpressionAdorner { public: // Use default source_info constructor to make source_info "optional". This // will prevent macro_calls lookups from interfering with adorning expressions // that don't need to use macro_calls, such as the parsed AST. explicit KindAndIdAdorner( const cel::expr::SourceInfo& source_info = cel::expr::SourceInfo::default_instance()) : source_info_(source_info) {} std::string Adorn(const cel::Expr& e) const override { // source_info_ might be empty on non-macro_calls tests if (source_info_.macro_calls_size() != 0 && source_info_.macro_calls().contains(e.id())) { return absl::StrFormat( "^#%d:%s#", e.id(), source_info_.macro_calls().at(e.id()).call_expr().function()); } if (e.has_const_expr()) { auto& const_expr = e.const_expr(); return absl::StrCat("^#", e.id(), ":", ConstantKind(const_expr), "#"); } else { return absl::StrCat("^#", e.id(), ":", ExprKind(e), "#"); } } std::string AdornStructField(const cel::StructExprField& e) const override { return absl::StrFormat("^#%d:Expr.CreateStruct.Entry#", e.id()); } std::string AdornMapEntry(const cel::MapExprEntry& e) const override { return absl::StrFormat("^#%d:Expr.CreateStruct.Entry#", e.id()); } private: const cel::expr::SourceInfo& source_info_; }; class LocationAdorner : public cel::test::ExpressionAdorner { public: explicit LocationAdorner(const cel::expr::SourceInfo& source_info) : source_info_(source_info) {} std::string Adorn(const cel::Expr& e) const override { return LocationToString(e.id()); } std::string AdornStructField(const cel::StructExprField& e) const override { return LocationToString(e.id()); } std::string AdornMapEntry(const cel::MapExprEntry& e) const override { return LocationToString(e.id()); } private: std::string LocationToString(int64_t id) const { auto loc = GetLocation(id); if (loc) { return absl::StrFormat("^#%d[%d,%d]#", id, loc->first, loc->second); } else { return absl::StrFormat("^#%d[NO_POS]#", id); } } absl::optional> GetLocation(int64_t id) const { absl::optional> location; const auto& positions = source_info_.positions(); if (positions.find(id) == positions.end()) { return location; } int32_t pos = positions.at(id); int32_t line = 1; for (int i = 0; i < source_info_.line_offsets_size(); ++i) { if (source_info_.line_offsets(i) > pos) { break; } else { line += 1; } } int32_t col = pos; if (line > 1) { col = pos - source_info_.line_offsets(line - 2); } return std::make_pair(line, col); } const cel::expr::SourceInfo& source_info_; }; std::string ConvertEnrichedSourceInfoToString( const EnrichedSourceInfo& enriched_source_info) { std::vector offsets; for (const auto& offset : enriched_source_info.offsets()) { offsets.push_back(absl::StrFormat( "[%d,%d,%d]", offset.first, offset.second.first, offset.second.second)); } return absl::StrJoin(offsets, "^#"); } std::string ConvertMacroCallsToString( const cel::expr::SourceInfo& source_info) { KindAndIdAdorner macro_calls_adorner(source_info); ExprPrinter w(macro_calls_adorner); // Use a list so we can sort the macro calls ensuring order for appending std::vector> macro_calls; for (auto pair : source_info.macro_calls()) { // Set ID to the map key for the adorner pair.second.set_id(pair.first); macro_calls.push_back(pair); } // Sort in reverse because the first macro will have the highest id absl::c_sort(macro_calls, [](const std::pair& p1, const std::pair& p2) { return p1.first > p2.first; }); std::string result = ""; for (const auto& pair : macro_calls) { result += w.PrintProto(pair.second) += ",\n"; } // substring last ",\n" return result.substr(0, result.size() - 3); } class ExpressionTest : public testing::TestWithParam {}; TEST_P(ExpressionTest, Parse) { const TestInfo& test_info = GetParam(); ParserOptions options; if (!test_info.M.empty()) { options.add_macro_calls = true; } options.enable_optional_syntax = true; options.enable_quoted_identifiers = true; std::vector macros = Macro::AllMacros(); macros.push_back(cel::OptMapMacro()); macros.push_back(cel::OptFlatMapMacro()); auto result = EnrichedParse(test_info.I, macros, "", options); if (test_info.E.empty()) { ASSERT_THAT(result, IsOk()); } else { EXPECT_THAT(result, Not(IsOk())); EXPECT_EQ(test_info.E, result.status().message()); } if (!test_info.P.empty()) { KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); EXPECT_EQ(test_info.P, adorned_string) << result->parsed_expr().ShortDebugString(); } if (!test_info.L.empty()) { LocationAdorner location_adorner(result->parsed_expr().source_info()); ExprPrinter w(location_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); EXPECT_EQ(test_info.L, adorned_string) << result->parsed_expr().ShortDebugString(); ; } if (!test_info.R.empty()) { EXPECT_EQ(test_info.R, ConvertEnrichedSourceInfoToString( result->enriched_source_info())); } if (!test_info.M.empty()) { EXPECT_EQ(test_info.M, ConvertMacroCallsToString( result.value().parsed_expr().source_info())) << result->parsed_expr().ShortDebugString(); ; } } TEST(ExpressionTest, TsanOom) { Parse( "[[a([[???[a[[??[a([[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" "[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[???[" "a([[????") .IgnoreError(); } TEST(ExpressionTest, ErrorRecoveryLimits) { ParserOptions options; options.error_recovery_limit = 1; auto result = Parse("......", "", options); EXPECT_THAT(result, Not(IsOk())); EXPECT_EQ(result.status().message(), "ERROR: :1:1: Syntax error: More than 1 parse errors.\n | ......\n " "| ^\nERROR: :1:2: Syntax error: no viable alternative at input " "'..'\n | ......\n | .^"); } TEST(ExpressionTest, ExpressionSizeLimit) { ParserOptions options; options.expression_size_codepoint_limit = 10; auto result = Parse("...............", "", options); EXPECT_THAT(result, Not(IsOk())); EXPECT_EQ( result.status().message(), "expression size exceeds codepoint limit. input size: 15, limit: 10"); } TEST(ExpressionTest, RecursionDepthLongArgList) { ParserOptions options; // The particular number here is an implementation detail: the underlying // visitor will recurse up to 8 times before branching to the create list or // const steps. The call graph looks something like: // visit->visitStart->visit->visitExpr->visit->visitOr->visit->visitAnd->visit // ->visitRelation->visit->visitCalc->visit->visitUnary->visit->visitPrimary // ->visitCreateList->visit[arg]->visitExpr... // The expected max depth for create list with an arbitrary number of elements // is 15. options.max_recursion_depth = 16; EXPECT_THAT(Parse("[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]", "", options), IsOk()); } TEST(ExpressionTest, RecursionDepthExceeded) { ParserOptions options; // AST visitor will recurse a variable amount depending on the terms used in // the expression. This check occurs in the business logic converting the raw // Antlr parse tree into an Expr. There is a separate check (via a custom // listener) for AST depth while running the antlr generated parser. options.max_recursion_depth = 6; auto result = Parse("1 + 2 + 3 + 4 + 5 + 6 + 7", "", options); EXPECT_THAT(result, Not(IsOk())); EXPECT_THAT(result.status().message(), HasSubstr("Exceeded max recursion depth of 6 when parsing.")); } TEST(ExpressionTest, DisableQuotedIdentifiers) { ParserOptions options; options.enable_quoted_identifiers = false; auto result = Parse("foo.`bar`", "", options); EXPECT_THAT(result, Not(IsOk())); EXPECT_THAT(result.status().message(), HasSubstr("ERROR: :1:5: unsupported syntax '`'\n" " | foo.`bar`\n" " | ....^")); } TEST(ExpressionTest, DisableStandardMacros) { ParserOptions options; options.disable_standard_macros = true; auto result = Parse("has(foo.bar)", "", options); ASSERT_THAT(result, IsOk()); KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); std::string adorned_string = w.PrintProto(result->expr()); EXPECT_EQ(adorned_string, "has(\n" " foo^#2:Expr.Ident#.bar^#3:Expr.Select#\n" ")^#1:Expr.Call#") << adorned_string; } TEST(ExpressionTest, RecursionDepthIgnoresParentheses) { ParserOptions options; options.max_recursion_depth = 6; auto result = Parse("(((1 + 2 + 3 + 4 + (5 + 6))))", "", options); EXPECT_THAT(result, IsOk()); } TEST(NewParserBuilderTest, Defaults) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [].exists(x, x > 0)")); ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); EXPECT_FALSE(ast->IsChecked()); } TEST(NewParserBuilderTest, CustomMacros) { auto builder = cel::NewParserBuilder(); builder->GetOptions().disable_standard_macros = true; ASSERT_THAT(builder->AddMacro(cel::HasMacro()), IsOk()); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); builder.reset(); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [].map(x, x)")); ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); EXPECT_FALSE(ast->IsChecked()); KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); EXPECT_EQ(w.Print(ast->root_expr()), "_&&_(\n" " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" " []^#5:Expr.CreateList#.map(\n" " x^#7:Expr.Ident#,\n" " x^#8:Expr.Ident#\n" " )^#6:Expr.Call#\n" ")^#9:Expr.Call#"); } TEST(NewParserBuilderTest, StandardMacrosNotAddedWithStdlib) { auto builder = cel::NewParserBuilder(); builder->GetOptions().disable_standard_macros = false; // Add a fake stdlib to check that we don't try to add the standard macros // again. Emulates what happens when we add support for subsetting stdlib by // ids. ASSERT_THAT(builder->AddLibrary({"stdlib", [](cel::ParserBuilder& b) { return b.AddMacro(cel::HasMacro()); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); builder.reset(); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [].map(x, x)")); ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); EXPECT_FALSE(ast->IsChecked()); KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); EXPECT_EQ(w.Print(ast->root_expr()), "_&&_(\n" " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" " []^#5:Expr.CreateList#.map(\n" " x^#7:Expr.Ident#,\n" " x^#8:Expr.Ident#\n" " )^#6:Expr.Call#\n" ")^#9:Expr.Call#"); } TEST(NewParserBuilderTest, ForwardsOptions) { auto builder = cel::NewParserBuilder(); builder->GetOptions().enable_optional_syntax = true; ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a.?b")); ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); EXPECT_FALSE(ast->IsChecked()); builder = cel::NewParserBuilder(); builder->GetOptions().enable_optional_syntax = false; ASSERT_OK_AND_ASSIGN(parser, std::move(*builder).Build()); ASSERT_OK_AND_ASSIGN(source, cel::NewSource("a.?b")); EXPECT_THAT(parser->Parse(*source), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(NewParserBuilderTest, ToBuilderCopiesConfig) { auto builder = cel::NewParserBuilder(); builder->GetOptions().enable_optional_syntax = true; builder->GetOptions().disable_standard_macros = true; ASSERT_THAT(builder->AddLibrary({"custom_lib", [](cel::ParserBuilder& b) { return b.AddMacro(cel::HasMacro()); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); auto derived_builder = parser->ToBuilder(); EXPECT_TRUE(derived_builder->GetOptions().enable_optional_syntax); ASSERT_OK_AND_ASSIGN(auto derived_parser, std::move(*derived_builder).Build()); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a.?b && has(a.b)")); ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); EXPECT_FALSE(ast->IsChecked()); } TEST(NewParserBuilderTest, ToBuilderHandlesStdlibAndOptionalByLibrary) { auto builder = cel::NewParserBuilder(); builder->GetOptions().disable_standard_macros = true; builder->GetOptions().enable_optional_syntax = false; // Abusing the library ids for testing. Real uses should use subsetting. ASSERT_THAT( builder->AddLibrary( {"stdlib", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), IsOk()); ASSERT_THAT( builder->AddLibrary( {"optional", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); auto derived_builder = parser->ToBuilder(); // Should be ignored now. derived_builder->GetOptions().disable_standard_macros = false; derived_builder->GetOptions().enable_optional_syntax = true; ASSERT_OK_AND_ASSIGN(auto derived_parser, std::move(*derived_builder).Build()); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b)")); ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); EXPECT_EQ(w.Print(ast->root_expr()), "has(\n" " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" ")^#1:Expr.Call#"); } TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { auto builder = cel::NewParserBuilder(); builder->GetOptions().disable_standard_macros = false; builder->GetOptions().enable_optional_syntax = true; ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); auto derived_builder = parser->ToBuilder(); ASSERT_OK_AND_ASSIGN(auto derived_parser, std::move(*derived_builder).Build()); ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [?a]")); ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); EXPECT_FALSE(ast->IsChecked()); } std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); return name; return name; } INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, testing::ValuesIn(test_cases), TestName); } // namespace } // namespace google::api::expr::parser ================================================ FILE: parser/source_factory.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_SOURCE_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_PARSER_SOURCE_FACTORY_H_ #include #include #include namespace google::api::expr::parser { class EnrichedSourceInfo { public: explicit EnrichedSourceInfo( std::map> offsets) : offsets_(std::move(offsets)) {} EnrichedSourceInfo() = default; EnrichedSourceInfo(const EnrichedSourceInfo& other) = default; EnrichedSourceInfo& operator=(const EnrichedSourceInfo& other) = default; EnrichedSourceInfo(EnrichedSourceInfo&& other) = default; EnrichedSourceInfo& operator=(EnrichedSourceInfo&& other) = default; const std::map>& offsets() const { return offsets_; } private: // A map between node_id and pair of start position and end position std::map> offsets_; }; } // namespace google::api::expr::parser #endif // THIRD_PARTY_CEL_CPP_PARSER_SOURCE_FACTORY_H_ ================================================ FILE: parser/standard_macros.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/standard_macros.h" #include "absl/status/status.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "parser/macro_registry.h" #include "parser/options.h" namespace cel { absl::Status RegisterStandardMacros(MacroRegistry& registry, const ParserOptions& options) { CEL_RETURN_IF_ERROR(registry.RegisterMacro(HasMacro())); CEL_RETURN_IF_ERROR(registry.RegisterMacro(AllMacro())); CEL_RETURN_IF_ERROR(registry.RegisterMacro(ExistsMacro())); CEL_RETURN_IF_ERROR(registry.RegisterMacro(ExistsOneMacro())); CEL_RETURN_IF_ERROR(registry.RegisterMacro(Map2Macro())); CEL_RETURN_IF_ERROR(registry.RegisterMacro(Map3Macro())); CEL_RETURN_IF_ERROR(registry.RegisterMacro(FilterMacro())); if (options.enable_optional_syntax) { CEL_RETURN_IF_ERROR(registry.RegisterMacro(OptMapMacro())); CEL_RETURN_IF_ERROR(registry.RegisterMacro(OptFlatMapMacro())); } return absl::OkStatus(); } } // namespace cel ================================================ FILE: parser/standard_macros.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_PARSER_STANDARD_MACROS_H_ #define THIRD_PARTY_CEL_CPP_PARSER_STANDARD_MACROS_H_ #include "absl/status/status.h" #include "parser/macro_registry.h" #include "parser/options.h" namespace cel { // Registers the standard macros defined by the Common Expression Language. // https://github.com/google/cel-spec/blob/master/doc/langdef.md#macros absl::Status RegisterStandardMacros(MacroRegistry& registry, const ParserOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_STANDARD_MACROS_H_ ================================================ FILE: parser/standard_macros_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "parser/standard_macros.h" #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "common/source.h" #include "internal/testing.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::google::api::expr::parser::EnrichedParse; using ::testing::HasSubstr; struct StandardMacrosTestCase { std::string expression; std::string error; }; using StandardMacrosTest = ::testing::TestWithParam; TEST_P(StandardMacrosTest, Errors) { const auto& test_param = GetParam(); ASSERT_OK_AND_ASSIGN(auto source, NewSource(test_param.expression)); ParserOptions options; options.enable_optional_syntax = true; MacroRegistry registry; ASSERT_THAT(RegisterStandardMacros(registry, options), IsOk()); EXPECT_THAT(EnrichedParse(*source, registry, options), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(test_param.error))); } INSTANTIATE_TEST_SUITE_P( StandardMacrosTest, StandardMacrosTest, ::testing::ValuesIn({ { .expression = "[].all(__result__, __result__ == 0)", .error = "variable name cannot be __result__", }, { .expression = "[].exists(__result__, __result__ == 0)", .error = "variable name cannot be __result__", }, { .expression = "[].exists_one(__result__, __result__ == 0)", .error = "variable name cannot be __result__", }, { .expression = "[].map(__result__, __result__)", .error = "variable name cannot be __result__", }, { .expression = "[].map(__result__, true, __result__)", .error = "variable name cannot be __result__", }, { .expression = "[].filter(__result__, __result__ == 0)", .error = "variable name cannot be __result__", }, { .expression = "foo.optMap(__result__, __result__)", .error = "variable name cannot be __result__", }, { .expression = "foo.optFlatMap(__result__, __result__)", .error = "variable name cannot be __result__", }, })); } // namespace } // namespace cel ================================================ FILE: runtime/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package( # Under active development, not yet being released. default_visibility = ["//visibility:public"], ) licenses(["notice"]) cc_library( name = "activation_interface", hdrs = ["activation_interface.h"], deps = [ ":function_overload_reference", "//base:attributes", "//common:value", "//internal:status_macros", "//runtime/internal:attribute_matcher", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "function_overload_reference", hdrs = ["function_overload_reference.h"], deps = [ ":function", "//common:function_descriptor", ], ) cc_library( name = "function_provider", hdrs = ["function_provider.h"], deps = [ ":activation_interface", ":function_overload_reference", "//common:function_descriptor", "@com_google_absl//absl/status:statusor", ], ) cc_library( name = "activation", srcs = ["activation.cc"], hdrs = ["activation.h"], deps = [ ":activation_interface", ":function", ":function_overload_reference", "//base:attributes", "//common:function_descriptor", "//common:value", "//internal:status_macros", "//runtime/internal:attribute_matcher", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "activation_test", srcs = ["activation_test.cc"], deps = [ ":activation", ":function", ":function_overload_reference", "//base:attributes", "//common:function_descriptor", "//common:value", "//common:value_testing", "//internal:testing", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "register_function_helper", hdrs = ["register_function_helper.h"], deps = [ ":function_registry", "//common:function_descriptor", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) cc_library( name = "function_registry", srcs = ["function_registry.cc"], hdrs = ["function_registry.h"], deps = [ ":activation_interface", ":function", ":function_overload_reference", ":function_provider", "//common:function_descriptor", "//common:kind", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) cc_test( name = "function_registry_test", srcs = ["function_registry_test.cc"], deps = [ ":activation", ":function", ":function_adapter", ":function_overload_reference", ":function_provider", ":function_registry", "//common:function_descriptor", "//common:kind", "//common:value", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) cc_library( name = "runtime_options", hdrs = ["runtime_options.h"], deps = ["@com_google_absl//absl/base:core_headers"], ) cc_library( name = "type_registry", srcs = ["type_registry.cc"], hdrs = ["type_registry.h"], deps = [ "//base:data", "//common:type", "//common:value", "//runtime/internal:legacy_runtime_type_provider", "//runtime/internal:runtime_type_provider", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "runtime", hdrs = ["runtime.h"], deps = [ ":activation_interface", ":runtime_issue", "//base:ast", "//base:data", "//common:native_type", "//common:value", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "runtime_builder", hdrs = ["runtime_builder.h"], deps = [ ":function_registry", ":runtime", ":runtime_options", ":type_registry", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "runtime_builder_factory", srcs = ["runtime_builder_factory.cc"], hdrs = ["runtime_builder_factory.h"], deps = [ ":runtime_builder", ":runtime_options", "//internal:noop_delete", "//internal:status_macros", "//runtime/internal:runtime_env", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "standard_runtime_builder_factory", srcs = ["standard_runtime_builder_factory.cc"], hdrs = ["standard_runtime_builder_factory.h"], deps = [ ":runtime_builder", ":runtime_builder_factory", ":runtime_options", ":standard_functions", "//internal:noop_delete", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "standard_runtime_builder_factory_test", srcs = ["standard_runtime_builder_factory_test.cc"], deps = [ ":activation", ":runtime", ":runtime_issue", ":runtime_options", ":standard_runtime_builder_factory", "//base:builtins", "//common:source", "//common:value", "//common:value_testing", "//extensions:bindings_ext", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//parser", "//parser:macro_registry", "//parser:standard_macros", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "standard_functions", srcs = ["standard_functions.cc"], hdrs = ["standard_functions.h"], deps = [ ":function_registry", ":runtime_options", "//internal:status_macros", "//runtime/standard:arithmetic_functions", "//runtime/standard:comparison_functions", "//runtime/standard:container_functions", "//runtime/standard:container_membership_functions", "//runtime/standard:equality_functions", "//runtime/standard:logical_functions", "//runtime/standard:regex_functions", "//runtime/standard:string_functions", "//runtime/standard:time_functions", "//runtime/standard:type_conversion_functions", "@com_google_absl//absl/status", ], ) cc_library( name = "constant_folding", srcs = ["constant_folding.cc"], hdrs = ["constant_folding.h"], deps = [ ":runtime", ":runtime_builder", "//common:typeinfo", "//eval/compiler:constant_folding", "//internal:casts", "//internal:noop_delete", "//internal:status_macros", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "constant_folding_test", srcs = ["constant_folding_test.cc"], deps = [ ":activation", ":constant_folding", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//base:function_adapter", "//common:function_descriptor", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "regex_precompilation", srcs = ["regex_precompilation.cc"], hdrs = ["regex_precompilation.h"], deps = [ ":runtime", ":runtime_builder", "//common:native_type", "//eval/compiler:regex_precompilation_optimization", "//internal:casts", "//internal:status_macros", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_test( name = "regex_precompilation_test", srcs = ["regex_precompilation_test.cc"], deps = [ ":activation", ":constant_folding", ":regex_precompilation", ":register_function_helper", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//base:function_adapter", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "reference_resolver", srcs = ["reference_resolver.cc"], hdrs = ["reference_resolver.h"], deps = [ ":runtime", ":runtime_builder", "//common:native_type", "//eval/compiler:qualified_reference_resolver", "//internal:casts", "//internal:status_macros", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_test( name = "reference_resolver_test", srcs = ["reference_resolver_test.cc"], deps = [ ":activation", ":reference_resolver", ":register_function_helper", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//base:function_adapter", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "runtime_issue", hdrs = ["runtime_issue.h"], deps = ["@com_google_absl//absl/status"], ) cc_library( name = "comprehension_vulnerability_check", srcs = ["comprehension_vulnerability_check.cc"], hdrs = ["comprehension_vulnerability_check.h"], deps = [ ":runtime", ":runtime_builder", "//common:native_type", "//eval/compiler:comprehension_vulnerability_check", "//internal:casts", "//internal:status_macros", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) cc_test( name = "comprehension_vulnerability_check_test", srcs = ["comprehension_vulnerability_check_test.cc"], deps = [ ":comprehension_vulnerability_check", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "function_adapter", hdrs = ["function_adapter.h"], deps = [ ":function", ":register_function_helper", "//common:function_descriptor", "//common:value", "//internal:status_macros", "//runtime/internal:function_adapter", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "function_adapter_test", srcs = ["function_adapter_test.cc"], deps = [ ":function", ":function_adapter", "//common:function_descriptor", "//common:kind", "//common:value", "//common:value_testing", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", ], ) cc_library( name = "optional_types", srcs = ["optional_types.cc"], hdrs = ["optional_types.h"], deps = [ ":function_registry", ":runtime_builder", ":runtime_options", "//base:function_adapter", "//common:casting", "//common:type", "//common:value", "//internal:casts", "//internal:number", "//internal:status_macros", "//runtime/internal:errors", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "optional_types_test", srcs = ["optional_types_test.cc"], deps = [ ":activation", ":function", ":optional_types", ":reference_resolver", ":runtime", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//common:function_descriptor", "//common:kind", "//common:value", "//common:value_testing", "//extensions/protobuf:runtime_adapter", "//internal:testing", "//internal:testing_descriptor_pool", "//parser", "//parser:options", "//runtime/internal:runtime_impl", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "function", hdrs = [ "function.h", ], deps = [ "//common:value", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "memory_safety_test", srcs = ["memory_safety_test.cc"], deps = [ ":activation", ":constant_folding", ":function_adapter", ":reference_resolver", ":regex_precompilation", ":runtime", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//checker:validation_result", "//common:decl", "//common:type", "//common:value", "//common:value_testing", "//compiler", "//compiler:compiler_factory", "//compiler:optional", "//compiler:standard_library", "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_protobuf//:any_cc_proto", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "embedder_context", hdrs = ["embedder_context.h"], deps = [ "//common:typeinfo", "//common:value", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/types:optional", ], ) cc_test( name = "embedder_context_test", srcs = ["embedder_context_test.cc"], deps = [ ":embedder_context", "//common:typeinfo", "//internal:testing", "@com_google_absl//absl/types:optional", ], ) ================================================ FILE: runtime/activation.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/activation.h" #include #include #include #include "absl/base/macros.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "common/function_descriptor.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { absl::StatusOr Activation::FindVariable( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); ABSL_DCHECK(result != nullptr); auto iter = values_.find(name); if (iter == values_.end()) { return false; } const ValueEntry& entry = iter->second; if (entry.provider.has_value()) { return ProvideValue(name, descriptor_pool, message_factory, arena, result); } if (entry.value.has_value()) { *result = *entry.value; return true; } return false; } absl::StatusOr Activation::ProvideValue( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { absl::MutexLock lock(mutex_); auto iter = values_.find(name); ABSL_ASSERT(iter != values_.end()); ValueEntry& entry = iter->second; if (entry.value.has_value()) { *result = *entry.value; return true; } CEL_ASSIGN_OR_RETURN( auto provided, (*entry.provider)(name, descriptor_pool, message_factory, arena)); if (provided.has_value()) { entry.value = std::move(provided); *result = *entry.value; return true; } return false; } std::vector Activation::FindFunctionOverloads( absl::string_view name) const { std::vector result; auto iter = functions_.find(name); if (iter != functions_.end()) { const std::vector& overloads = iter->second; result.reserve(overloads.size()); for (const auto& overload : overloads) { result.push_back({*overload.descriptor, *overload.implementation}); } } return result; } bool Activation::InsertOrAssignValue(absl::string_view name, Value value) { return values_ .insert_or_assign(name, ValueEntry{std::move(value), absl::nullopt}) .second; } bool Activation::InsertOrAssignValueProvider(absl::string_view name, ValueProvider provider) { return values_ .insert_or_assign(name, ValueEntry{absl::nullopt, std::move(provider)}) .second; } bool Activation::InsertFunction(const cel::FunctionDescriptor& descriptor, std::unique_ptr impl) { auto& overloads = functions_[descriptor.name()]; for (auto& overload : overloads) { if (overload.descriptor->ShapeMatches(descriptor)) { return false; } } overloads.push_back( {std::make_unique(descriptor), std::move(impl)}); return true; } Activation::Activation(Activation&& other) { using std::swap; swap(*this, other); } Activation& Activation::operator=(Activation&& other) { using std::swap; Activation tmp(std::move(other)); swap(*this, tmp); return *this; } } // namespace cel ================================================ FILE: runtime/activation.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_H_ #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/function_descriptor.h" #include "common/value.h" #include "runtime/activation_interface.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "runtime/internal/attribute_matcher.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace runtime_internal { class ActivationAttributeMatcherAccess; } // Thread-compatible implementation of a CEL Activation. // // Values can either be provided eagerly or via a provider. class Activation final : public ActivationInterface { public: // Definition for value providers. using ValueProvider = absl::AnyInvocable>( absl::string_view, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull)>; Activation() = default; // Move only. Activation(Activation&& other); Activation& operator=(Activation&& other); // Implements ActivationInterface. absl::StatusOr FindVariable( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const override; using ActivationInterface::FindVariable; std::vector FindFunctionOverloads( absl::string_view name) const override; absl::Span GetUnknownAttributes() const override { return unknown_patterns_; } absl::Span GetMissingAttributes() const override { return missing_patterns_; } // Bind a value to a named variable. // // Returns false if the entry for name was overwritten. bool InsertOrAssignValue(absl::string_view name, Value value); // Bind a provider to a named variable. The result of the provider may be // memoized by the activation. // // Returns false if the entry for name was overwritten. bool InsertOrAssignValueProvider(absl::string_view name, ValueProvider provider); void AddUnknownPattern(cel::AttributePattern pattern) { unknown_patterns_.push_back(std::move(pattern)); } void SetUnknownPatterns(std::vector patterns) { unknown_patterns_ = std::move(patterns); } void AddMissingPattern(cel::AttributePattern pattern) { missing_patterns_.push_back(std::move(pattern)); } void SetMissingPatterns(std::vector patterns) { missing_patterns_ = std::move(patterns); } // Returns true if the function was inserted (no other registered function has // a matching descriptor). bool InsertFunction(const cel::FunctionDescriptor& descriptor, std::unique_ptr impl); private: struct ValueEntry { // If provider is present, then access must be synchronized to maintain // thread-compatible semantics for the lazily provided value. absl::optional value; absl::optional provider; }; struct FunctionEntry { std::unique_ptr descriptor; std::unique_ptr implementation; }; friend class runtime_internal::ActivationAttributeMatcherAccess; void SetAttributeMatcher(const runtime_internal::AttributeMatcher* matcher) { attribute_matcher_ = matcher; } void SetAttributeMatcher( std::unique_ptr matcher) { owned_attribute_matcher_ = std::move(matcher); attribute_matcher_ = owned_attribute_matcher_.get(); } const runtime_internal::AttributeMatcher* absl_nullable GetAttributeMatcher() const override { return attribute_matcher_; } friend void swap(Activation& a, Activation& b) { using std::swap; swap(a.values_, b.values_); swap(a.functions_, b.functions_); swap(a.unknown_patterns_, b.unknown_patterns_); swap(a.missing_patterns_, b.missing_patterns_); } // Internal getter for provided values. // Assumes entry for name is present and is a provided value. // Handles synchronization for caching the provided value. absl::StatusOr ProvideValue( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; // mutex_ used for safe caching of provided variables mutable absl::Mutex mutex_; mutable absl::flat_hash_map values_; std::vector unknown_patterns_; std::vector missing_patterns_; const runtime_internal::AttributeMatcher* attribute_matcher_ = nullptr; std::unique_ptr owned_attribute_matcher_; absl::flat_hash_map> functions_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_H_ ================================================ FILE: runtime/activation_interface.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_INTERFACE_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_INTERFACE_H_ #include #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function_overload_reference.h" #include "runtime/internal/attribute_matcher.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace runtime_internal { class ActivationAttributeMatcherAccess; } // namespace runtime_internal // Interface for providing runtime with variable lookups. // // Clients should prefer to use one of the concrete implementations provided by // the CEL library rather than implementing this interface directly. // TODO(uncreated-issue/40): After finalizing, make this public and add instructions // for clients to migrate. class ActivationInterface { public: virtual ~ActivationInterface() = default; // Find value for a string (possibly qualified) variable name. virtual absl::StatusOr FindVariable( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const = 0; absl::StatusOr> FindVariable( absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { Value result; CEL_ASSIGN_OR_RETURN( auto found, FindVariable(name, descriptor_pool, message_factory, arena, &result)); if (found) { return result; } return absl::nullopt; } // Find a set of context function overloads by name. virtual std::vector FindFunctionOverloads( absl::string_view name) const = 0; // Return a list of unknown attribute patterns. // // If an attribute (select path) encountered during evaluation matches any of // the patterns, the value will be treated as unknown and propagated in an // unknown set. // // The returned span must remain valid for the duration of any evaluation // using this this activation. virtual absl::Span GetUnknownAttributes() const = 0; // Return a list of missing attribute patterns. // // If an attribute (select path) encountered during evaluation matches any of // the patterns, the value will be treated as missing and propagated as an // error. // // The returned span must remain valid for the duration of any evaluation // using this activation. virtual absl::Span GetMissingAttributes() const = 0; private: friend class runtime_internal::ActivationAttributeMatcherAccess; // Returns the attribute matcher for this activation. virtual const runtime_internal::AttributeMatcher* absl_nullable GetAttributeMatcher() const { return nullptr; } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_INTERFACE_H_ ================================================ FILE: runtime/activation_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/activation.h" #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "common/function_descriptor.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using testing::ElementsAre; using testing::Eq; using testing::IsEmpty; using testing::Optional; using testing::SizeIs; using testing::Truly; using testing::UnorderedElementsAre; MATCHER_P(IsIntValue, x, absl::StrCat("is IntValue Handle with value ", x)) { const Value& handle = arg; return handle->Is() && handle.GetInt().NativeValue() == x; } MATCHER_P(AttributePatternMatches, val, "matches AttributePattern") { const AttributePattern& pattern = arg; const Attribute& expected = val; return pattern.IsMatch(expected) == AttributePattern::MatchType::FULL; } class FunctionImpl : public cel::Function { public: FunctionImpl() = default; absl::StatusOr Invoke(absl::Span args, const InvokeContext& context) const override { return NullValue(); } }; using ActivationTest = common_internal::ValueTest<>; TEST_F(ActivationTest, ValueNotFound) { Activation activation; EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ActivationTest, InsertValue) { Activation activation; EXPECT_TRUE(activation.InsertOrAssignValue("var1", IntValue(42))); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); } TEST_F(ActivationTest, InsertValueOverwrite) { Activation activation; EXPECT_TRUE(activation.InsertOrAssignValue("var1", IntValue(42))); EXPECT_FALSE(activation.InsertOrAssignValue("var1", IntValue(0))); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(0)))); } TEST_F(ActivationTest, InsertProvider) { Activation activation; EXPECT_TRUE(activation.InsertOrAssignValueProvider( "var1", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) { return IntValue(42); })); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); } TEST_F(ActivationTest, InsertProviderForwardsNotFound) { Activation activation; EXPECT_TRUE(activation.InsertOrAssignValueProvider( "var1", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) { return absl::nullopt; })); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); } TEST_F(ActivationTest, InsertProviderForwardsStatus) { Activation activation; EXPECT_TRUE(activation.InsertOrAssignValueProvider( "var1", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) { return absl::InternalError("test"); })); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), StatusIs(absl::StatusCode::kInternal, "test")); } TEST_F(ActivationTest, ProviderMemoized) { Activation activation; int call_count = 0; EXPECT_TRUE(activation.InsertOrAssignValueProvider( "var1", [&call_count](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) { call_count++; return IntValue(42); })); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); EXPECT_EQ(call_count, 1); } TEST_F(ActivationTest, InsertProviderOverwrite) { Activation activation; EXPECT_TRUE(activation.InsertOrAssignValueProvider( "var1", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) { return IntValue(42); })); EXPECT_FALSE(activation.InsertOrAssignValueProvider( "var1", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) { return IntValue(0); })); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(0)))); } TEST_F(ActivationTest, ValuesAndProvidersShareNamespace) { Activation activation; bool called = false; EXPECT_TRUE(activation.InsertOrAssignValue("var1", IntValue(41))); EXPECT_TRUE(activation.InsertOrAssignValue("var2", IntValue(41))); EXPECT_FALSE(activation.InsertOrAssignValueProvider( "var1", [&called](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) { called = true; return IntValue(42); })); EXPECT_THAT(activation.FindVariable("var1", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); EXPECT_THAT(activation.FindVariable("var2", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(41)))); EXPECT_TRUE(called); } TEST_F(ActivationTest, SetUnknownAttributes) { Activation activation; activation.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), AttributePattern("var1", {AttributeQualifierPattern::OfString("field2")})}); EXPECT_THAT( activation.GetUnknownAttributes(), ElementsAre(AttributePatternMatches(Attribute( "var1", {AttributeQualifier::OfString("field1")})), AttributePatternMatches(Attribute( "var1", {AttributeQualifier::OfString("field2")})))); } TEST_F(ActivationTest, ClearUnknownAttributes) { Activation activation; activation.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), AttributePattern("var1", {AttributeQualifierPattern::OfString("field2")})}); activation.SetUnknownPatterns({}); EXPECT_THAT(activation.GetUnknownAttributes(), IsEmpty()); } TEST_F(ActivationTest, SetMissingAttributes) { Activation activation; activation.SetMissingPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), AttributePattern("var1", {AttributeQualifierPattern::OfString("field2")})}); EXPECT_THAT( activation.GetMissingAttributes(), ElementsAre(AttributePatternMatches(Attribute( "var1", {AttributeQualifier::OfString("field1")})), AttributePatternMatches(Attribute( "var1", {AttributeQualifier::OfString("field2")})))); } TEST_F(ActivationTest, ClearMissingAttributes) { Activation activation; activation.SetMissingPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), AttributePattern("var1", {AttributeQualifierPattern::OfString("field2")})}); activation.SetMissingPatterns({}); EXPECT_THAT(activation.GetMissingAttributes(), IsEmpty()); } TEST_F(ActivationTest, InsertFunctionOk) { Activation activation; EXPECT_TRUE( activation.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kUint}), std::make_unique())); EXPECT_TRUE( activation.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kInt}), std::make_unique())); EXPECT_TRUE( activation.InsertFunction(FunctionDescriptor("Fn2", false, {Kind::kInt}), std::make_unique())); EXPECT_THAT( activation.FindFunctionOverloads("Fn"), UnorderedElementsAre( Truly([](const FunctionOverloadReference& ref) { return ref.descriptor.name() == "Fn" && ref.descriptor.types() == std::vector{Kind::kUint}; }), Truly([](const FunctionOverloadReference& ref) { return ref.descriptor.name() == "Fn" && ref.descriptor.types() == std::vector{Kind::kInt}; }))) << "expected overloads Fn(int), Fn(uint)"; } TEST_F(ActivationTest, InsertFunctionFails) { Activation activation; EXPECT_TRUE( activation.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kAny}), std::make_unique())); EXPECT_FALSE( activation.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kInt}), std::make_unique())); EXPECT_THAT(activation.FindFunctionOverloads("Fn"), ElementsAre(Truly([](const FunctionOverloadReference& ref) { return ref.descriptor.name() == "Fn" && ref.descriptor.types() == std::vector{Kind::kAny}; }))) << "expected overload Fn(any)"; } TEST_F(ActivationTest, MoveAssignment) { Activation moved_from; ASSERT_TRUE( moved_from.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kAny}), std::make_unique())); ASSERT_TRUE(moved_from.InsertOrAssignValue("val", IntValue(42))); ASSERT_TRUE(moved_from.InsertOrAssignValueProvider( "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), AttributePattern("var1", {AttributeQualifierPattern::OfString("field2")})}); moved_from.SetMissingPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), AttributePattern("var1", {AttributeQualifierPattern::OfString("field2")})}); Activation moved_to; moved_to = std::move(moved_from); EXPECT_THAT(moved_to.FindVariable("val", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); EXPECT_THAT(moved_to.FindVariable("val_provided", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); EXPECT_THAT(moved_to.FindFunctionOverloads("Fn"), SizeIs(1)); EXPECT_THAT(moved_to.GetUnknownAttributes(), SizeIs(2)); EXPECT_THAT(moved_to.GetMissingAttributes(), SizeIs(2)); // moved from value is empty. (well defined but not specified state) // NOLINTBEGIN(bugprone-use-after-move) EXPECT_THAT(moved_from.FindVariable("val", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(moved_from.FindVariable("val_provided", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(moved_from.FindFunctionOverloads("Fn"), SizeIs(0)); EXPECT_THAT(moved_from.GetUnknownAttributes(), SizeIs(0)); EXPECT_THAT(moved_from.GetMissingAttributes(), SizeIs(0)); // NOLINTEND(bugprone-use-after-move) } TEST_F(ActivationTest, MoveCtor) { Activation moved_from; ASSERT_TRUE( moved_from.InsertFunction(FunctionDescriptor("Fn", false, {Kind::kAny}), std::make_unique())); ASSERT_TRUE(moved_from.InsertOrAssignValue("val", IntValue(42))); ASSERT_TRUE(moved_from.InsertOrAssignValueProvider( "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), AttributePattern("var1", {AttributeQualifierPattern::OfString("field2")})}); moved_from.SetMissingPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), AttributePattern("var1", {AttributeQualifierPattern::OfString("field2")})}); Activation moved_to = std::move(moved_from); EXPECT_THAT(moved_to.FindVariable("val", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); EXPECT_THAT(moved_to.FindVariable("val_provided", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Optional(IsIntValue(42)))); EXPECT_THAT(moved_to.FindFunctionOverloads("Fn"), SizeIs(1)); EXPECT_THAT(moved_to.GetUnknownAttributes(), SizeIs(2)); EXPECT_THAT(moved_to.GetMissingAttributes(), SizeIs(2)); // moved from value is empty. // NOLINTBEGIN(bugprone-use-after-move) EXPECT_THAT(moved_from.FindVariable("val", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(moved_from.FindVariable("val_provided", descriptor_pool(), message_factory(), arena()), IsOkAndHolds(Eq(absl::nullopt))); EXPECT_THAT(moved_from.FindFunctionOverloads("Fn"), SizeIs(0)); EXPECT_THAT(moved_from.GetUnknownAttributes(), SizeIs(0)); EXPECT_THAT(moved_from.GetMissingAttributes(), SizeIs(0)); // NOLINTEND(bugprone-use-after-move) } } // namespace } // namespace cel ================================================ FILE: runtime/comprehension_vulnerability_check.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/comprehension_vulnerability_check.h" #include "absl/base/macros.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/native_type.h" #include "eval/compiler/comprehension_vulnerability_check.h" #include "internal/casts.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" namespace cel { namespace { using ::cel::internal::down_cast; using ::cel::runtime_internal::RuntimeFriendAccess; using ::cel::runtime_internal::RuntimeImpl; using ::google::api::expr::runtime::CreateComprehensionVulnerabilityCheck; absl::StatusOr RuntimeImplFromBuilder( RuntimeBuilder& builder) { Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); if (RuntimeFriendAccess::RuntimeTypeId(runtime) != NativeTypeId::For()) { return absl::UnimplementedError( "constant folding only supported on the default cel::Runtime " "implementation."); } RuntimeImpl& runtime_impl = down_cast(runtime); return &runtime_impl; } } // namespace absl::Status EnableComprehensionVulnerabiltyCheck( cel::RuntimeBuilder& builder) { CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, RuntimeImplFromBuilder(builder)); ABSL_ASSERT(runtime_impl != nullptr); runtime_impl->expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); return absl::OkStatus(); } } // namespace cel ================================================ FILE: runtime/comprehension_vulnerability_check.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_COMPREHENSION_VULNERABILITY_CHECK_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_COMPREHENSION_VULNERABILITY_CHECK_H_ #include "absl/status/status.h" #include "runtime/runtime_builder.h" namespace cel { // Enable a check for memory vulnerabilities within comprehension // sub-expressions. // // Note: This flag is not necessary if you are only using Core CEL macros. // // Consider enabling this feature when using custom comprehensions, and // absolutely enable the feature when using hand-written ASTs for // comprehension expressions. // // This check is not exhaustive and shouldn't be used with deeply nested ASTs. absl::Status EnableComprehensionVulnerabiltyCheck(RuntimeBuilder& builder); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_COMPREHENSION_VULNERABILITY_CHECK_H_ ================================================ FILE: runtime/comprehension_vulnerability_check_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/comprehension_vulnerability_check.h" #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/text_format.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::protobuf::TextFormat; using ::testing::HasSubstr; constexpr absl::string_view kVulnerableExpr = R"pb( expr { id: 1 comprehension_expr { iter_var: "unused" accu_var: "accu" result { id: 2 ident_expr { name: "accu" } } accu_init { id: 11 list_expr { elements { id: 12 const_expr { int64_value: 0 } } } } loop_condition { id: 13 const_expr { bool_value: true } } loop_step { id: 3 call_expr { function: "_+_" args { id: 4 ident_expr { name: "accu" } } args { id: 5 ident_expr { name: "accu" } } } } iter_range { id: 6 list_expr { elements { id: 7 const_expr { int64_value: 0 } } elements { id: 8 const_expr { int64_value: 0 } } elements { id: 9 const_expr { int64_value: 0 } } elements { id: 10 const_expr { int64_value: 0 } } } } } } )pb"; TEST(ComprehensionVulnerabilityCheck, EnabledVulnerable) { RuntimeOptions runtime_options; ASSERT_OK_AND_ASSIGN( RuntimeBuilder builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), runtime_options)); ASSERT_OK(EnableComprehensionVulnerabiltyCheck(builder)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ParsedExpr expr; ASSERT_TRUE(TextFormat::ParseFromString(kVulnerableExpr, &expr)); EXPECT_THAT( ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), StatusIs( absl::StatusCode::kInvalidArgument, HasSubstr("Comprehension contains memory exhaustion vulnerability"))); } TEST(ComprehensionVulnerabilityCheck, EnabledNotVulnerable) { RuntimeOptions runtime_options; ASSERT_OK_AND_ASSIGN( RuntimeBuilder builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), runtime_options)); ASSERT_OK(EnableComprehensionVulnerabiltyCheck(builder)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("[0, 0, 0, 0].map(x, x + 1)")); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), IsOk()); } TEST(ComprehensionVulnerabilityCheck, DisabledVulnerable) { RuntimeOptions runtime_options; ASSERT_OK_AND_ASSIGN( RuntimeBuilder builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), runtime_options)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ParsedExpr expr; ASSERT_TRUE(TextFormat::ParseFromString(kVulnerableExpr, &expr)); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), IsOk()); } } // namespace } // namespace cel ================================================ FILE: runtime/constant_folding.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/constant_folding.h" #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/typeinfo.h" #include "eval/compiler/constant_folding.h" #include "internal/casts.h" #include "internal/noop_delete.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { using ::cel::internal::down_cast; using ::cel::runtime_internal::RuntimeFriendAccess; using ::cel::runtime_internal::RuntimeImpl; absl::StatusOr RuntimeImplFromBuilder( RuntimeBuilder& builder ABSL_ATTRIBUTE_LIFETIME_BOUND) { Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); if (RuntimeFriendAccess::RuntimeTypeId(runtime) != TypeId()) { return absl::UnimplementedError( "constant folding only supported on the default cel::Runtime " "implementation."); } return down_cast(&runtime); } absl::Status EnableConstantFoldingImpl( RuntimeBuilder& builder, absl_nullable std::shared_ptr arena, absl_nullable std::shared_ptr message_factory) { CEL_ASSIGN_OR_RETURN(RuntimeImpl* absl_nonnull runtime_impl, RuntimeImplFromBuilder(builder)); if (arena != nullptr) { runtime_impl->environment().KeepAlive(arena); } if (message_factory != nullptr) { runtime_impl->environment().KeepAlive(message_factory); } runtime_impl->expr_builder().AddProgramOptimizer( runtime_internal::CreateConstantFoldingOptimizer( std::move(arena), std::move(message_factory))); return absl::OkStatus(); } } // namespace absl::Status EnableConstantFolding(RuntimeBuilder& builder) { return EnableConstantFoldingImpl(builder, nullptr, nullptr); } absl::Status EnableConstantFolding(RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena) { ABSL_DCHECK(arena != nullptr); return EnableConstantFoldingImpl( builder, std::shared_ptr(arena, internal::NoopDeleteFor()), nullptr); } absl::Status EnableConstantFolding( RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena) { ABSL_DCHECK(arena != nullptr); return EnableConstantFoldingImpl(builder, std::move(arena), nullptr); } absl::Status EnableConstantFolding( RuntimeBuilder& builder, google::protobuf::MessageFactory* absl_nonnull message_factory) { ABSL_DCHECK(message_factory != nullptr); return EnableConstantFoldingImpl( builder, nullptr, std::shared_ptr( message_factory, internal::NoopDeleteFor())); } absl::Status EnableConstantFolding( RuntimeBuilder& builder, absl_nonnull std::shared_ptr message_factory) { ABSL_DCHECK(message_factory != nullptr); return EnableConstantFoldingImpl(builder, nullptr, std::move(message_factory)); } absl::Status EnableConstantFolding( RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena, google::protobuf::MessageFactory* absl_nonnull message_factory) { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(message_factory != nullptr); return EnableConstantFoldingImpl( builder, std::shared_ptr(arena, internal::NoopDeleteFor()), std::shared_ptr( message_factory, internal::NoopDeleteFor())); } absl::Status EnableConstantFolding( RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena, absl_nonnull std::shared_ptr message_factory) { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(message_factory != nullptr); return EnableConstantFoldingImpl( builder, std::shared_ptr(arena, internal::NoopDeleteFor()), std::move(message_factory)); } absl::Status EnableConstantFolding( RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena, google::protobuf::MessageFactory* absl_nonnull message_factory) { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(message_factory != nullptr); return EnableConstantFoldingImpl( builder, std::move(arena), std::shared_ptr( message_factory, internal::NoopDeleteFor())); } absl::Status EnableConstantFolding( RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena, absl_nonnull std::shared_ptr message_factory) { ABSL_DCHECK(arena != nullptr); ABSL_DCHECK(message_factory != nullptr); return EnableConstantFoldingImpl(builder, std::move(arena), std::move(message_factory)); } } // namespace cel::extensions ================================================ FILE: runtime/constant_folding.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "runtime/runtime_builder.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" namespace cel::extensions { // Enable constant folding in the runtime being built. // // Constant folding eagerly evaluates sub-expressions with all constant inputs // at plan time to simplify the resulting program. User functions are executed // if they are eagerly bound. // // The provided, the `google::protobuf::Arena` must outlive the resulting runtime // and any program it creates. Otherwise the runtime will create one as needed // during planning for each program, unless one is explicitly provided during // planning. // // The provided, the `google::protobuf::MessageFactory` must outlive the resulting runtime // and any program it creates. Otherwise the runtime will create one as needed // and use it for all planning and the resulting programs created from the // runtime, unless one is explicitly provided during planning or evaluation. absl::Status EnableConstantFolding(RuntimeBuilder& builder); absl::Status EnableConstantFolding(RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena); absl::Status EnableConstantFolding( RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena); absl::Status EnableConstantFolding( RuntimeBuilder& builder, google::protobuf::MessageFactory* absl_nonnull message_factory); absl::Status EnableConstantFolding( RuntimeBuilder& builder, absl_nonnull std::shared_ptr message_factory); absl::Status EnableConstantFolding( RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena, google::protobuf::MessageFactory* absl_nonnull message_factory); absl::Status EnableConstantFolding( RuntimeBuilder& builder, google::protobuf::Arena* absl_nonnull arena, absl_nonnull std::shared_ptr message_factory); absl::Status EnableConstantFolding( RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena, google::protobuf::MessageFactory* absl_nonnull message_factory); absl::Status EnableConstantFolding( RuntimeBuilder& builder, absl_nonnull std::shared_ptr arena, absl_nonnull std::shared_ptr message_factory); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_RUNTIME_CONSTANT_FOLDING_H_ ================================================ FILE: runtime/constant_folding_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/constant_folding.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "base/function_adapter.h" #include "common/function_descriptor.h" #include "common/value.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::HasSubstr; using ValueMatcher = testing::Matcher; struct TestCase { std::string name; std::string expression; ValueMatcher result_matcher; absl::Status status; }; MATCHER_P(IsIntValue, expected, "") { const Value& value = arg; return value->Is() && value.GetInt().NativeValue() == expected; } MATCHER_P(IsBoolValue, expected, "") { const Value& value = arg; return value->Is() && value.GetBool().NativeValue() == expected; } MATCHER_P(IsErrorValue, expected_substr, "") { const Value& value = arg; return value->Is() && absl::StrContains(value.GetError().NativeValue().message(), expected_substr); } class ConstantFoldingExtTest : public testing::TestWithParam {}; TEST_P(ConstantFoldingExtTest, Runner) { google::protobuf::Arena arena; RuntimeOptions options; const TestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); auto status = BinaryFunctionAdapter, const StringValue&, const StringValue&>:: RegisterGlobalOverload( "prepend", [](const StringValue& value, const StringValue& prefix) { return StringValue( absl::StrCat(prefix.ToString(), value.ToString())); }, builder.function_registry()); ASSERT_THAT(status, IsOk()); ASSERT_THAT(EnableConstantFolding(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expression)); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); Activation activation; auto result = program->Evaluate(&arena, activation); if (test_case.status.ok()) { ASSERT_OK_AND_ASSIGN(Value value, std::move(result)); EXPECT_THAT(value, test_case.result_matcher); return; } EXPECT_THAT(result.status(), StatusIs(test_case.status.code(), HasSubstr(test_case.status.message()))); } INSTANTIATE_TEST_SUITE_P( Cases, ConstantFoldingExtTest, testing::ValuesIn(std::vector{ {"sum", "1 + 2 + 3", IsIntValue(6)}, {"list_create", "[1, 2, 3, 4].filter(x, x < 4).size()", IsIntValue(3)}, {"string_concat", "('12' + '34' + '56' + '78' + '90').size()", IsIntValue(10)}, {"comprehension", "[1, 2, 3, 4].exists(x, x in [4, 5, 6, 7])", IsBoolValue(true)}, {"nested_comprehension", "[1, 2, 3, 4].exists(x, [1, 2, 3, 4].all(y, y <= x))", IsBoolValue(true)}, {"runtime_error", "[1, 2, 3, 4].exists(x, ['4'].all(y, y <= x))", IsErrorValue("No matching overloads")}, {"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", IsIntValue(2)}, {"custom_function", "prepend('def', 'abc') == 'abcdef'", IsBoolValue(true)}}), [](const testing::TestParamInfo& info) { return info.param.name; }); TEST(ConstantFoldingExtTest, LazyFunctionNotFolded) { google::protobuf::Arena arena; RuntimeOptions options; ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); int call_count = 0; using FunctionAdapter = BinaryFunctionAdapter, const StringValue&, const StringValue&>; auto fn = FunctionAdapter::WrapFunction( [&call_count](const StringValue& value, const StringValue& prefix) { call_count++; return StringValue(absl::StrCat(prefix.ToString(), value.ToString())); }); FunctionDescriptor descriptor = FunctionAdapter::CreateDescriptor( "lazy_prepend", /*receiver_style=*/false); ASSERT_THAT(builder.function_registry().RegisterLazyFunction(descriptor), IsOk()); ASSERT_THAT(EnableConstantFolding(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("lazy_prepend('def', 'abc') == 'abcdef'")); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); EXPECT_EQ(call_count, 0); Activation activation; activation.InsertFunction(descriptor, std::move(fn)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_EQ(call_count, 1); EXPECT_THAT(result, IsBoolValue(true)); ASSERT_OK_AND_ASSIGN(result, program->Evaluate(&arena, activation)); EXPECT_EQ(call_count, 2); EXPECT_THAT(result, IsBoolValue(true)); } TEST(ConstantFoldingExtTest, ContextualFunctionNotFolded) { google::protobuf::Arena arena; RuntimeOptions options; ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); int call_count = 0; auto status = BinaryFunctionAdapter< absl::StatusOr, const StringValue&, const StringValue&>::Register("contextual_prepend", /*receiver_style=*/false, [&call_count](const StringValue& value, const StringValue& prefix) { call_count++; return StringValue(absl::StrCat( prefix.ToString(), value.ToString())); }, builder.function_registry(), {/*.is_strict=*/true, /*is_contextual=*/true}); ASSERT_THAT(status, IsOk()); ASSERT_THAT(EnableConstantFolding(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("contextual_prepend('def', 'abc') == 'abcdef'")); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); EXPECT_EQ(call_count, 0); Activation activation; ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); EXPECT_EQ(call_count, 1); EXPECT_THAT(value, IsBoolValue(true)); ASSERT_OK_AND_ASSIGN(value, program->Evaluate(&arena, activation)); EXPECT_EQ(call_count, 2); EXPECT_THAT(value, IsBoolValue(true)); } } // namespace } // namespace cel::extensions ================================================ FILE: runtime/embedder_context.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ #include #include #include "absl/container/inlined_vector.h" #include "absl/log/absl_check.h" #include "absl/types/optional.h" #include "common/typeinfo.h" #include "common/value.h" namespace cel { // EmbedderContext is used to package custom content defined by the embedder // during CEL evaluation. The custom content is indexed by type. Value types // are returned as absl::optional where T is the value type. Pointer types // are returned as T*. // // The content values must be trivially copyable and have a size <= 16 bytes. // These are typically pointers or small value types (e.g. primitives, enums). // // An all zero memory value is used to represent an empty value. The caller // must provide some way to disambiguate if that is a meaningfully distinct // value from nullopt / nullptr. // // Scope is used to provide a distinction between multiple usages of CEL in the // same binary. class EmbedderContext { public: template static EmbedderContext From(Args... args); // Convenience using a default scope. template static EmbedderContext From(Args... args) { return From(args...); } template std::enable_if_t, absl::optional> Get() const; template std::enable_if_t, T> Get() const; template std::enable_if_t, absl::optional> Get() const { return Get(); } template std::enable_if_t, T> Get() const { return Get(); } private: template void Set(T arg, Ts... args); template void Set() {} absl::InlinedVector values_; // These are included to check for bad accesses in debug mode. absl::InlinedVector type_ids_; TypeInfo scope_; }; template void EmbedderContext::Set(Arg arg, Args... args) { using IndexType = std::decay_t; size_t index = TypeIdInSet::template IndexFor(); if (index >= values_.size()) { values_.resize(index + 1, cel::CustomValueContent::Zero()); type_ids_.resize(index + 1); } values_[index] = cel::CustomValueContent::From(arg); type_ids_[index] = cel::TypeId(); Set(args...); } template std::enable_if_t, absl::optional> EmbedderContext::Get() const { ABSL_DCHECK_EQ(cel::TypeId(), scope_) << "EmbedderContext::Get wrong scope"; using IndexType = std::decay_t; size_t index = TypeIdInSet::template IndexFor(); if (index >= values_.size()) { return absl::nullopt; } const auto& content = values_[index]; if (content.IsZero()) return absl::nullopt; ABSL_DCHECK_EQ(type_ids_.size(), values_.size()); ABSL_DCHECK_EQ(type_ids_[index], cel::TypeId()) << "EmbedderContext::Get wrong type id"; return content.To(); } template std::enable_if_t, T> EmbedderContext::Get() const { ABSL_DCHECK_EQ(cel::TypeId(), scope_) << "EmbedderContext::Get wrong scope"; using IndexType = std::decay_t; size_t index = TypeIdInSet::template IndexFor(); if (index >= values_.size()) { return nullptr; } const auto& content = values_[index]; if (content.IsZero()) return nullptr; ABSL_DCHECK_EQ(type_ids_.size(), values_.size()); ABSL_DCHECK_EQ(type_ids_[index], cel::TypeId()) << "EmbedderContext::Get wrong type id"; return content.To(); } template EmbedderContext EmbedderContext::From(Args... args) { EmbedderContext context; context.scope_ = TypeId(); context.Set(args...); return context; } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_EMBEDDER_CONTEXT_H_ ================================================ FILE: runtime/embedder_context_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/embedder_context.h" #include #include "absl/types/optional.h" #include "common/typeinfo.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::Optional; TEST(EmbedderContextTest, From) { struct TestScope {}; EmbedderContext context = EmbedderContext::From(int64_t{42}); EXPECT_THAT((context.Get()), Optional(42)); EXPECT_EQ((context.Get()), absl::nullopt); EmbedderContext context2 = EmbedderContext::From(uint64_t{42}); EXPECT_THAT((context2.Get()), Optional(42)); EXPECT_EQ((context2.Get()), absl::nullopt); // Side effect, but checking that we keep a dense range. EXPECT_EQ(cel::TypeIdInSet::Size(), 2); } TEST(EmbedderContextTest, FromOutOfLine) { struct TestScope {}; EmbedderContext context = EmbedderContext::From(int64_t{42}, uint64_t{43}, double{44}); EXPECT_THAT((context.Get()), Optional(42)); EXPECT_THAT((context.Get()), Optional(43)); EXPECT_THAT((context.Get()), Optional(44)); EXPECT_EQ((context.Get()), absl::nullopt); // Note: Referencing a type not intended to be stored will still reserve a // slot in the TypeIdInSet. EXPECT_EQ(cel::TypeIdInSet::Size(), 4); } TEST(EmbedderContextTest, FromPtrs) { struct TestScope {}; struct TestPointee { } foo; int64_t pointee2; EmbedderContext context = EmbedderContext::From( &foo, const_cast(&pointee2)); EXPECT_EQ((context.Get()), &pointee2); EXPECT_EQ((context.Get()), &foo); EmbedderContext context2 = EmbedderContext::From(&foo); EXPECT_EQ((context2.Get()), nullptr); EXPECT_EQ((context2.Get()), &foo); // Note: const int* not the same as int*. EXPECT_EQ(cel::TypeIdInSet::Size(), 3); } TEST(EmbedderContextTest, FromDefaultScope) { EmbedderContext context = EmbedderContext::From(int64_t{42}); EXPECT_THAT((context.Get()), Optional(42)); EXPECT_EQ((context.Get()), absl::nullopt); } // These death assertions are only enabled when compiled in debug mode. // Caller is responsible for adequately testing since we're limited in what // we can statically check due to the type-erasure. TEST(EmbedderContextDeathTest, GetWithWrongScope) { struct TestScope {}; EmbedderContext context = EmbedderContext::From(int64_t{42}); EXPECT_DEBUG_DEATH( { context.Get(); }, "EmbedderContext::Get wrong scope"); } } // namespace } // namespace cel ================================================ FILE: runtime/function.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_H_ #define THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_H_ #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { class EmbedderContext; // Interface for extension functions. // // The host for the CEL environment may provide implementations to define custom // extension functions. // // The runtime expects functions to be deterministic and side-effect free. class Function { public: virtual ~Function() = default; // Context for the function invocation. // // Collects evaluation state that may be needed for the function to operate. // // The function implementation should not retain a reference to the context // object beyond the duration of the function call or modify the InvokeContext // itself. class InvokeContext { public: InvokeContext( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, const EmbedderContext* absl_nullable embedder_context = nullptr) : descriptor_pool_(descriptor_pool), message_factory_(message_factory), arena_(arena), embedder_context_(embedder_context) {} const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { return descriptor_pool_; } google::protobuf::MessageFactory* absl_nonnull message_factory() const { return message_factory_; } google::protobuf::Arena* absl_nonnull arena() const { return arena_; } const EmbedderContext* absl_nullable embedder_context() const { return embedder_context_; } void set_embedder_context( const EmbedderContext* absl_nullable embedder_context) { embedder_context_ = embedder_context; } private: const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; google::protobuf::MessageFactory* absl_nonnull message_factory_; google::protobuf::Arena* absl_nonnull arena_; const EmbedderContext* absl_nullable embedder_context_; }; ABSL_DEPRECATED("Use the InvokeContext overload instead.") inline absl::StatusOr Invoke( absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; // Attempt to evaluate an extension function based on the runtime arguments // during the evaluation of a CEL expression. // // A non-ok status is interpreted as an unrecoverable error in evaluation ( // e.g. data corruption). This stops evaluation and is propagated immediately. // // A cel::ErrorValue typed result is considered a recoverable error and // follows CEL's logical short-circuiting behavior. virtual absl::StatusOr Invoke(absl::Span args, const InvokeContext& context) const = 0; }; absl::StatusOr Function::Invoke( absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { InvokeContext context(descriptor_pool, message_factory, arena); return Invoke(args, context); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_FUNCTION_H_ ================================================ FILE: runtime/function_adapter.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Definitions for template helpers to wrap C++ functions as CEL extension // function implementations. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ #include #include #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function.h" #include "runtime/internal/function_adapter.h" #include "runtime/register_function_helper.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace runtime_internal { template struct AdaptedTypeTraits { using AssignableType = T; static T ToArg(AssignableType v) { return v; } }; // Specialization for cref parameters without forcing a temporary copy of the // underlying handle argument. template <> struct AdaptedTypeTraits { using AssignableType = const Value*; static std::reference_wrapper ToArg(AssignableType v) { return *v; } }; template <> struct AdaptedTypeTraits { using AssignableType = const StringValue*; static std::reference_wrapper ToArg(AssignableType v) { return *v; } }; template <> struct AdaptedTypeTraits { using AssignableType = const BytesValue*; static std::reference_wrapper ToArg(AssignableType v) { return *v; } }; // Partial specialization for other cases. // // These types aren't referenceable since they aren't actually // represented as alternatives in the underlying variant. // // This still requires an implicit copy and corresponding ref-count increase. template struct AdaptedTypeTraits { using AssignableType = T; static T ToArg(AssignableType v) { return v; } }; template struct AdaptHelperImpl { template static absl::Status Apply(absl::Span input, T& output) { static_assert(sizeof...(Args) > 0); static_assert(std::tuple_size_v == sizeof...(Args)); CEL_RETURN_IF_ERROR(ValueToAdaptedVisitor{input[I]}(&std::get(output))); if constexpr (I == sizeof...(Args) - 1) { return absl::OkStatus(); } else { CEL_RETURN_IF_ERROR( (AdaptHelperImpl::template Apply(input, output))); } return absl::OkStatus(); } }; template struct AdaptHelper { template static absl::Status Apply(absl::Span input, T& output) { return AdaptHelperImpl<0, Args...>::template Apply(input, output); } }; template struct ToArgsImpl { template struct El { using type = T; constexpr static size_t index = I; }; template struct ZipHolder { template static ResultType ToArgs(Op&& op, const TupleType& argbuffer, const Function::InvokeContext& context) { return std::forward(op)( runtime_internal::AdaptedTypeTraits::ToArg( std::get(argbuffer))..., context); } }; template static ZipHolder...> MakeZip(const std::index_sequence&) { return ZipHolder...>{}; } }; template struct ToArgsHelper { template static ResultType Apply(Op&& op, const TupleType& argbuffer, const Function::InvokeContext& context) { using Impl = ToArgsImpl; using Zip = decltype(Impl::MakeZip(std::index_sequence_for{})); return Zip::template ToArgs(std::forward(op), argbuffer, context); } }; } // namespace runtime_internal // Adapter class for generating CEL extension functions from a one argument // function. // // See documentation for Binary Function adapter for general recommendations. // // Example Usage: // double Invert(ValueManager&, double x) { // return 1 / x; // } // // { // std::unique_ptr builder; // // CEL_RETURN_IF_ERROR( // builder->GetRegistry()->Register( // UnaryFunctionAdapter::CreateDescriptor("inv", // /*receiver_style=*/false), // UnaryFunctionAdapter::WrapFunction(&Invert))); // } // // example CEL expression // inv(4) == 1/4 [true] template class NullaryFunctionAdapter : public RegisterHelper> { public: using FunctionType = absl::AnyInvocable; static std::unique_ptr WrapFunction(FunctionType fn) { return std::make_unique(std::move(fn)); } template static std::enable_if_t< std::is_invocable_v, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction([function = std::forward(function)]( const Function::InvokeContext& context) -> T { return function(context.descriptor_pool(), context.message_factory(), context.arena()); }); } template static std::enable_if_t, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction([function = std::forward(function)]( const Function::InvokeContext& context) -> T { return function(); }); } static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, bool is_strict) { return CreateDescriptor(name, receiver_style, {is_strict, /*is_contextual=*/false}); } static FunctionDescriptor CreateDescriptor( absl::string_view name, bool receiver_style, FunctionDescriptorOptions options = {}) { return FunctionDescriptor(name, receiver_style, {}, options); } private: class UnaryFunctionImpl : public Function { public: explicit UnaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, const Function::InvokeContext& context) const final { if (args.size() != 0) { return absl::InvalidArgumentError( "unexpected number of arguments for nullary function"); } if constexpr (std::is_same_v || std::is_same_v>) { return fn_(context); } else { T result = fn_(context); return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); } } private: FunctionType fn_; }; }; // Adapter class for generating CEL extension functions from a one argument // function. // // See documentation for Binary Function adapter for general recommendations. // // Example Usage: // double Invert(ValueManager&, double x) { // return 1 / x; // } // // { // std::unique_ptr builder; // // CEL_RETURN_IF_ERROR( // builder->GetRegistry()->Register( // UnaryFunctionAdapter::CreateDescriptor("inv", // /*receiver_style=*/false), // UnaryFunctionAdapter::WrapFunction(&Invert))); // } // // example CEL expression // inv(4) == 1/4 [true] template class UnaryFunctionAdapter : public RegisterHelper> { public: using FunctionType = absl::AnyInvocable; static std::unique_ptr WrapFunction(FunctionType fn) { return std::make_unique(std::move(fn)); } template static std::enable_if_t< std::is_invocable_v, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction( [function = std::forward(function)]( U arg1, const Function::InvokeContext& context) -> T { return function(arg1, context.descriptor_pool(), context.message_factory(), context.arena()); }); } template static std::enable_if_t, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction( [function = std::forward(function)]( U arg1, const Function::InvokeContext& context) -> T { return function(arg1); }); } static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, bool is_strict) { return CreateDescriptor( name, receiver_style, FunctionDescriptorOptions{is_strict, /*is_contextual=*/false}); } static FunctionDescriptor CreateDescriptor( absl::string_view name, bool receiver_style, FunctionDescriptorOptions options = {}) { return FunctionDescriptor(name, receiver_style, {runtime_internal::AdaptedKind()}, options); } private: class UnaryFunctionImpl : public Function { public: explicit UnaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, const Function::InvokeContext& context) const final { using ArgTraits = runtime_internal::AdaptedTypeTraits; if (args.size() != 1) { return absl::InvalidArgumentError( "unexpected number of arguments for unary function"); } typename ArgTraits::AssignableType arg1; CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[0]}(&arg1)); if constexpr (std::is_same_v || std::is_same_v>) { return fn_(ArgTraits::ToArg(arg1), context); } else { T result = fn_(ArgTraits::ToArg(arg1), context); return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); } } private: FunctionType fn_; }; }; // Adapter class for generating CEL extension functions from a two argument // function. Generates an implementation of the cel::Function interface that // calls the function to wrap. // // Extension functions must distinguish between recoverable errors (error that // should participate in CEL's error pruning) and unrecoverable errors (a non-ok // absl::Status that stops evaluation). The function to wrap may return // StatusOr to propagate a Status, or return a Value with an Error // value to introduce a CEL error. // // To introduce an extension function that may accept any kind of CEL value as // an argument, the wrapped function should use a Value parameter and // check the type of the argument at evaluation time. // // Supported CEL to C++ type mappings: // bool -> bool // double -> double // uint -> uint64_t // int -> int64_t // timestamp -> absl::Time // duration -> absl::Duration // // Complex types may be referred to by cref or value. // To return these, users should return a Value. // any/dyn -> Value, const Value& // string -> StringValue | const StringValue& // bytes -> BytesValue | const BytesValue& // list -> ListValue | const ListValue& // map -> MapValue | const MapValue& // struct -> StructValue | const StructValue& // null -> NullValue | const NullValue& // // To intercept error and unknown arguments, users must use a non-strict // overload with all arguments typed as any and check the kind of the // Value argument. // // Example Usage: // double SquareDifference(ValueManager&, double x, double y) { // return x * x - y * y; // } // // { // RuntimeBuilder builder; // // Initialize Expression builder with built-ins as needed. // // CEL_RETURN_IF_ERROR( // builder.function_registry().Register( // BinaryFunctionAdapter::CreateDescriptor( // "sq_diff", /*receiver_style=*/false), // BinaryFunctionAdapter::WrapFunction( // &SquareDifference))); // // // // Alternative shorthand // // See RegisterHelper (template base class) for details. // // runtime/register_function_helper.h // auto status = BinaryFunctionAdapter:: // RegisterGlobalOverload( // "sq_diff", // &SquareDifference, // builder.function_registry()); // CEL_RETURN_IF_ERROR(status); // } // // example CEL expression: // sq_diff(4, 3) == 7 [true] // template class BinaryFunctionAdapter : public RegisterHelper> { public: using FunctionType = absl::AnyInvocable; static std::unique_ptr WrapFunction(FunctionType fn) { return std::make_unique(std::move(fn)); } template static std::enable_if_t< std::is_invocable_v, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction( [function = std::forward(function)]( U arg1, V arg2, const Function::InvokeContext& context) -> T { return function(arg1, arg2, context.descriptor_pool(), context.message_factory(), context.arena()); }); } template static std::enable_if_t, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction( [function = std::forward(function)]( U arg1, V arg2, const Function::InvokeContext& context) -> T { return function(arg1, arg2); }); } static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, bool is_strict) { return CreateDescriptor(name, receiver_style, {is_strict, /*is_contextual=*/false}); } static FunctionDescriptor CreateDescriptor( absl::string_view name, bool receiver_style, FunctionDescriptorOptions options = {}) { return FunctionDescriptor(name, receiver_style, {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind()}, options); } private: class BinaryFunctionImpl : public Function { public: explicit BinaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, const Function::InvokeContext& context) const final { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; if (args.size() != 2) { return absl::InvalidArgumentError( "unexpected number of arguments for binary function"); } typename Arg1Traits::AssignableType arg1; typename Arg2Traits::AssignableType arg2; CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[0]}(&arg1)); CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[1]}(&arg2)); if constexpr (std::is_same_v || std::is_same_v>) { return fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), context); } else { T result = fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), context); return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); } } private: BinaryFunctionAdapter::FunctionType fn_; }; }; template class TernaryFunctionAdapter : public RegisterHelper> { public: using FunctionType = absl::AnyInvocable; static std::unique_ptr WrapFunction(FunctionType fn) { return std::make_unique(std::move(fn)); } template static std::enable_if_t< std::is_invocable_v< F, U, V, W, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull>, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction([function = std::forward(function)]( U arg1, V arg2, W arg3, const Function::InvokeContext& context) -> T { return function(arg1, arg2, arg3, context.descriptor_pool(), context.message_factory(), context.arena()); }); } template static std::enable_if_t, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction([function = std::forward(function)]( U arg1, V arg2, W arg3, const Function::InvokeContext& context) -> T { return function(arg1, arg2, arg3); }); } static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, bool is_strict) { return CreateDescriptor( name, receiver_style, FunctionDescriptorOptions{is_strict, /*is_contextual=*/false}); } static FunctionDescriptor CreateDescriptor( absl::string_view name, bool receiver_style, FunctionDescriptorOptions options = {}) { return FunctionDescriptor( name, receiver_style, {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind()}, options); } private: class TernaryFunctionImpl : public Function { public: explicit TernaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, const Function::InvokeContext& context) const final { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; using Arg3Traits = runtime_internal::AdaptedTypeTraits; if (args.size() != 3) { return absl::InvalidArgumentError( "unexpected number of arguments for ternary function"); } typename Arg1Traits::AssignableType arg1; typename Arg2Traits::AssignableType arg2; typename Arg3Traits::AssignableType arg3; CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[0]}(&arg1)); CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[1]}(&arg2)); CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[2]}(&arg3)); if constexpr (std::is_same_v || std::is_same_v>) { return fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), Arg3Traits::ToArg(arg3), context); } else { T result = fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), Arg3Traits::ToArg(arg3), context); return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); } } private: TernaryFunctionAdapter::FunctionType fn_; }; }; template class QuaternaryFunctionAdapter : public RegisterHelper> { public: using FunctionType = absl::AnyInvocable; static std::unique_ptr WrapFunction(FunctionType fn) { return std::make_unique(std::move(fn)); } template static std::enable_if_t< std::is_invocable_v< F, U, V, W, X, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull>, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction([function = std::forward(function)]( U arg1, V arg2, W arg3, X arg4, const Function::InvokeContext& context) -> T { return function(arg1, arg2, arg3, arg4, context.descriptor_pool(), context.message_factory(), context.arena()); }); } template static std::enable_if_t, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction([function = std::forward(function)]( U arg1, V arg2, W arg3, X arg4, const Function::InvokeContext& context) -> T { return function(arg1, arg2, arg3, arg4); }); } static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, bool is_strict) { return CreateDescriptor(name, receiver_style, {is_strict, /*is_contextual=*/false}); } static FunctionDescriptor CreateDescriptor( absl::string_view name, bool receiver_style, FunctionDescriptorOptions options = {}) { return FunctionDescriptor( name, receiver_style, {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind()}, options); } private: class QuaternaryFunctionImpl : public Function { public: explicit QuaternaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, const Function::InvokeContext& context) const final { using Arg1Traits = runtime_internal::AdaptedTypeTraits; using Arg2Traits = runtime_internal::AdaptedTypeTraits; using Arg3Traits = runtime_internal::AdaptedTypeTraits; using Arg4Traits = runtime_internal::AdaptedTypeTraits; if (args.size() != 4) { return absl::InvalidArgumentError( "unexpected number of arguments for quaternary function"); } typename Arg1Traits::AssignableType arg1; typename Arg2Traits::AssignableType arg2; typename Arg3Traits::AssignableType arg3; typename Arg4Traits::AssignableType arg4; CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[0]}(&arg1)); CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[1]}(&arg2)); CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[2]}(&arg3)); CEL_RETURN_IF_ERROR( runtime_internal::ValueToAdaptedVisitor{args[3]}(&arg4)); if constexpr (std::is_same_v || std::is_same_v>) { return fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), Arg3Traits::ToArg(arg3), Arg4Traits::ToArg(arg4), context); } else { T result = fn_(Arg1Traits::ToArg(arg1), Arg2Traits::ToArg(arg2), Arg3Traits::ToArg(arg3), Arg4Traits::ToArg(arg4), context); return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); } } private: QuaternaryFunctionAdapter::FunctionType fn_; }; }; // Primary template for n-ary adapter. template class NaryFunctionAdapter; template class NaryFunctionAdapter : public NullaryFunctionAdapter {}; template class NaryFunctionAdapter : public UnaryFunctionAdapter {}; template class NaryFunctionAdapter : public BinaryFunctionAdapter {}; template class NaryFunctionAdapter : public TernaryFunctionAdapter {}; template class NaryFunctionAdapter : public QuaternaryFunctionAdapter {}; // N-ary function adapter. // // Prefer using one of the specific count adapters above for readability and // better error messages. template class NaryFunctionAdapter : public RegisterHelper> { public: using FunctionType = absl::AnyInvocable; static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, bool is_strict) { return CreateDescriptor(name, receiver_style, {is_strict, /*is_contextual=*/false}); } static FunctionDescriptor CreateDescriptor( absl::string_view name, bool receiver_style, FunctionDescriptorOptions options = {}) { return FunctionDescriptor(name, receiver_style, {runtime_internal::AdaptedKind()...}, options); } static std::unique_ptr WrapFunction(FunctionType fn) { return std::make_unique(std::move(fn)); } template static std::enable_if_t< std::is_invocable_v< F, Args..., const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull>, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction( [function = std::forward(function)]( Args... args, const Function::InvokeContext& context) -> T { return function(args..., context.descriptor_pool(), context.message_factory(), context.arena()); }); } template static std::enable_if_t, std::unique_ptr> WrapFunction(F&& function) { return WrapFunction( [function = std::forward(function)]( Args... args, const Function::InvokeContext& context) -> T { return function(args...); }); } private: class NaryFunctionImpl : public Function { private: using ArgBuffer = std::tuple< typename runtime_internal::AdaptedTypeTraits::AssignableType...>; public: explicit NaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} absl::StatusOr Invoke( absl::Span args, const Function::InvokeContext& context) const final { if (args.size() != sizeof...(Args)) { return absl::InvalidArgumentError( absl::StrCat("unexpected number of arguments for ", sizeof...(Args), "-ary function")); } ArgBuffer arg_buffer; CEL_RETURN_IF_ERROR( runtime_internal::AdaptHelper::Apply(args, arg_buffer)); if constexpr (std::is_same_v || std::is_same_v>) { return runtime_internal::ToArgsHelper::template Apply( fn_, arg_buffer, context); } else { T result = runtime_internal::ToArgsHelper::template Apply( fn_, arg_buffer, context); return runtime_internal::AdaptedToValueVisitor{}(std::move(result)); } } private: FunctionType fn_; }; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ ================================================ FILE: runtime/function_adapter_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/function_adapter.h" #include #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "common/value.h" #include "common/value_testing.h" #include "internal/testing.h" #include "runtime/function.h" namespace cel { namespace { using ::absl_testing::StatusIs; using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::IsEmpty; class FunctionAdapterTest : public common_internal::ValueTest<> { using Base = common_internal::ValueTest<>; public: FunctionAdapterTest() : Base(), test_context_(descriptor_pool(), message_factory(), arena()) {} const Function::InvokeContext& test_invoke_context() const { return test_context_; } protected: cel::Function::InvokeContext test_context_; }; TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionOldOverload) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](const StringValue& x, const Function::InvokeContext& context) -> StringValue { std::string buf; absl::string_view s = x.ToStringView(&buf); buf = absl::StrCat("pre_", s); return StringValue::From(std::move(buf), context.arena()); }); std::vector args{StringValue::Wrap(absl::string_view("foo"), arena())}; ASSERT_OK_AND_ASSIGN( auto result, wrapped->Invoke(args, descriptor_pool(), message_factory(), arena())); EXPECT_THAT(result, test::StringValueIs("pre_foo")); ASSERT_OK_AND_ASSIGN(result, wrapped->Invoke(args, test_invoke_context())); EXPECT_THAT(result, test::StringValueIs("pre_foo")); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionInt) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction([](int64_t x) -> int64_t { return x + 2; }); std::vector args{IntValue(40)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetInt().NativeValue(), 42); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionDouble) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction([](double x) -> double { return x * 2; }); std::vector args{DoubleValue(40.0)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetDouble().NativeValue(), 80.0); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionUint) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](uint64_t x) -> uint64_t { return x - 2; }); std::vector args{UintValue(44)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetUint().NativeValue(), 42); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionBool) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction([](bool x) -> bool { return !x; }); std::vector args{BoolValue(true)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetBool().NativeValue(), false); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionTimestamp) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](absl::Time x) -> absl::Time { return x + absl::Minutes(1); }); std::vector args; args.emplace_back() = TimestampValue(absl::UnixEpoch()); ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetTimestamp().NativeValue(), absl::UnixEpoch() + absl::Minutes(1)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionDuration) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](absl::Duration x) -> absl::Duration { return x + absl::Seconds(2); }); std::vector args; args.emplace_back() = DurationValue(absl::Seconds(6)); ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetDuration().NativeValue(), absl::Seconds(8)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionString) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction([](const StringValue& x) -> StringValue { return StringValue("pre_" + x.ToString()); }); std::vector args; args.emplace_back() = StringValue("string"); ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetString().ToString(), "pre_string"); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionBytes) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction([](const BytesValue& x) -> BytesValue { return BytesValue("pre_" + x.ToString()); }); std::vector args; args.emplace_back() = BytesValue("bytes"); ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetBytes().ToString(), "pre_bytes"); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionAny) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](const Value& x) -> uint64_t { return x.GetUint().NativeValue() - 2; }); std::vector args{UintValue(44)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetUint().NativeValue(), 42); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionReturnError) { using FunctionAdapter = UnaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction([](uint64_t x) -> Value { return ErrorValue(absl::InvalidArgumentError("test_error")); }); std::vector args{UintValue(44)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_THAT(result.GetError().NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, "test_error")); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionPropagateStatus) { using FunctionAdapter = UnaryFunctionAdapter, uint64_t>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction([](uint64_t x) -> absl::StatusOr { // Returning a status directly stops CEL evaluation and // immediately returns. return absl::InternalError("test_error"); }); std::vector args{UintValue(44)}; EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInternal, "test_error")); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionReturnStatusOrValue) { using FunctionAdapter = UnaryFunctionAdapter, uint64_t>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](uint64_t x) -> absl::StatusOr { return x; }); std::vector args{UintValue(44)}; ASSERT_OK_AND_ASSIGN(Value result, wrapped->Invoke(args, test_invoke_context())); EXPECT_EQ(result.GetUint().NativeValue(), 44); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionWrongArgCountError) { using FunctionAdapter = UnaryFunctionAdapter, uint64_t>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](uint64_t x) -> absl::StatusOr { return 42; }); std::vector args{UintValue(44), UintValue(43)}; EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInvalidArgument, "unexpected number of arguments for unary function")); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterWrapFunctionWrongArgTypeError) { using FunctionAdapter = UnaryFunctionAdapter, uint64_t>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](uint64_t x) -> absl::StatusOr { return 42; }); std::vector args{DoubleValue(44)}; EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("expected uint value"))); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorInt) { FunctionDescriptor desc = UnaryFunctionAdapter, int64_t>::CreateDescriptor( "Increment", false); EXPECT_EQ(desc.name(), "Increment"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kInt64)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorDouble) { FunctionDescriptor desc = UnaryFunctionAdapter, double>::CreateDescriptor( "Mult2", true); EXPECT_EQ(desc.name(), "Mult2"); EXPECT_TRUE(desc.is_strict()); EXPECT_TRUE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kDouble)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorUint) { FunctionDescriptor desc = UnaryFunctionAdapter, uint64_t>::CreateDescriptor( "Increment", false); EXPECT_EQ(desc.name(), "Increment"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kUint64)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorBool) { FunctionDescriptor desc = UnaryFunctionAdapter, bool>::CreateDescriptor( "Not", false); EXPECT_EQ(desc.name(), "Not"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kBool)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorTimestamp) { FunctionDescriptor desc = UnaryFunctionAdapter, absl::Time>::CreateDescriptor( "AddMinute", false); EXPECT_EQ(desc.name(), "AddMinute"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kTimestamp)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorDuration) { FunctionDescriptor desc = UnaryFunctionAdapter, absl::Duration>::CreateDescriptor("AddFiveSeconds", false); EXPECT_EQ(desc.name(), "AddFiveSeconds"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kDuration)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorString) { FunctionDescriptor desc = UnaryFunctionAdapter, StringValue>::CreateDescriptor("Prepend", false); EXPECT_EQ(desc.name(), "Prepend"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kString)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorBytes) { FunctionDescriptor desc = UnaryFunctionAdapter, BytesValue>::CreateDescriptor( "Prepend", false); EXPECT_EQ(desc.name(), "Prepend"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kBytes)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorAny) { FunctionDescriptor desc = UnaryFunctionAdapter, Value>::CreateDescriptor( "Increment", false); EXPECT_EQ(desc.name(), "Increment"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorNonStrict) { FunctionDescriptor desc = UnaryFunctionAdapter, Value>::CreateDescriptor( "Increment", false, /*is_strict=*/false); EXPECT_EQ(desc.name(), "Increment"); EXPECT_FALSE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionInt) { using FunctionAdapter = BinaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](int64_t x, int64_t y) -> int64_t { return x + y; }); std::vector args{IntValue(21), IntValue(21)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetInt().NativeValue(), 42); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionDouble) { using FunctionAdapter = BinaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](double x, double y) -> double { return x * y; }); std::vector args{DoubleValue(40.0), DoubleValue(2.0)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetDouble().NativeValue(), 80.0); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionUint) { using FunctionAdapter = BinaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](uint64_t x, uint64_t y) -> uint64_t { return x - y; }); std::vector args{UintValue(44), UintValue(2)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetUint().NativeValue(), 42); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionBool) { using FunctionAdapter = BinaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](bool x, bool y) -> bool { return x != y; }); std::vector args{BoolValue(false), BoolValue(true)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetBool().NativeValue(), true); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionTimestamp) { using FunctionAdapter = BinaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](absl::Time x, absl::Time y) -> absl::Time { return x > y ? x : y; }); std::vector args; args.emplace_back() = TimestampValue(absl::UnixEpoch() + absl::Seconds(1)); args.emplace_back() = TimestampValue(absl::UnixEpoch() + absl::Seconds(2)); ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetTimestamp().NativeValue(), absl::UnixEpoch() + absl::Seconds(2)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionDuration) { using FunctionAdapter = BinaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](absl::Duration x, absl::Duration y) -> absl::Duration { return x > y ? x : y; }); std::vector args; args.emplace_back() = DurationValue(absl::Seconds(5)); args.emplace_back() = DurationValue(absl::Seconds(2)); ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetDuration().NativeValue(), absl::Seconds(5)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionString) { using FunctionAdapter = BinaryFunctionAdapter, const StringValue&, const StringValue&>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](const StringValue& x, const StringValue& y) -> absl::StatusOr { return StringValue(x.ToString() + y.ToString()); }); std::vector args; args.emplace_back() = StringValue("abc"); args.emplace_back() = StringValue("def"); ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetString().ToString(), "abcdef"); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionBytes) { using FunctionAdapter = BinaryFunctionAdapter, const BytesValue&, const BytesValue&>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](const BytesValue& x, const BytesValue& y) -> absl::StatusOr { return BytesValue(x.ToString() + y.ToString()); }); std::vector args; args.emplace_back() = BytesValue("abc"); args.emplace_back() = BytesValue("def"); ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetBytes().ToString(), "abcdef"); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionAny) { using FunctionAdapter = BinaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](const Value& x, const Value& y) -> uint64_t { return x.GetUint().NativeValue() - static_cast(y.GetDouble().NativeValue()); }); std::vector args{UintValue(44), DoubleValue(2)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetUint().NativeValue(), 42); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionReturnError) { using FunctionAdapter = BinaryFunctionAdapter; std::unique_ptr wrapped = FunctionAdapter::WrapFunction([](int64_t x, uint64_t y) -> Value { return ErrorValue(absl::InvalidArgumentError("test_error")); }); std::vector args{IntValue(44), UintValue(44)}; ASSERT_OK_AND_ASSIGN(auto result, wrapped->Invoke(args, test_invoke_context())); ASSERT_TRUE(result->Is()); EXPECT_THAT(result.GetError().NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, "test_error")); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionPropagateStatus) { using FunctionAdapter = BinaryFunctionAdapter, int64_t, uint64_t>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](int64_t, uint64_t x) -> absl::StatusOr { // Returning a status directly stops CEL evaluation and // immediately returns. return absl::InternalError("test_error"); }); std::vector args{IntValue(43), UintValue(44)}; EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInternal, "test_error")); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionWrongArgCountError) { using FunctionAdapter = BinaryFunctionAdapter, uint64_t, double>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](uint64_t x, double y) -> absl::StatusOr { return 42; }); std::vector args{UintValue(44)}; EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInvalidArgument, "unexpected number of arguments for binary function")); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionWrongArgTypeError) { using FunctionAdapter = BinaryFunctionAdapter, uint64_t, uint64_t>; std::unique_ptr wrapped = FunctionAdapter::WrapFunction( [](int64_t x, int64_t y) -> absl::StatusOr { return 42; }); std::vector args{DoubleValue(44), DoubleValue(44)}; EXPECT_THAT(wrapped->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("expected uint value"))); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorInt) { FunctionDescriptor desc = BinaryFunctionAdapter, int64_t, int64_t>::CreateDescriptor("Add", false); EXPECT_EQ(desc.name(), "Add"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kInt64, Kind::kInt64)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorDouble) { FunctionDescriptor desc = BinaryFunctionAdapter, double, double>::CreateDescriptor("Mult", true); EXPECT_EQ(desc.name(), "Mult"); EXPECT_TRUE(desc.is_strict()); EXPECT_TRUE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kDouble, Kind::kDouble)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorUint) { FunctionDescriptor desc = BinaryFunctionAdapter, uint64_t, uint64_t>::CreateDescriptor("Add", false); EXPECT_EQ(desc.name(), "Add"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kUint64, Kind::kUint64)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorBool) { FunctionDescriptor desc = BinaryFunctionAdapter, bool, bool>::CreateDescriptor("Xor", false); EXPECT_EQ(desc.name(), "Xor"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kBool, Kind::kBool)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorTimestamp) { FunctionDescriptor desc = BinaryFunctionAdapter, absl::Time, absl::Time>::CreateDescriptor("Max", false); EXPECT_EQ(desc.name(), "Max"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kTimestamp, Kind::kTimestamp)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorDuration) { FunctionDescriptor desc = BinaryFunctionAdapter, absl::Duration, absl::Duration>::CreateDescriptor("Max", false); EXPECT_EQ(desc.name(), "Max"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kDuration, Kind::kDuration)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorString) { FunctionDescriptor desc = BinaryFunctionAdapter, StringValue, StringValue>::CreateDescriptor("Concat", false); EXPECT_EQ(desc.name(), "Concat"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kString, Kind::kString)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorBytes) { FunctionDescriptor desc = BinaryFunctionAdapter, BytesValue, BytesValue>::CreateDescriptor("Concat", false); EXPECT_EQ(desc.name(), "Concat"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kBytes, Kind::kBytes)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorAny) { FunctionDescriptor desc = BinaryFunctionAdapter, Value, Value>::CreateDescriptor("Add", false); EXPECT_EQ(desc.name(), "Add"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny, Kind::kAny)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorNonStrict) { FunctionDescriptor desc = BinaryFunctionAdapter, Value, Value>::CreateDescriptor("Add", false, false); EXPECT_EQ(desc.name(), "Add"); EXPECT_FALSE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny, Kind::kAny)); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor0Args) { FunctionDescriptor desc = NullaryFunctionAdapter>::CreateDescriptor( "ZeroArgs", false); EXPECT_EQ(desc.name(), "ZeroArgs"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), IsEmpty()); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction0Args) { std::unique_ptr fn = NullaryFunctionAdapter>::WrapFunction( []() { return StringValue("abc"); }); ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke({}, descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetString().ToString(), "abc"); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor3Args) { FunctionDescriptor desc = TernaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&>::CreateDescriptor("MyFormatter", false); EXPECT_EQ(desc.name(), "MyFormatter"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kInt64, Kind::kBool, Kind::kString)); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3Args) { std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val) -> absl::StatusOr { return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", string_val.ToString())); }); std::vector args{IntValue(42), BoolValue(false)}; args.emplace_back() = StringValue("abcd"); ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke(args, descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetString().ToString(), "42_false_abcd"); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3ArgsBadArgType) { std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val) -> absl::StatusOr { return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", string_val.ToString())); }); std::vector args{IntValue(42), BoolValue(false)}; args.emplace_back() = TimestampValue(absl::UnixEpoch()); EXPECT_THAT(fn->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("expected string value"))); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3ArgsBadArgCount) { std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val) -> absl::StatusOr { return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", string_val.ToString())); }); std::vector args{IntValue(42), BoolValue(false)}; EXPECT_THAT(fn->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("unexpected number of arguments"))); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor5Args) { FunctionDescriptor desc = NaryFunctionAdapter, int64_t, bool, const StringValue&, int64_t, int64_t>::CreateDescriptor("MyFormatter", false); EXPECT_EQ(desc.name(), "MyFormatter"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); EXPECT_THAT(desc.types(), ElementsAre(Kind::kInt64, Kind::kBool, Kind::kString, Kind::kInt64, Kind::kInt64)); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5Args) { std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&, int64_t, int64_t>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val, int64_t extra_arg, int64_t extra_arg2) -> absl::StatusOr { return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", string_val.ToString(), "_", extra_arg, "_", extra_arg2)); }); std::vector args{IntValue(42), BoolValue(false)}; args.emplace_back() = StringValue("abcd"); args.push_back(IntValue(123)); args.push_back(IntValue(456)); ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke(args, descriptor_pool(), message_factory(), arena())); ASSERT_TRUE(result->Is()); EXPECT_EQ(result.GetString().ToString(), "42_false_abcd_123_456"); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5ArgsBadArgType) { std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&, int64_t, int64_t>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val, int64_t extra_arg, int64_t extra_arg2) -> absl::StatusOr { static_cast(extra_arg); static_cast(extra_arg2); return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", string_val.ToString())); }); std::vector args{IntValue(42), BoolValue(false)}; args.emplace_back() = TimestampValue(absl::UnixEpoch()); args.push_back(IntValue(123)); args.push_back(IntValue(456)); EXPECT_THAT(fn->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("expected string value"))); } TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5ArgsBadArgCount) { std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&, int64_t, int64_t>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val, int64_t extra_arg, int64_t extra_arg2) -> absl::StatusOr { static_cast(extra_arg); static_cast(extra_arg2); return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), "_", string_val.ToString())); }); std::vector args{IntValue(42), BoolValue(false)}; EXPECT_THAT(fn->Invoke(args, test_invoke_context()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("unexpected number of arguments"))); } } // namespace } // namespace cel ================================================ FILE: runtime/function_overload_reference.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_OVERLOAD_REFERENCE_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_OVERLOAD_REFERENCE_H_ #include "common/function_descriptor.h" #include "runtime/function.h" namespace cel { // Represents a view to a single overload for a function. // // Clients must take care to not persist instances beyond the lifetime of the // owning object. struct FunctionOverloadReference { const FunctionDescriptor& descriptor; const Function& implementation; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_OVERLOAD_REFERENCE_H_ ================================================ FILE: runtime/function_provider.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_PROVIDER_H_ #include "absl/status/statusor.h" #include "common/function_descriptor.h" #include "runtime/activation_interface.h" #include "runtime/function_overload_reference.h" namespace cel::runtime_internal { // Interface for providers of lazily bound functions. // // Lazily bound functions may have an implementation that is dependent on the // evaluation context (as represented by the Activation). class FunctionProvider { public: virtual ~FunctionProvider() = default; // Returns a reference to a function implementation based on the provided // Activation. Given the same activation, this should return the same Function // instance. The cel::FunctionOverloadReference is assumed to be stable for // the life of the Activation. // // An empty optional result is interpreted as no matching overload. virtual absl::StatusOr> GetFunction( const FunctionDescriptor& descriptor, const ActivationInterface& activation) const = 0; }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_PROVIDER_H_ ================================================ FILE: runtime/function_registry.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/function_registry.h" #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "runtime/activation_interface.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "runtime/function_provider.h" namespace cel { namespace { // Impl for simple provider that looks up functions in an activation function // registry. class ActivationFunctionProviderImpl : public cel::runtime_internal::FunctionProvider { public: ActivationFunctionProviderImpl() = default; absl::StatusOr> GetFunction( const cel::FunctionDescriptor& descriptor, const cel::ActivationInterface& activation) const override { std::vector overloads = activation.FindFunctionOverloads(descriptor.name()); absl::optional matching_overload = absl::nullopt; for (const auto& overload : overloads) { if (overload.descriptor.ShapeMatches(descriptor)) { if (matching_overload.has_value()) { return absl::Status(absl::StatusCode::kInvalidArgument, "Couldn't resolve function."); } matching_overload.emplace(overload); } } return matching_overload; } }; // Create a CelFunctionProvider that just looks up the functions inserted in the // Activation. This is a convenience implementation for a simple, common // use-case. std::unique_ptr CreateActivationFunctionProvider() { return std::make_unique(); } } // namespace absl::Status FunctionRegistry::Register( const cel::FunctionDescriptor& descriptor, std::unique_ptr implementation) { if (DescriptorRegistered(descriptor)) { return absl::Status( absl::StatusCode::kAlreadyExists, "CelFunction with specified parameters already registered"); } if (!ValidateNonStrictOverload(descriptor)) { return absl::Status(absl::StatusCode::kAlreadyExists, "Only one overload is allowed for non-strict function"); } auto& overloads = functions_[descriptor.name()]; overloads.static_overloads.push_back( StaticFunctionEntry(descriptor, std::move(implementation))); return absl::OkStatus(); } absl::Status FunctionRegistry::RegisterLazyFunction( const cel::FunctionDescriptor& descriptor) { if (DescriptorRegistered(descriptor)) { return absl::Status( absl::StatusCode::kAlreadyExists, "CelFunction with specified parameters already registered"); } if (!ValidateNonStrictOverload(descriptor)) { return absl::Status(absl::StatusCode::kAlreadyExists, "Only one overload is allowed for non-strict function"); } auto& overloads = functions_[descriptor.name()]; overloads.lazy_overloads.push_back( LazyFunctionEntry(descriptor, CreateActivationFunctionProvider())); return absl::OkStatus(); } std::vector FunctionRegistry::FindStaticOverloads(absl::string_view name, bool receiver_style, absl::Span types) const { std::vector matched_funcs; auto overloads = functions_.find(name); if (overloads == functions_.end()) { return matched_funcs; } for (const auto& overload : overloads->second.static_overloads) { if (overload.descriptor->ShapeMatches(receiver_style, types)) { matched_funcs.push_back({*overload.descriptor, *overload.implementation}); } } return matched_funcs; } std::vector FunctionRegistry::FindStaticOverloadsByArity(absl::string_view name, bool receiver_style, size_t arity) const { std::vector matched_funcs; auto overloads = functions_.find(name); if (overloads == functions_.end()) { return matched_funcs; } for (const auto& overload : overloads->second.static_overloads) { if (overload.descriptor->receiver_style() == receiver_style && overload.descriptor->types().size() == arity) { matched_funcs.push_back({*overload.descriptor, *overload.implementation}); } } return matched_funcs; } std::vector FunctionRegistry::FindLazyOverloads( absl::string_view name, bool receiver_style, absl::Span types) const { std::vector matched_funcs; auto overloads = functions_.find(name); if (overloads == functions_.end()) { return matched_funcs; } for (const auto& entry : overloads->second.lazy_overloads) { if (entry.descriptor->ShapeMatches(receiver_style, types)) { matched_funcs.push_back({*entry.descriptor, *entry.function_provider}); } } return matched_funcs; } std::vector FunctionRegistry::FindLazyOverloadsByArity(absl::string_view name, bool receiver_style, size_t arity) const { std::vector matched_funcs; auto overloads = functions_.find(name); if (overloads == functions_.end()) { return matched_funcs; } for (const auto& entry : overloads->second.lazy_overloads) { if (entry.descriptor->receiver_style() == receiver_style && entry.descriptor->types().size() == arity) { matched_funcs.push_back({*entry.descriptor, *entry.function_provider}); } } return matched_funcs; } absl::node_hash_map> FunctionRegistry::ListFunctions() const { absl::node_hash_map> descriptor_map; for (const auto& entry : functions_) { std::vector descriptors; const RegistryEntry& function_entry = entry.second; descriptors.reserve(function_entry.static_overloads.size() + function_entry.lazy_overloads.size()); for (const auto& entry : function_entry.static_overloads) { descriptors.push_back(entry.descriptor.get()); } for (const auto& entry : function_entry.lazy_overloads) { descriptors.push_back(entry.descriptor.get()); } descriptor_map[entry.first] = std::move(descriptors); } return descriptor_map; } bool FunctionRegistry::DescriptorRegistered( const cel::FunctionDescriptor& descriptor) const { auto overloads = functions_.find(descriptor.name()); if (overloads == functions_.end()) { return false; } const RegistryEntry& entry = overloads->second; for (const auto& static_ovl : entry.static_overloads) { if (static_ovl.descriptor->ShapeMatches(descriptor)) { return true; } } for (const auto& lazy_ovl : entry.lazy_overloads) { if (lazy_ovl.descriptor->ShapeMatches(descriptor)) { return true; } } return false; } bool FunctionRegistry::ValidateNonStrictOverload( const cel::FunctionDescriptor& descriptor) const { auto overloads = functions_.find(descriptor.name()); if (overloads == functions_.end()) { return true; } const RegistryEntry& entry = overloads->second; if (!descriptor.is_strict()) { // If the newly added overload is a non-strict function, we require that // there are no other overloads, which is not possible here. return false; } // If the newly added overload is a strict function, we need to make sure // that no previous overloads are registered non-strict. If the list of // overload is not empty, we only need to check the first overload. This is // because if the first overload is strict, other overloads must also be // strict by the rule. return (entry.static_overloads.empty() || entry.static_overloads[0].descriptor->is_strict()) && (entry.lazy_overloads.empty() || entry.lazy_overloads[0].descriptor->is_strict()); } } // namespace cel ================================================ FILE: runtime/function_registry.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_ #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "runtime/function_provider.h" namespace cel { // FunctionRegistry manages binding builtin or custom CEL functions to // implementations. // // The registry is consulted during program planning to tie overload candidates // to the CEL function in the AST getting planned. // // The registry takes ownership of the cel::Function objects -- the registry // must outlive any program planned using it. // // This class is move-only. class FunctionRegistry { public: // Represents a single overload for a lazily provided function. struct LazyOverload { const cel::FunctionDescriptor& descriptor; const cel::runtime_internal::FunctionProvider& provider; }; FunctionRegistry() = default; // Move-only FunctionRegistry(FunctionRegistry&&) = default; FunctionRegistry& operator=(FunctionRegistry&&) = default; // Register a function implementation for the given descriptor. // Function registration should be performed prior to CelExpression creation. absl::Status Register(const cel::FunctionDescriptor& descriptor, std::unique_ptr implementation); // Register a lazily provided function. // Internally, the registry binds a FunctionProvider that provides an overload // at evaluation time by resolving against the overloads provided by an // implementation of cel::ActivationInterface. absl::Status RegisterLazyFunction(const cel::FunctionDescriptor& descriptor); // Find subset of cel::Function implementations that match overload conditions // As types may not be available during expression compilation, // further narrowing of this subset will happen at evaluation stage. // // name - the name of CEL function (as distinct from overload ID); // receiver_style - indicates whether function has receiver style; // types - argument types. If type is not known during compilation, // cel::Kind::kAny should be passed. // // Results refer to underlying registry entries by reference. Results are // invalid after the registry is deleted. std::vector FindStaticOverloads( absl::string_view name, bool receiver_style, absl::Span types) const; std::vector FindStaticOverloadsByArity( absl::string_view name, bool receiver_style, size_t arity) const; // Find subset of cel::Function providers that match overload conditions. // As types may not be available during expression compilation, // further narrowing of this subset will happen at evaluation stage. // // name - the name of CEL function (as distinct from overload ID); // receiver_style - indicates whether function has receiver style; // types - argument types. If type is not known during compilation, // cel::Kind::kAny should be passed. // // Results refer to underlying registry entries by reference. Results are // invalid after the registry is deleted. std::vector FindLazyOverloads( absl::string_view name, bool receiver_style, absl::Span types) const; std::vector FindLazyOverloadsByArity(absl::string_view name, bool receiver_style, size_t arity) const; // Retrieve list of registered function descriptors. This includes both // static and lazy functions. absl::node_hash_map> ListFunctions() const; private: struct StaticFunctionEntry { StaticFunctionEntry(const cel::FunctionDescriptor& descriptor, std::unique_ptr impl) : descriptor(std::make_unique(descriptor)), implementation(std::move(impl)) {} // Extra indirection needed to preserve pointer stability for the // descriptors. std::unique_ptr descriptor; std::unique_ptr implementation; }; struct LazyFunctionEntry { LazyFunctionEntry( const cel::FunctionDescriptor& descriptor, std::unique_ptr provider) : descriptor(std::make_unique(descriptor)), function_provider(std::move(provider)) {} // Extra indirection needed to preserve pointer stability for the // descriptors. std::unique_ptr descriptor; std::unique_ptr function_provider; }; struct RegistryEntry { std::vector static_overloads; std::vector lazy_overloads; }; // Returns whether the descriptor is registered either as a lazy function or // as a static function. bool DescriptorRegistered(const cel::FunctionDescriptor& descriptor) const; // Returns true if after adding this function, the rule "a non-strict // function should have only a single overload" will be preserved. bool ValidateNonStrictOverload( const cel::FunctionDescriptor& descriptor) const; // indexed by function name (not type checker overload id). absl::flat_hash_map functions_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_ ================================================ FILE: runtime/function_registry_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/function_registry.h" #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "common/value.h" #include "internal/testing.h" #include "runtime/activation.h" #include "runtime/function.h" #include "runtime/function_adapter.h" #include "runtime/function_overload_reference.h" #include "runtime/function_provider.h" namespace cel { namespace { using ::absl_testing::StatusIs; using ::cel::runtime_internal::FunctionProvider; using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::SizeIs; using ::testing::Truly; class ConstIntFunction : public cel::Function { public: static cel::FunctionDescriptor MakeDescriptor() { return {"ConstFunction", false, {}}; } absl::StatusOr Invoke(absl::Span args, const InvokeContext& context) const override { return IntValue(42); } }; TEST(FunctionRegistryTest, InsertAndRetrieveLazyFunction) { cel::FunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; FunctionRegistry registry; Activation activation; ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); const auto descriptors = registry.FindLazyOverloads("LazyFunction", false, {}); EXPECT_THAT(descriptors, SizeIs(1)); } // Confirm that lazy and static functions share the same descriptor space: // i.e. you can't insert both a lazy function and a static function for the same // descriptors. TEST(FunctionRegistryTest, LazyAndStaticFunctionShareDescriptorSpace) { FunctionRegistry registry; cel::FunctionDescriptor desc = ConstIntFunction::MakeDescriptor(); ASSERT_OK(registry.RegisterLazyFunction(desc)); absl::Status status = registry.Register(ConstIntFunction::MakeDescriptor(), std::make_unique()); EXPECT_FALSE(status.ok()); } TEST(FunctionRegistryTest, FindStaticOverloadsReturns) { FunctionRegistry registry; cel::FunctionDescriptor desc = ConstIntFunction::MakeDescriptor(); ASSERT_OK(registry.Register(desc, std::make_unique())); std::vector overloads = registry.FindStaticOverloads(desc.name(), false, {}); EXPECT_THAT(overloads, ElementsAre(Truly( [](const cel::FunctionOverloadReference& overload) -> bool { return overload.descriptor.name() == "ConstFunction"; }))) << "Expected single ConstFunction()"; } TEST(FunctionRegistryTest, ListFunctions) { cel::FunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; FunctionRegistry registry; ASSERT_OK(registry.RegisterLazyFunction(lazy_function_desc)); EXPECT_OK(registry.Register(ConstIntFunction::MakeDescriptor(), std::make_unique())); auto registered_functions = registry.ListFunctions(); EXPECT_THAT(registered_functions, SizeIs(2)); EXPECT_THAT(registered_functions["LazyFunction"], SizeIs(1)); EXPECT_THAT(registered_functions["ConstFunction"], SizeIs(1)); } TEST(FunctionRegistryTest, DefaultLazyProviderNoOverloadFound) { FunctionRegistry registry; Activation activation; cel::FunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; EXPECT_OK(registry.RegisterLazyFunction(lazy_function_desc)); auto providers = registry.FindLazyOverloads("LazyFunction", false, {}); ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( absl::optional func, provider.GetFunction({"LazyFunc", false, {cel::Kind::kInt64}}, activation)); EXPECT_EQ(func, absl::nullopt); } TEST(FunctionRegistryTest, DefaultLazyProviderReturnsImpl) { FunctionRegistry registry; Activation activation; EXPECT_OK(registry.RegisterLazyFunction( FunctionDescriptor("LazyFunction", false, {Kind::kAny}))); EXPECT_TRUE(activation.InsertFunction( FunctionDescriptor("LazyFunction", false, {Kind::kInt}), UnaryFunctionAdapter::WrapFunction( [](int64_t x) { return 2 * x; }))); EXPECT_TRUE(activation.InsertFunction( FunctionDescriptor("LazyFunction", false, {Kind::kDouble}), UnaryFunctionAdapter::WrapFunction( [](double x) { return 2 * x; }))); auto providers = registry.FindLazyOverloads("LazyFunction", false, {Kind::kInt}); ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( absl::optional func, provider.GetFunction( FunctionDescriptor("LazyFunction", false, {Kind::kInt}), activation)); ASSERT_TRUE(func.has_value()); EXPECT_EQ(func->descriptor.name(), "LazyFunction"); EXPECT_EQ(func->descriptor.types(), std::vector{cel::Kind::kInt64}); } TEST(FunctionRegistryTest, DefaultLazyProviderAmbiguousOverload) { FunctionRegistry registry; Activation activation; EXPECT_OK(registry.RegisterLazyFunction( FunctionDescriptor("LazyFunction", false, {Kind::kAny}))); EXPECT_TRUE(activation.InsertFunction( FunctionDescriptor("LazyFunction", false, {Kind::kInt}), UnaryFunctionAdapter::WrapFunction( [](int64_t x) { return 2 * x; }))); EXPECT_TRUE(activation.InsertFunction( FunctionDescriptor("LazyFunction", false, {Kind::kDouble}), UnaryFunctionAdapter::WrapFunction( [](double x) { return 2 * x; }))); auto providers = registry.FindLazyOverloads("LazyFunction", false, {Kind::kInt}); ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; EXPECT_THAT( provider.GetFunction( FunctionDescriptor("LazyFunction", false, {Kind::kAny}), activation), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Couldn't resolve function"))); } TEST(FunctionRegistryTest, CanRegisterNonStrictFunction) { { FunctionRegistry registry; cel::FunctionDescriptor descriptor("NonStrictFunction", /*receiver_style=*/false, {Kind::kAny}, /*is_strict=*/false); ASSERT_OK( registry.Register(descriptor, std::make_unique())); EXPECT_THAT( registry.FindStaticOverloads("NonStrictFunction", false, {Kind::kAny}), SizeIs(1)); } { FunctionRegistry registry; cel::FunctionDescriptor descriptor("NonStrictLazyFunction", /*receiver_style=*/false, {Kind::kAny}, /*is_strict=*/false); EXPECT_OK(registry.RegisterLazyFunction(descriptor)); EXPECT_THAT(registry.FindLazyOverloads("NonStrictLazyFunction", false, {Kind::kAny}), SizeIs(1)); } } using NonStrictTestCase = std::tuple; using NonStrictRegistrationFailTest = testing::TestWithParam; TEST_P(NonStrictRegistrationFailTest, IfOtherOverloadExistsRegisteringNonStrictFails) { bool existing_function_is_lazy, new_function_is_lazy; std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); FunctionRegistry registry; cel::FunctionDescriptor descriptor("OverloadedFunction", /*receiver_style=*/false, {Kind::kAny}, /*is_strict=*/true); if (existing_function_is_lazy) { ASSERT_OK(registry.RegisterLazyFunction(descriptor)); } else { ASSERT_OK( registry.Register(descriptor, std::make_unique())); } cel::FunctionDescriptor new_descriptor("OverloadedFunction", /*receiver_style=*/false, {Kind::kAny, Kind::kAny}, /*is_strict=*/false); absl::Status status; if (new_function_is_lazy) { status = registry.RegisterLazyFunction(new_descriptor); } else { status = registry.Register(new_descriptor, std::make_unique()); } EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("Only one overload"))); } TEST_P(NonStrictRegistrationFailTest, IfOtherNonStrictExistsRegisteringStrictFails) { bool existing_function_is_lazy, new_function_is_lazy; std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); FunctionRegistry registry; cel::FunctionDescriptor descriptor("OverloadedFunction", /*receiver_style=*/false, {Kind::kAny}, /*is_strict=*/false); if (existing_function_is_lazy) { ASSERT_OK(registry.RegisterLazyFunction(descriptor)); } else { ASSERT_OK( registry.Register(descriptor, std::make_unique())); } cel::FunctionDescriptor new_descriptor("OverloadedFunction", /*receiver_style=*/false, {Kind::kAny, Kind::kAny}, /*is_strict=*/true); absl::Status status; if (new_function_is_lazy) { status = registry.RegisterLazyFunction(new_descriptor); } else { status = registry.Register(new_descriptor, std::make_unique()); } EXPECT_THAT(status, StatusIs(absl::StatusCode::kAlreadyExists, HasSubstr("Only one overload"))); } TEST_P(NonStrictRegistrationFailTest, CanRegisterStrictFunctionsWithoutLimit) { bool existing_function_is_lazy, new_function_is_lazy; std::tie(existing_function_is_lazy, new_function_is_lazy) = GetParam(); FunctionRegistry registry; cel::FunctionDescriptor descriptor("OverloadedFunction", /*receiver_style=*/false, {Kind::kAny}, /*is_strict=*/true); if (existing_function_is_lazy) { ASSERT_OK(registry.RegisterLazyFunction(descriptor)); } else { ASSERT_OK( registry.Register(descriptor, std::make_unique())); } cel::FunctionDescriptor new_descriptor("OverloadedFunction", /*receiver_style=*/false, {Kind::kAny, Kind::kAny}, /*is_strict=*/true); absl::Status status; if (new_function_is_lazy) { status = registry.RegisterLazyFunction(new_descriptor); } else { status = registry.Register(new_descriptor, std::make_unique()); } EXPECT_OK(status); } INSTANTIATE_TEST_SUITE_P(NonStrictRegistrationFailTest, NonStrictRegistrationFailTest, testing::Combine(testing::Bool(), testing::Bool())); } // namespace } // namespace cel ================================================ FILE: runtime/internal/BUILD ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package( # Internals for cel/runtime. default_visibility = ["//visibility:public"], ) licenses(["notice"]) cc_library( name = "runtime_friend_access", hdrs = ["runtime_friend_access.h"], deps = [ "//common:native_type", "//runtime", "//runtime:runtime_builder", ], ) cc_library( name = "runtime_env", srcs = ["runtime_env.cc"], hdrs = ["runtime_env.h"], deps = [ "//eval/public:cel_function_registry", "//eval/public:cel_type_registry", "//internal:noop_delete", "//internal:well_known_types", "//runtime:function_registry", "//runtime:type_registry", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "runtime_impl", srcs = ["runtime_impl.cc"], hdrs = ["runtime_impl.h"], deps = [ ":runtime_env", "//base:ast", "//base:data", "//common:native_type", "//common:value", "//eval/compiler:flat_expr_builder", "//eval/eval:attribute_trail", "//eval/eval:comprehension_slots", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//internal:casts", "//internal:status_macros", "//internal:well_known_types", "//runtime", "//runtime:activation_interface", "//runtime:function_registry", "//runtime:runtime_options", "//runtime:type_registry", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "convert_constant", srcs = ["convert_constant.cc"], hdrs = ["convert_constant.h"], deps = [ "//common:allocator", "//common:ast", "//common:constant", "//common:value", "//eval/internal:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", ], ) cc_library( name = "errors", srcs = ["errors.cc"], hdrs = ["errors.h"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/time", ], ) cc_library( name = "issue_collector", hdrs = ["issue_collector.h"], deps = [ "//runtime:runtime_issue", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) cc_test( name = "issue_collector_test", srcs = ["issue_collector_test.cc"], deps = [ ":issue_collector", "//internal:testing", "//runtime:runtime_issue", "@com_google_absl//absl/status", ], ) cc_library( name = "function_adapter", hdrs = [ "function_adapter.h", ], deps = [ "//common:casting", "//common:kind", "//common:value", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", ], ) cc_test( name = "function_adapter_test", srcs = ["function_adapter_test.cc"], deps = [ ":function_adapter", "//common:kind", "//common:value", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", ], ) cc_library( name = "runtime_env_testing", testonly = True, srcs = ["runtime_env_testing.cc"], hdrs = ["runtime_env_testing.h"], deps = [ ":runtime_env", "//internal:noop_delete", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "legacy_runtime_type_provider", hdrs = ["legacy_runtime_type_provider.h"], deps = [ "//eval/public/structs:protobuf_descriptor_type_provider", "@com_google_absl//absl/base:nullability", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "runtime_type_provider", srcs = ["runtime_type_provider.cc"], hdrs = ["runtime_type_provider.h"], deps = [ "//common:type", "//common:value", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "attribute_matcher", hdrs = ["attribute_matcher.h"], deps = ["//base:attributes"], ) cc_library( name = "activation_attribute_matcher_access", srcs = ["activation_attribute_matcher_access.cc"], hdrs = ["activation_attribute_matcher_access.h"], deps = [ ":attribute_matcher", "//eval/public:activation", "//runtime:activation", "@com_google_absl//absl/base:nullability", ], ) ================================================ FILE: runtime/internal/activation_attribute_matcher_access.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/activation_attribute_matcher_access.h" #include #include #include "absl/base/nullability.h" #include "eval/public/activation.h" #include "runtime/activation.h" #include "runtime/internal/attribute_matcher.h" namespace cel::runtime_internal { void ActivationAttributeMatcherAccess::SetAttributeMatcher( google::api::expr::runtime::Activation& activation, const AttributeMatcher* matcher) { activation.SetAttributeMatcher(matcher); } void ActivationAttributeMatcherAccess::SetAttributeMatcher( google::api::expr::runtime::Activation& activation, std::unique_ptr matcher) { activation.SetAttributeMatcher(std::move(matcher)); } const AttributeMatcher* absl_nullable ActivationAttributeMatcherAccess::GetAttributeMatcher( const google::api::expr::runtime::BaseActivation& activation) { return activation.GetAttributeMatcher(); } void ActivationAttributeMatcherAccess::SetAttributeMatcher( Activation& activation, const AttributeMatcher* matcher) { activation.SetAttributeMatcher(matcher); } void ActivationAttributeMatcherAccess::SetAttributeMatcher( Activation& activation, std::unique_ptr matcher) { activation.SetAttributeMatcher(std::move(matcher)); } const AttributeMatcher* absl_nullable ActivationAttributeMatcherAccess::GetAttributeMatcher( const ActivationInterface& activation) { return activation.GetAttributeMatcher(); } } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/activation_attribute_matcher_access.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ACTIVATION_MATCHER_ACCESS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ACTIVATION_MATCHER_ACCESS_H_ #include #include "absl/base/nullability.h" #include "runtime/internal/attribute_matcher.h" namespace google::api::expr::runtime { class Activation; class BaseActivation; } // namespace google::api::expr::runtime namespace cel { class Activation; class ActivationInterface; } // namespace cel namespace cel::runtime_internal { class ActivationAttributeMatcherAccess { public: static void SetAttributeMatcher( google::api::expr::runtime::Activation& activation, const AttributeMatcher* matcher); static void SetAttributeMatcher( google::api::expr::runtime::Activation& activation, std::unique_ptr matcher); static const AttributeMatcher* absl_nullable GetAttributeMatcher( const google::api::expr::runtime::BaseActivation& activation); static void SetAttributeMatcher(Activation& activation, const AttributeMatcher* matcher); static void SetAttributeMatcher( Activation& activation, std::unique_ptr matcher); static const AttributeMatcher* absl_nullable GetAttributeMatcher( const ActivationInterface& activation); }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ACTIVATION_MATCHER_ACCESS_H_ ================================================ FILE: runtime/internal/attribute_matcher.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ATTRIBUTE_MATCHER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ATTRIBUTE_MATCHER_H_ #include "base/attribute.h" namespace cel::runtime_internal { // Interface for matching unknown and missing attributes against the // observed attribute trail at runtime. class AttributeMatcher { public: using MatchResult = cel::AttributePattern::MatchType; virtual ~AttributeMatcher() = default; // Checks whether the attribute trail matches any unknown patterns. // Used to identify and collect referenced unknowns in an UnknownValue. virtual MatchResult CheckForUnknown(const Attribute& attr [[maybe_unused]]) const { return MatchResult::NONE; }; // Checks whether the attribute trail matches any missing patterns. // Used to identify missing attributes, and report an error if referenced // directly. virtual MatchResult CheckForMissing(const Attribute& attr [[maybe_unused]]) const { return MatchResult::NONE; }; }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ATTRIBUTE_MATCHER_H_ ================================================ FILE: runtime/internal/convert_constant.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/convert_constant.h" #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/variant.h" #include "common/allocator.h" #include "common/constant.h" #include "common/value.h" #include "eval/internal/errors.h" namespace cel::runtime_internal { namespace { using ::cel::Constant; struct ConvertVisitor { Allocator<> allocator; absl::StatusOr operator()(absl::monostate) { return absl::InvalidArgumentError("unspecified constant"); } absl::StatusOr operator()(std::nullptr_t) { return NullValue(); } absl::StatusOr operator()(bool value) { return BoolValue(value); } absl::StatusOr operator()(int64_t value) { return IntValue(value); } absl::StatusOr operator()(uint64_t value) { return UintValue(value); } absl::StatusOr operator()(double value) { return DoubleValue(value); } absl::StatusOr operator()(const cel::StringConstant& value) { return StringValue(allocator, value); } absl::StatusOr operator()(const cel::BytesConstant& value) { return BytesValue(allocator, value); } absl::StatusOr operator()(const absl::Duration duration) { if (duration >= kDurationHigh || duration <= kDurationLow) { return ErrorValue(*DurationOverflowError()); } return UnsafeDurationValue(duration); } absl::StatusOr operator()(const absl::Time timestamp) { return UnsafeTimestampValue(timestamp); } }; } // namespace // Converts an Ast constant into a runtime value, managed according to the // given value factory. // // A status maybe returned if value creation fails. absl::StatusOr ConvertConstant(const Constant& constant, Allocator<> allocator) { return absl::visit(ConvertVisitor{allocator}, constant.constant_kind()); } } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/convert_constant.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_CONVERT_CONSTANT_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_CONVERT_CONSTANT_H_ #include "absl/status/statusor.h" #include "common/allocator.h" #include "common/ast.h" #include "common/value.h" namespace cel::runtime_internal { // Adapt AST constant to a Value. // // Underlying data is copied for string types to keep the program independent // from the input AST. // // The evaluator assumes most ast constants are valid so unchecked ValueManager // methods are used. // // A status may still be returned if value creation fails according to // value_factory's policy. absl::StatusOr ConvertConstant(const Constant& constant, Allocator<> allocator); } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_CONVERT_CONSTANT_H_ ================================================ FILE: runtime/internal/errors.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/errors.h" #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" namespace cel::runtime_internal { const absl::Status* DurationOverflowError() { static const auto* const kDurationOverflow = new absl::Status( absl::StatusCode::kInvalidArgument, "Duration is out of range"); return kDurationOverflow; } absl::Status CreateNoSuchKeyError(absl::string_view key) { return absl::NotFoundError(absl::StrCat(kErrNoSuchKey, " : ", key)); } absl::Status CreateNoMatchingOverloadError(absl::string_view fn) { return absl::UnknownError( absl::StrCat(kErrNoMatchingOverload, fn.empty() ? "" : " : ", fn)); } absl::Status CreateNoSuchFieldError(absl::string_view field) { return absl::Status( absl::StatusCode::kNotFound, absl::StrCat(kErrNoSuchField, field.empty() ? "" : " : ", field)); } absl::Status CreateMissingAttributeError( absl::string_view missing_attribute_path) { absl::Status result = absl::InvalidArgumentError( absl::StrCat(kErrMissingAttribute, missing_attribute_path)); result.SetPayload(kPayloadUrlMissingAttributePath, absl::Cord(missing_attribute_path)); return result; } absl::Status CreateInvalidMapKeyTypeError(absl::string_view key_type) { return absl::InvalidArgumentError( absl::StrCat("Invalid map key type: '", key_type, "'")); } absl::Status CreateUnknownFunctionResultError(absl::string_view help_message) { absl::Status result = absl::UnavailableError( absl::StrCat("Unknown function result: ", help_message)); result.SetPayload(kPayloadUrlUnknownFunctionResult, absl::Cord("true")); return result; } absl::Status CreateError(absl::string_view message, absl::StatusCode code) { return absl::Status(code, message); } } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/errors.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Factories and constants for well-known CEL errors. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ERRORS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ERRORS_H_ #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" namespace cel::runtime_internal { constexpr absl::string_view kErrNoMatchingOverload = "No matching overloads found"; constexpr absl::string_view kErrNoSuchField = "no_such_field"; constexpr absl::string_view kErrNoSuchKey = "Key not found in map"; // Error name for MissingAttributeError indicating that evaluation has // accessed an attribute whose value is undefined. go/terminal-unknown constexpr absl::string_view kErrMissingAttribute = "MissingAttributeError: "; constexpr absl::string_view kPayloadUrlMissingAttributePath = "missing_attribute_path"; constexpr absl::string_view kPayloadUrlUnknownFunctionResult = "cel_is_unknown_function_result"; // Exclusive bounds for valid duration values. constexpr absl::Duration kDurationHigh = absl::Seconds(315576000001); constexpr absl::Duration kDurationLow = absl::Seconds(-315576000001); const absl::Status* DurationOverflowError(); // At runtime, no matching overload could be found for a function invocation. absl::Status CreateNoMatchingOverloadError(absl::string_view fn); // No such field for struct access. absl::Status CreateNoSuchFieldError(absl::string_view field); // No such key for map access. absl::Status CreateNoSuchKeyError(absl::string_view key); // Invalid key type used for map index. absl::Status CreateInvalidMapKeyTypeError(absl::string_view key_type); // A missing attribute was accessed. Attributes may be declared as missing to // they are not well defined at evaluation time. absl::Status CreateMissingAttributeError( absl::string_view missing_attribute_path); // Function result is unknown. The evaluator may convert this to an // UnknownValue if enabled. absl::Status CreateUnknownFunctionResultError(absl::string_view help_message); // The default error type uses absl::StatusCode::kUnknown. In general, a more // specific error should be used. absl::Status CreateError(absl::string_view message, absl::StatusCode code = absl::StatusCode::kUnknown); } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ERRORS_H_ ================================================ FILE: runtime/internal/function_adapter.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Definitions for implementation details of the function adapter utility. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_FUNCTION_ADAPTER_H_ #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" #include "common/casting.h" #include "common/kind.h" #include "common/value.h" namespace cel::runtime_internal { // Helper for triggering static asserts in an unspecialized template overload. template struct UnhandledType : std::false_type {}; // Adapts the type param Type to the appropriate Kind. // A static assertion fails if the provided type does not map to a cel::Value // kind. template constexpr Kind AdaptedKind() { static_assert(UnhandledType::value, "Unsupported primitive type to cel::Kind conversion"); return Kind::kNotForUseWithExhaustiveSwitchStatements; } template <> constexpr Kind AdaptedKind() { return Kind::kInt64; } template <> constexpr Kind AdaptedKind() { return Kind::kUint64; } template <> constexpr Kind AdaptedKind() { return Kind::kDouble; } template <> constexpr Kind AdaptedKind() { return Kind::kBool; } template <> constexpr Kind AdaptedKind() { return Kind::kTimestamp; } template <> constexpr Kind AdaptedKind() { return Kind::kDuration; } // Value types without a generic C++ type representation can be referenced by // cref or value of the cel::*Value type. #define VALUE_ADAPTED_KIND_OVL(value_type, kind) \ template <> \ constexpr Kind AdaptedKind() { \ return kind; \ } \ \ template <> \ constexpr Kind AdaptedKind() { \ return kind; \ } VALUE_ADAPTED_KIND_OVL(Value, Kind::kAny); VALUE_ADAPTED_KIND_OVL(StringValue, Kind::kString); VALUE_ADAPTED_KIND_OVL(BytesValue, Kind::kBytes); VALUE_ADAPTED_KIND_OVL(StructValue, Kind::kStruct); VALUE_ADAPTED_KIND_OVL(MapValue, Kind::kMap); VALUE_ADAPTED_KIND_OVL(ListValue, Kind::kList); VALUE_ADAPTED_KIND_OVL(NullValue, Kind::kNullType); VALUE_ADAPTED_KIND_OVL(OpaqueValue, Kind::kOpaque); VALUE_ADAPTED_KIND_OVL(TypeValue, Kind::kType); #undef VALUE_ADAPTED_KIND_OVL // Adapt a Value to its corresponding argument type in a wrapped c++ // function. struct ValueToAdaptedVisitor { absl::Status operator()(int64_t* out) const { if (!input.IsInt()) { return absl::InvalidArgumentError("expected int value"); } *out = input.GetInt().NativeValue(); return absl::OkStatus(); } absl::Status operator()(uint64_t* out) const { if (!input.IsUint()) { return absl::InvalidArgumentError("expected uint value"); } *out = input.GetUint().NativeValue(); return absl::OkStatus(); } absl::Status operator()(double* out) const { if (!input.IsDouble()) { return absl::InvalidArgumentError("expected double value"); } *out = input.GetDouble().NativeValue(); return absl::OkStatus(); } absl::Status operator()(bool* out) const { if (!input.IsBool()) { return absl::InvalidArgumentError("expected bool value"); } *out = input.GetBool().NativeValue(); return absl::OkStatus(); } absl::Status operator()(absl::Time* out) const { if (!input.IsTimestamp()) { return absl::InvalidArgumentError("expected timestamp value"); } *out = input.GetTimestamp().ToTime(); return absl::OkStatus(); } absl::Status operator()(absl::Duration* out) const { if (!input.IsDuration()) { return absl::InvalidArgumentError("expected duration value"); } *out = input.GetDuration().ToDuration(); return absl::OkStatus(); } absl::Status operator()(Value* out) const { *out = input; return absl::OkStatus(); } absl::Status operator()(const Value** out) const { *out = &input; return absl::OkStatus(); } template absl::Status operator()(T* out) const { if (!InstanceOf>(input)) { return absl::InvalidArgumentError( absl::StrCat("expected ", ValueKindToString(T::kKind), " value")); } *out = Cast>(input); return absl::OkStatus(); } template absl::Status operator()(T** out) const { if (!InstanceOf>(input)) { return absl::InvalidArgumentError( absl::StrCat("expected ", ValueKindToString(T::kKind), " value")); } static_assert(std::is_lvalue_reference_v< decltype(Cast>(input))>, "expected l-value reference return type for Cast."); *out = &Cast>(input); return absl::OkStatus(); } const Value& input; }; // Adapts the return value of a wrapped C++ function to its corresponding // Value representation. struct AdaptedToValueVisitor { absl::StatusOr operator()(int64_t in) { return IntValue(in); } absl::StatusOr operator()(uint64_t in) { return UintValue(in); } absl::StatusOr operator()(double in) { return DoubleValue(in); } absl::StatusOr operator()(bool in) { return BoolValue(in); } absl::StatusOr operator()(absl::Time in) { // Type matching may have already occurred. It's too late to change up the // type and return an error. return TimestampValue(in); } absl::StatusOr operator()(absl::Duration in) { // Type matching may have already occurred. It's too late to change up the // type and return an error. return DurationValue(in); } absl::StatusOr operator()(Value in) { return in; } template absl::StatusOr operator()(T in) { return in; } // Special case for StatusOr return value -- wrap the underlying value if // present, otherwise return the status. template absl::StatusOr operator()(absl::StatusOr wrapped) { if (!wrapped.ok()) { return std::move(wrapped).status(); } return this->operator()(std::move(wrapped).value()); } }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_FUNCTION_ADAPTER_H_ ================================================ FILE: runtime/internal/function_adapter_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/function_adapter.h" #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/time/time.h" #include "common/kind.h" #include "common/value.h" #include "internal/testing.h" namespace cel::runtime_internal { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; static_assert(AdaptedKind() == Kind::kInt, "int adapts to int64_t"); static_assert(AdaptedKind() == Kind::kUint, "uint adapts to uint64_t"); static_assert(AdaptedKind() == Kind::kDouble, "double adapts to double"); static_assert(AdaptedKind() == Kind::kBool, "bool adapts to bool"); static_assert(AdaptedKind() == Kind::kTimestamp, "timestamp adapts to absl::Time"); static_assert(AdaptedKind() == Kind::kDuration, "duration adapts to absl::Duration"); // Handle types. static_assert(AdaptedKind() == Kind::kAny, "any adapts to Value"); static_assert(AdaptedKind() == Kind::kString, "string adapts to String"); static_assert(AdaptedKind() == Kind::kBytes, "bytes adapts to Bytes"); static_assert(AdaptedKind() == Kind::kStruct, "struct adapts to StructValue"); static_assert(AdaptedKind() == Kind::kList, "list adapts to ListValue"); static_assert(AdaptedKind() == Kind::kMap, "map adapts to MapValue"); static_assert(AdaptedKind() == Kind::kNullType, "null adapts to NullValue"); static_assert(AdaptedKind() == Kind::kAny, "any adapts to const Value&"); static_assert(AdaptedKind() == Kind::kString, "string adapts to const String&"); static_assert(AdaptedKind() == Kind::kBytes, "bytes adapts to const Bytes&"); static_assert(AdaptedKind() == Kind::kStruct, "struct adapts to const StructValue&"); static_assert(AdaptedKind() == Kind::kList, "list adapts to const ListValue&"); static_assert(AdaptedKind() == Kind::kMap, "map adapts to const MapValue&"); static_assert(AdaptedKind() == Kind::kNullType, "null adapts to const NullValue&"); class ValueToAdaptedVisitorTest : public ::testing::Test {}; TEST_F(ValueToAdaptedVisitorTest, Int) { Value v = cel::IntValue(10); int64_t out; ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); EXPECT_EQ(out, 10); } TEST_F(ValueToAdaptedVisitorTest, IntWrongKind) { Value v = cel::UintValue(10); int64_t out; EXPECT_THAT( ValueToAdaptedVisitor{v}(&out), StatusIs(absl::StatusCode::kInvalidArgument, "expected int value")); } TEST_F(ValueToAdaptedVisitorTest, Uint) { Value v = cel::UintValue(11); uint64_t out; ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); EXPECT_EQ(out, 11); } TEST_F(ValueToAdaptedVisitorTest, UintWrongKind) { Value v = cel::IntValue(11); uint64_t out; EXPECT_THAT( ValueToAdaptedVisitor{v}(&out), StatusIs(absl::StatusCode::kInvalidArgument, "expected uint value")); } TEST_F(ValueToAdaptedVisitorTest, Double) { Value v = cel::DoubleValue(12.0); double out; ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); EXPECT_EQ(out, 12.0); } TEST_F(ValueToAdaptedVisitorTest, DoubleWrongKind) { Value v = cel::UintValue(10); double out; EXPECT_THAT( ValueToAdaptedVisitor{v}(&out), StatusIs(absl::StatusCode::kInvalidArgument, "expected double value")); } TEST_F(ValueToAdaptedVisitorTest, Bool) { Value v = cel::BoolValue(false); bool out; ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); EXPECT_EQ(out, false); } TEST_F(ValueToAdaptedVisitorTest, BoolWrongKind) { Value v = cel::UintValue(10); bool out; EXPECT_THAT( ValueToAdaptedVisitor{v}(&out), StatusIs(absl::StatusCode::kInvalidArgument, "expected bool value")); } TEST_F(ValueToAdaptedVisitorTest, Timestamp) { Value v = cel::TimestampValue(absl::UnixEpoch() + absl::Seconds(1)); absl::Time out; ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); EXPECT_EQ(out, absl::UnixEpoch() + absl::Seconds(1)); } TEST_F(ValueToAdaptedVisitorTest, TimestampWrongKind) { Value v = cel::UintValue(10); absl::Time out; EXPECT_THAT( ValueToAdaptedVisitor{v}(&out), StatusIs(absl::StatusCode::kInvalidArgument, "expected timestamp value")); } TEST_F(ValueToAdaptedVisitorTest, Duration) { Value v = cel::DurationValue(absl::Seconds(5)); absl::Duration out; ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); EXPECT_EQ(out, absl::Seconds(5)); } TEST_F(ValueToAdaptedVisitorTest, DurationWrongKind) { Value v = cel::UintValue(10); absl::Duration out; EXPECT_THAT( ValueToAdaptedVisitor{v}(&out), StatusIs(absl::StatusCode::kInvalidArgument, "expected duration value")); } TEST_F(ValueToAdaptedVisitorTest, String) { Value v = cel::StringValue("string"); StringValue out; ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); EXPECT_EQ(out.ToString(), "string"); } TEST_F(ValueToAdaptedVisitorTest, StringWrongKind) { Value v = cel::UintValue(10); StringValue out; EXPECT_THAT( ValueToAdaptedVisitor{v}(&out), StatusIs(absl::StatusCode::kInvalidArgument, "expected string value")); } TEST_F(ValueToAdaptedVisitorTest, Bytes) { Value v = cel::BytesValue("bytes"); BytesValue out; ASSERT_THAT(ValueToAdaptedVisitor{v}(&out), IsOk()); EXPECT_EQ(out.ToString(), "bytes"); } TEST_F(ValueToAdaptedVisitorTest, BytesWrongKind) { Value v = cel::UintValue(10); BytesValue out; EXPECT_THAT( ValueToAdaptedVisitor{v}(&out), StatusIs(absl::StatusCode::kInvalidArgument, "expected bytes value")); } class AdaptedToValueVisitorTest : public ::testing::Test {}; TEST_F(AdaptedToValueVisitorTest, Int) { int64_t value = 10; ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsInt()); EXPECT_EQ(result.GetInt().NativeValue(), 10); } TEST_F(AdaptedToValueVisitorTest, Double) { double value = 10; ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsDouble()); EXPECT_EQ(result.GetDouble().NativeValue(), 10.0); } TEST_F(AdaptedToValueVisitorTest, Uint) { uint64_t value = 10; ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsUint()); EXPECT_EQ(result.GetUint().NativeValue(), 10); } TEST_F(AdaptedToValueVisitorTest, Bool) { bool value = true; ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.GetBool().NativeValue(), true); } TEST_F(AdaptedToValueVisitorTest, Timestamp) { absl::Time value = absl::UnixEpoch() + absl::Seconds(10); ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsTimestamp()); EXPECT_EQ(result.GetTimestamp().ToTime(), absl::UnixEpoch() + absl::Seconds(10)); } TEST_F(AdaptedToValueVisitorTest, Duration) { absl::Duration value = absl::Seconds(5); ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsDuration()); EXPECT_EQ(result.GetDuration().ToDuration(), absl::Seconds(5)); } TEST_F(AdaptedToValueVisitorTest, String) { StringValue value = cel::StringValue("str"); ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.GetString().ToString(), "str"); } TEST_F(AdaptedToValueVisitorTest, Bytes) { BytesValue value = cel::BytesValue("bytes"); ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsBytes()); EXPECT_EQ(result.GetBytes().ToString(), "bytes"); } TEST_F(AdaptedToValueVisitorTest, StatusOrValue) { absl::StatusOr value = 10; ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(value)); ASSERT_TRUE(result.IsInt()); EXPECT_EQ(result.GetInt().NativeValue(), 10); } TEST_F(AdaptedToValueVisitorTest, StatusOrError) { absl::StatusOr value = absl::InternalError("test_error"); EXPECT_THAT(AdaptedToValueVisitor{}(value).status(), StatusIs(absl::StatusCode::kInternal, "test_error")); } TEST_F(AdaptedToValueVisitorTest, Any) { auto handle = cel::ErrorValue(absl::InternalError("test_error")); ASSERT_OK_AND_ASSIGN(auto result, AdaptedToValueVisitor{}(handle)); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.GetError().NativeValue(), StatusIs(absl::StatusCode::kInternal, "test_error")); } } // namespace } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/issue_collector.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ISSUE_COLLECTOR_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ISSUE_COLLECTOR_H_ #include #include #include "absl/status/status.h" #include "absl/types/span.h" #include "runtime/runtime_issue.h" namespace cel::runtime_internal { // IssueCollector collects issues and reports absl::Status according to the // configured severity limit. class IssueCollector { public: // Args: // severity: inclusive limit for issues to return as non-ok absl::Status. explicit IssueCollector(RuntimeIssue::Severity severity_limit) : severity_limit_(severity_limit) {} // move-only. IssueCollector(const IssueCollector&) = delete; IssueCollector& operator=(const IssueCollector&) = delete; IssueCollector(IssueCollector&&) = default; IssueCollector& operator=(IssueCollector&&) = default; // Collect an Issue. // Returns a status according to the IssueCollector's policy and the given // Issue. // The Issue is always added to issues, regardless of whether AddIssue returns // a non-ok status. absl::Status AddIssue(RuntimeIssue issue) { issues_.push_back(std::move(issue)); if (issues_.back().severity() >= severity_limit_) { return issues_.back().ToStatus(); } return absl::OkStatus(); } absl::Span issues() const { return issues_; } std::vector ExtractIssues() { return std::move(issues_); } private: RuntimeIssue::Severity severity_limit_; std::vector issues_; }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_ISSUE_COLLECTOR_H_ ================================================ FILE: runtime/internal/issue_collector_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/issue_collector.h" #include "absl/status/status.h" #include "internal/testing.h" #include "runtime/runtime_issue.h" namespace cel::runtime_internal { namespace { using ::absl_testing::StatusIs; using ::testing::ElementsAre; using ::testing::Truly; template bool ApplyMatcher(Matcher m, const T& t) { return static_cast>(m).Matches(t); } TEST(IssueCollector, CollectsIssues) { IssueCollector issue_collector(RuntimeIssue::Severity::kError); EXPECT_THAT(issue_collector.AddIssue( RuntimeIssue::CreateError(absl::InvalidArgumentError("e1"))), StatusIs(absl::StatusCode::kInvalidArgument, "e1")); ASSERT_OK(issue_collector.AddIssue(RuntimeIssue::CreateWarning( absl::InvalidArgumentError("w1"), RuntimeIssue::ErrorCode::kNoMatchingOverload))); EXPECT_THAT( issue_collector.issues(), ElementsAre( Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kError && issue.error_code() == RuntimeIssue::ErrorCode::kOther && ApplyMatcher( StatusIs(absl::StatusCode::kInvalidArgument, "e1"), issue.ToStatus()); }), Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kWarning && issue.error_code() == RuntimeIssue::ErrorCode::kNoMatchingOverload && ApplyMatcher( StatusIs(absl::StatusCode::kInvalidArgument, "w1"), issue.ToStatus()); }))); } TEST(IssueCollector, ReturnsStatusAtLimit) { IssueCollector issue_collector(RuntimeIssue::Severity::kWarning); EXPECT_THAT(issue_collector.AddIssue( RuntimeIssue::CreateError(absl::InvalidArgumentError("e1"))), StatusIs(absl::StatusCode::kInvalidArgument, "e1")); EXPECT_THAT(issue_collector.AddIssue(RuntimeIssue::CreateWarning( absl::InvalidArgumentError("w1"), RuntimeIssue::ErrorCode::kNoMatchingOverload)), StatusIs(absl::StatusCode::kInvalidArgument, "w1")); EXPECT_THAT( issue_collector.issues(), ElementsAre( Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kError && issue.error_code() == RuntimeIssue::ErrorCode::kOther && ApplyMatcher( StatusIs(absl::StatusCode::kInvalidArgument, "e1"), issue.ToStatus()); }), Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kWarning && issue.error_code() == RuntimeIssue::ErrorCode::kNoMatchingOverload && ApplyMatcher( StatusIs(absl::StatusCode::kInvalidArgument, "w1"), issue.ToStatus()); }))); } } // namespace } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/legacy_runtime_type_provider.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ #include "absl/base/nullability.h" #include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { class LegacyRuntimeTypeProvider final : public google::api::expr::runtime::ProtobufDescriptorProvider { public: LegacyRuntimeTypeProvider( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nullable message_factory) : google::api::expr::runtime::ProtobufDescriptorProvider( descriptor_pool, message_factory) {} }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_LEGACY_RUNTIME_TYPE_PROVIDER_H_ ================================================ FILE: runtime/internal/runtime_env.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/runtime_env.h" #include #include #include #include "absl/base/nullability.h" #include "absl/synchronization/mutex.h" #include "internal/noop_delete.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { RuntimeEnv::KeepAlives::~KeepAlives() { while (!deque.empty()) { deque.pop_back(); } } google::protobuf::MessageFactory* absl_nonnull RuntimeEnv::MutableMessageFactory() const { google::protobuf::MessageFactory* absl_nullable shared_message_factory = message_factory_ptr.load(std::memory_order_relaxed); if (shared_message_factory != nullptr) { return shared_message_factory; } absl::MutexLock lock(message_factory_mutex); shared_message_factory = message_factory_ptr.load(std::memory_order_relaxed); if (shared_message_factory == nullptr) { if (descriptor_pool.get() == google::protobuf::DescriptorPool::generated_pool()) { // Using the generated descriptor pool, just use the generated message // factory. message_factory = std::shared_ptr( google::protobuf::MessageFactory::generated_factory(), internal::NoopDeleteFor()); } else { auto dynamic_message_factory = std::make_shared(); // Ensure we do not delegate to the generated factory, if the default // every changes. We prefer being hermetic. dynamic_message_factory->SetDelegateToGeneratedFactory(false); message_factory = std::move(dynamic_message_factory); } shared_message_factory = message_factory.get(); message_factory_ptr.store(shared_message_factory, std::memory_order_seq_cst); } return shared_message_factory; } void RuntimeEnv::KeepAlive(std::shared_ptr keep_alive) { if (keep_alive == nullptr) { return; } keep_alives.deque.push_back(std::move(keep_alive)); } } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/runtime_env.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_type_registry.h" #include "internal/well_known_types.h" #include "runtime/function_registry.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { // Shared state used by the runtime during creation, configuration, planning, // and evaluation. Passed around via `std::shared_ptr`. // // TODO(uncreated-issue/66): Make this a class. struct RuntimeEnv final { explicit RuntimeEnv(absl_nonnull std::shared_ptr descriptor_pool, absl_nullable std::shared_ptr message_factory = nullptr) : descriptor_pool(std::move(descriptor_pool)), message_factory(std::move(message_factory)), legacy_type_registry(this->descriptor_pool.get(), this->message_factory.get()), type_registry(legacy_type_registry.InternalGetModernRegistry()), function_registry(legacy_function_registry.InternalGetRegistry()) { if (this->message_factory != nullptr) { message_factory_ptr.store(this->message_factory.get(), std::memory_order_seq_cst); } } // Not copyable or moveable. RuntimeEnv(const RuntimeEnv&) = delete; RuntimeEnv(RuntimeEnv&&) = delete; RuntimeEnv& operator=(const RuntimeEnv&) = delete; RuntimeEnv& operator=(RuntimeEnv&&) = delete; // Ideally the environment would already be initialized, but things are a bit // awkward. This should only be called once immediately after construction. absl::Status Initialize() { return well_known_types.Initialize(descriptor_pool.get()); } bool IsInitialized() const { return well_known_types.IsInitialized(); } ABSL_ATTRIBUTE_UNUSED const absl_nonnull std::shared_ptr descriptor_pool; private: // These fields deal with a message factory that is lazily initialized as // needed. This might be called during the planning phase of an expression or // during evaluation. We want the ability to get the message factory when it // is already created to be cheap, so we use an atomic and a mutex for the // slow path. // // Do not access any of these fields directly, use member functions. mutable absl::Mutex message_factory_mutex; mutable absl_nullable std::shared_ptr message_factory ABSL_GUARDED_BY(message_factory_mutex); // std::atomic> is not really a simple atomic, so we // avoid it. mutable std::atomic message_factory_ptr = nullptr; struct KeepAlives final { KeepAlives() = default; ~KeepAlives(); // Not copyable or moveable. KeepAlives(const KeepAlives&) = delete; KeepAlives(KeepAlives&&) = delete; KeepAlives& operator=(const KeepAlives&) = delete; KeepAlives& operator=(KeepAlives&&) = delete; std::deque> deque; }; KeepAlives keep_alives; public: // Because of legacy shenanigans, we use shared_ptr here. For legacy, this is // an unowned shared_ptr (a noop deleter) pointing to the modern equivalent // which is a member of the legacy variant. google::api::expr::runtime::CelTypeRegistry legacy_type_registry; google::api::expr::runtime::CelFunctionRegistry legacy_function_registry; TypeRegistry& type_registry; FunctionRegistry& function_registry; well_known_types::Reflection well_known_types; google::protobuf::MessageFactory* absl_nonnull MutableMessageFactory() const ABSL_ATTRIBUTE_LIFETIME_BOUND; // Not thread safe. Adds `keep_alive` to a list owned by this environment // and ensures it survives at least as long as this environment. Keep alives // are released in reverse order of their registration. This mimics normal // destructor rules of members. // // IMPORTANT: This should only be when building the runtime, and not after. void KeepAlive(std::shared_ptr keep_alive); }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_H_ ================================================ FILE: runtime/internal/runtime_env_testing.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/runtime_env_testing.h" #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "internal/noop_delete.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/internal/runtime_env.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { absl_nonnull std::shared_ptr NewTestingRuntimeEnv() { auto env = std::make_shared( internal::GetSharedTestingDescriptorPool(), std::shared_ptr( internal::GetTestingMessageFactory(), internal::NoopDeleteFor())); ABSL_CHECK_OK(env->Initialize()); // Crash OK return env; } } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/runtime_env_testing.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ #include #include "absl/base/nullability.h" #include "runtime/internal/runtime_env.h" namespace cel::runtime_internal { absl_nonnull std::shared_ptr NewTestingRuntimeEnv(); } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_ENV_TESTING_H_ ================================================ FILE: runtime/internal/runtime_friend_access.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_FRIEND_ACCESS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_FRIEND_ACCESS_H_ #include "common/native_type.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" namespace cel::runtime_internal { // Provide accessors for friend-visibility internal runtime details. // // CEL supported runtime extensions need implementation specific details to work // correctly. We restrict access to prevent external usages since we don't // guarantee stability on the implementation details. class RuntimeFriendAccess { public: // Access underlying runtime instance. static Runtime& GetMutableRuntime(RuntimeBuilder& builder) { return builder.runtime(); } // Return the internal type_id for the runtime instance for checked down // casting. static NativeTypeId RuntimeTypeId(Runtime& runtime) { return runtime.GetNativeTypeId(); } }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_RUNTIME_EXTENSIONS_FRIEND_ACCESS_H_ ================================================ FILE: runtime/internal/runtime_impl.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/runtime_impl.h" #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "base/ast.h" #include "base/type_provider.h" #include "common/native_type.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/comprehension_slots.h" #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "internal/casts.h" #include "internal/status_macros.h" #include "runtime/activation_interface.h" #include "runtime/runtime.h" #include "google/protobuf/arena.h" namespace cel::runtime_internal { namespace { using ::google::api::expr::runtime::AttributeTrail; using ::google::api::expr::runtime::ComprehensionSlots; using ::google::api::expr::runtime::DirectExpressionStep; using ::google::api::expr::runtime::ExecutionFrameBase; using ::google::api::expr::runtime::FlatExpression; using ::google::api::expr::runtime::WrappedDirectStep; class ProgramImpl final : public TraceableProgram { public: using EvaluationListener = TraceableProgram::EvaluationListener; ProgramImpl( const std::shared_ptr& environment, FlatExpression impl) : environment_(environment), impl_(std::move(impl)) {} absl::StatusOr TraceImpl( const ActivationInterface& activation, EvaluationListener evaluation_listener, google::protobuf::Arena* absl_nonnull arena, const EvaluateOptions& options) const override { ABSL_DCHECK(arena != nullptr); auto state = impl_.MakeEvaluatorState(environment_->descriptor_pool.get(), options.message_factory != nullptr ? options.message_factory : environment_->MutableMessageFactory(), arena); return impl_.EvaluateWithCallback(activation, options.embedder_context, std::move(evaluation_listener), state); } const TypeProvider& GetTypeProvider() const override { return environment_->type_registry.GetComposedTypeProvider(); } private: // Keep the Runtime environment alive while programs reference it. std::shared_ptr environment_; FlatExpression impl_; }; class RecursiveProgramImpl final : public TraceableProgram { public: using EvaluationListener = TraceableProgram::EvaluationListener; RecursiveProgramImpl( const std::shared_ptr& environment, FlatExpression impl, const DirectExpressionStep* absl_nonnull root) : environment_(environment), impl_(std::move(impl)), root_(root) {} absl::StatusOr TraceImpl( const ActivationInterface& activation, EvaluationListener evaluation_listener, google::protobuf::Arena* absl_nonnull arena, const EvaluateOptions& options) const override { ABSL_DCHECK(arena != nullptr); ComprehensionSlots slots(impl_.comprehension_slots_size()); ExecutionFrameBase frame(activation, std::move(evaluation_listener), impl_.options(), GetTypeProvider(), environment_->descriptor_pool.get(), options.message_factory != nullptr ? options.message_factory : environment_->MutableMessageFactory(), arena, options.embedder_context, slots); Value result; AttributeTrail attribute; CEL_RETURN_IF_ERROR(root_->Evaluate(frame, result, attribute)); return result; } const TypeProvider& GetTypeProvider() const override { return environment_->type_registry.GetComposedTypeProvider(); } private: // Keep the Runtime environment alive while programs reference it. std::shared_ptr environment_; FlatExpression impl_; const DirectExpressionStep* absl_nonnull root_; }; } // namespace absl::StatusOr> RuntimeImpl::CreateProgram( std::unique_ptr ast, const Runtime::CreateProgramOptions& options) const { return CreateTraceableProgram(std::move(ast), options); } absl::StatusOr> RuntimeImpl::CreateTraceableProgram( std::unique_ptr ast, const Runtime::CreateProgramOptions& options) const { CEL_ASSIGN_OR_RETURN(auto flat_expr, expr_builder_.CreateExpressionImpl( std::move(ast), options.issues)); // Special case if the program is fully recursive. // // This implementation avoids unnecessary allocs at evaluation time which // improves performance notably for small expressions. if (expr_builder_.options().max_recursion_depth != 0 && !flat_expr.subexpressions().empty() && // mainline expression is exactly one recursive step. flat_expr.subexpressions().front().size() == 1 && flat_expr.subexpressions().front().front()->GetNativeTypeId() == NativeTypeId::For()) { const DirectExpressionStep* root = internal::down_cast( flat_expr.subexpressions().front().front().get()) ->wrapped(); return std::make_unique(environment_, std::move(flat_expr), root); } return std::make_unique(environment_, std::move(flat_expr)); } bool TestOnly_IsRecursiveImpl(const Program* program) { return dynamic_cast(program) != nullptr; } } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/runtime_impl.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_IMPL_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_IMPL_H_ #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "base/ast.h" #include "base/type_provider.h" #include "common/native_type.h" #include "eval/compiler/flat_expr_builder.h" #include "internal/well_known_types.h" #include "runtime/function_registry.h" #include "runtime/internal/runtime_env.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { class RuntimeImpl : public Runtime { public: using Environment = RuntimeEnv; RuntimeImpl(absl_nonnull std::shared_ptr environment, const RuntimeOptions& options) : environment_(std::move(environment)), expr_builder_(environment_, options) { ABSL_DCHECK(environment_->well_known_types.IsInitialized()); } TypeRegistry& type_registry() ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->type_registry; } const TypeRegistry& type_registry() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->type_registry; } FunctionRegistry& function_registry() ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->function_registry; } const FunctionRegistry& function_registry() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->function_registry; } const well_known_types::Reflection& well_known_types() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return environment_->well_known_types; } Environment& environment() ABSL_ATTRIBUTE_LIFETIME_BOUND { return *environment_; } const Environment& environment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return *environment_; } // implement Runtime absl::StatusOr> CreateProgram( std::unique_ptr ast, const Runtime::CreateProgramOptions& options) const final; absl::StatusOr> CreateTraceableProgram( std::unique_ptr ast, const Runtime::CreateProgramOptions& options) const override; const TypeProvider& GetTypeProvider() const override { return environment_->type_registry.GetComposedTypeProvider(); } const google::protobuf::DescriptorPool* absl_nonnull GetDescriptorPool() const override { return environment_->descriptor_pool.get(); } google::protobuf::MessageFactory* absl_nonnull GetMessageFactory() const override { return environment_->MutableMessageFactory(); } // exposed for extensions access google::api::expr::runtime::FlatExprBuilder& expr_builder() ABSL_ATTRIBUTE_LIFETIME_BOUND { return expr_builder_; } private: NativeTypeId GetNativeTypeId() const override { return NativeTypeId::For(); } // Note: this is mutable, but should only be accessed in a const context after // building is complete. // // This is used to keep alive the registries while programs reference them. std::shared_ptr environment_; google::api::expr::runtime::FlatExprBuilder expr_builder_; }; // Exposed for testing to validate program is recursively planned. // // Uses dynamic_casts to test. bool TestOnly_IsRecursiveImpl(const Program* program); } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_IMPL_H_ ================================================ FILE: runtime/internal/runtime_type_provider.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/internal/runtime_type_provider.h" #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/type.h" #include "common/type_introspector.h" #include "common/value.h" #include "common/values/value_builder.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { absl::Status RuntimeTypeProvider::RegisterType(const OpaqueType& type) { auto insertion = types_.insert(std::pair{type.name(), Type(type)}); if (!insertion.second) { return absl::AlreadyExistsError( absl::StrCat("type already registered: ", insertion.first->first)); } return absl::OkStatus(); } absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( absl::string_view name) const { auto type = FindWellKnownType(name); if (type.has_value()) { return type; } const auto* desc = descriptor_pool_->FindMessageTypeByName(name); if (desc != nullptr) { return MessageType(desc); } if (const auto it = types_.find(name); it != types_.end()) { return it->second; } return absl::nullopt; } absl::StatusOr> RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, absl::string_view value) const { auto enum_constant = FindWellKnownTypeEnumConstant(type, value); if (enum_constant.has_value()) { return enum_constant; } const google::protobuf::EnumDescriptor* enum_desc = descriptor_pool_->FindEnumTypeByName(type); if (enum_desc == nullptr) { return absl::nullopt; } // Note: we don't support strong enum typing at this time so only the fully // qualified enum values are meaningful, so we don't provide any signal if the // enum type is found but can't match the value name. const google::protobuf::EnumValueDescriptor* value_desc = enum_desc->FindValueByName(value); if (value_desc == nullptr) { return absl::nullopt; } return TypeIntrospector::EnumConstant{ EnumType(enum_desc), enum_desc->full_name(), value_desc->name(), value_desc->number()}; } absl::StatusOr> RuntimeTypeProvider::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { auto field = FindWellKnownTypeFieldByName(type, name); if (field.has_value()) { return field; } const auto* desc = descriptor_pool_->FindMessageTypeByName(type); if (desc == nullptr) { return absl::nullopt; } const auto* field_desc = desc->FindFieldByName(name); if (field_desc == nullptr) { field_desc = descriptor_pool_->FindExtensionByPrintableName(desc, name); if (field_desc == nullptr) { return absl::nullopt; } } return MessageTypeField(field_desc); } absl::StatusOr RuntimeTypeProvider::NewValueBuilder( absl::string_view name, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { return common_internal::NewValueBuilder(arena, descriptor_pool_, message_factory, name); } } // namespace cel::runtime_internal ================================================ FILE: runtime/internal/runtime_type_provider.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::runtime_internal { class RuntimeTypeProvider final : public TypeReflector { public: explicit RuntimeTypeProvider( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) : descriptor_pool_(descriptor_pool) {} absl::Status RegisterType(const OpaqueType& type); absl::StatusOr NewValueBuilder( absl::string_view name, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const override; protected: absl::StatusOr> FindTypeImpl( absl::string_view name) const override; absl::StatusOr> FindEnumConstantImpl( absl::string_view type, absl::string_view value) const override; absl::StatusOr> FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const override; private: const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; absl::flat_hash_map types_; }; } // namespace cel::runtime_internal #endif // THIRD_PARTY_CEL_CPP_RUNTIME_INTERNAL_RUNTIME_TYPE_PROVIDER_H_ ================================================ FILE: runtime/memory_safety_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Tests for memory safety using the CEL Evaluator. #include #include #include #include #include #include #include "google/protobuf/any.pb.h" #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "checker/validation_result.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "common/value_testing.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/optional.h" #include "compiler/standard_library.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/activation.h" #include "runtime/constant_folding.h" #include "runtime/function_adapter.h" #include "runtime/reference_resolver.h" #include "runtime/regex_precompilation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" #include "google/protobuf/util/message_differencer.h" namespace cel { namespace { using ::absl_testing::IsOkAndHolds; using ::cel::expr::conformance::proto3::NestedTestAllTypes; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::test::StringValueIs; using ::cel::test::ValueMatcher; using ::google::protobuf::Any; using ::testing::Not; struct TestCase { std::string name; std::string expression; absl::flat_hash_map> activation; test::ValueMatcher expected_matcher; bool reference_resolver_enabled = false; }; enum Options { kDefault, kExhaustive, kFoldConstants }; using ParamType = std::tuple; absl::StatusOr> CreateCompiler() { google::protobuf::LinkMessageReflection(); google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::NestedTestAllTypes>(); CEL_ASSIGN_OR_RETURN( std::unique_ptr b, NewCompilerBuilder(google::protobuf::DescriptorPool::generated_pool())); CEL_RETURN_IF_ERROR(b->AddLibrary(StandardCompilerLibrary())); CEL_RETURN_IF_ERROR(b->AddLibrary(OptionalCompilerLibrary())); b->GetCheckerBuilder().set_container("cel.expr.conformance.proto3"); auto& cb = b->GetCheckerBuilder(); CEL_RETURN_IF_ERROR(cb.AddVariable(MakeVariableDecl("bool_var", BoolType()))); CEL_RETURN_IF_ERROR( cb.AddVariable(MakeVariableDecl("string_var", StringType()))); CEL_RETURN_IF_ERROR( cb.AddVariable(MakeVariableDecl("condition", BoolType()))); CEL_RETURN_IF_ERROR(cb.AddVariable(MakeVariableDecl( "nested_test_all_types", MessageType(NestedTestAllTypes::descriptor())))); CEL_RETURN_IF_ERROR(cb.AddFunction( MakeFunctionDecl("IsPrivate", MakeOverloadDecl("IsPrivate_string", BoolType(), StringType())) .value())); CEL_RETURN_IF_ERROR(cb.AddFunction( MakeFunctionDecl( "net.IsPrivate", MakeOverloadDecl("net_IsPrivate_string", BoolType(), StringType())) .value())); return b->Build(); } const Compiler& GetCompiler() { static const Compiler* compiler = []() { auto compiler = CreateCompiler(); ABSL_QCHECK_OK(compiler.status()); return compiler->release(); }(); return *compiler; } std::string TestCaseName(const testing::TestParamInfo& param_info) { const ParamType& param = param_info.param; absl::string_view opt; switch (std::get<1>(param)) { case Options::kDefault: opt = "default"; break; case Options::kExhaustive: opt = "exhaustive"; break; case Options::kFoldConstants: opt = "opt"; break; } return absl::StrCat(std::get<0>(param).name, "_", opt); } bool IsPrivateIpv4Impl(const StringValue& addr) { // Implementation for demonstration, this is simple but incomplete and // brittle. std::string buf; return absl::StartsWith(addr.ToStringView(&buf), "192.168.") || absl::StartsWith(addr.ToStringView(&buf), "10."); } absl::StatusOr> ConfigureRuntimeImpl( bool resolve_references, Options evaluation_options) { RuntimeOptions options; switch (evaluation_options) { case Options::kDefault: options.short_circuiting = true; break; case Options::kExhaustive: options.short_circuiting = false; break; case Options::kFoldConstants: options.enable_comprehension_list_append = true; options.short_circuiting = true; break; } options.enable_qualified_type_identifiers = resolve_references; options.container = "cel.expr.conformance.proto3"; CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); if (resolve_references) { CEL_RETURN_IF_ERROR(EnableReferenceResolver( runtime_builder, ReferenceResolverEnabled::kAlways)); } if (evaluation_options == Options::kFoldConstants) { CEL_RETURN_IF_ERROR(extensions::EnableConstantFolding(runtime_builder)); CEL_RETURN_IF_ERROR(extensions::EnableRegexPrecompilation(runtime_builder)); } auto s = UnaryFunctionAdapter::Register( "IsPrivate", false, &IsPrivateIpv4Impl, runtime_builder.function_registry()); CEL_RETURN_IF_ERROR(s); s.Update(UnaryFunctionAdapter::Register( "net.IsPrivate", false, &IsPrivateIpv4Impl, runtime_builder.function_registry())); CEL_RETURN_IF_ERROR(s); return std::move(runtime_builder).Build(); } class EvaluatorMemorySafetyTest : public testing::TestWithParam { public: EvaluatorMemorySafetyTest() = default; protected: const TestCase& GetTestCase() { return std::get<0>(GetParam()); } absl::StatusOr> ConfigureRuntime() { return ConfigureRuntimeImpl(GetTestCase().reference_resolver_enabled, std::get<1>(GetParam())); } }; void InitActivation(const TestCase& test_case, google::protobuf::Arena& arena, Activation& activation) { for (const auto& [key, value] : test_case.activation) { if (absl::holds_alternative(value)) { activation.InsertOrAssignValue(key, std::get(value)); } else { // Note: This assumes that the TestCase is valid for the given TEST. // Changes to the activation map will invalidate the pointer to message // that gets wrapped here. activation.InsertOrAssignValue( key, Value::WrapMessageUnsafe( &std::get(value), google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); } } } TEST_P(EvaluatorMemorySafetyTest, Basic) { const auto& test_case = GetTestCase(); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntime()); ASSERT_OK_AND_ASSIGN(ValidationResult validation, GetCompiler().Compile(test_case.expression)); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); Activation activation; google::protobuf::Arena arena; InitActivation(test_case, arena, activation); absl::StatusOr got = program->Evaluate(&arena, activation); EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); } TEST_P(EvaluatorMemorySafetyTest, ProgramSafeAfterRuntimeDestroyed) { const auto& test_case = GetTestCase(); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntime()); ASSERT_OK_AND_ASSIGN(ValidationResult validation, GetCompiler().Compile(test_case.expression)); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); Activation activation; google::protobuf::Arena arena; InitActivation(test_case, arena, activation); runtime.reset(); absl::StatusOr got = program->Evaluate(&arena, activation); EXPECT_THAT(got, IsOkAndHolds(test_case.expected_matcher)); } // Helper for making an eternal string value without looking like a memory leak. Value MakeStringValue(absl::string_view str) { static absl::NoDestructor kArena; return StringValue::Wrap(str, kArena.get()); } NestedTestAllTypes MakeNestedTestAllTypes(absl::string_view textproto) { NestedTestAllTypes msg; ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(textproto, &msg)); return msg; } MATCHER_P(ParsedProtoStructEquals, expected, "") { const cel::StructValue& got = arg; if (!got.IsParsedMessage()) { return false; } auto& msg = got.GetParsedMessage(); auto cmp = absl::WrapUnique(msg->New()); if (!google::protobuf::TextFormat::ParseFromString(expected, cmp.get())) { *result_listener << "Failed to parse expected proto"; return false; } return google::protobuf::util::MessageDifferencer::Equals(*msg, *cmp); } INSTANTIATE_TEST_SUITE_P( Expression, EvaluatorMemorySafetyTest, testing::Combine( testing::ValuesIn(std::vector{ { "bool", "(true && false) || bool_var || string_var == 'test_str'", {{"bool_var", BoolValue(false)}, {"string_var", MakeStringValue("test_str")}}, test::BoolValueIs(true), }, { "const_str", "condition ? 'left_hand_string' : 'right_hand_string'", {{"condition", BoolValue(false)}}, test::StringValueIs("right_hand_string"), }, { "long_const_string", "condition ? 'left_hand_string' : " "'long_right_hand_string_0123456789'", {{"condition", BoolValue(false)}}, test::StringValueIs("long_right_hand_string_0123456789"), }, { "computed_string", "(condition ? 'a.b' : 'b.c') + '.d.e.f'", {{"condition", BoolValue(false)}}, test::StringValueIs("b.c.d.e.f"), }, { "regex", R"('192.168.128.64'.matches(r'^192\.168\.[0-2]?[0-9]?[0-9]\.[0-2]?[0-9]?[0-9]') )", {}, test::BoolValueIs(true), }, { "list_create", "[1, 2, 3, 4, 5, 6][3] == 4", {}, test::BoolValueIs(true), }, { "list_create_strings", "['1', '2', '3', '4', '5', '6'][2] == '3'", {}, test::BoolValueIs(true), }, { "map_create", "{'1': 'one', '2': 'two'}['2']", {}, test::StringValueIs("two"), }, { "struct_create", R"cel( NestedTestAllTypes{ child: NestedTestAllTypes{ payload: TestAllTypes{ repeated_int32: [1, 2, 3] } }, payload: TestAllTypes{ repeated_string: ["foo", "bar", "baz"] } })cel", {}, test::StructValueIs(ParsedProtoStructEquals(R"pb( child { payload { repeated_int32: [ 1, 2, 3 ] } } payload { repeated_string: [ "foo", "bar", "baz" ] } )pb")), }, {"extension_function", "IsPrivate('8.8.8.8')", {}, test::BoolValueIs(false), /*enable_reference_resolver=*/false}, {"namespaced_function", "net.IsPrivate('192.168.0.1')", {}, test::BoolValueIs(true), /*enable_reference_resolver=*/true}, { "comprehension", "['abc', 'def', 'ghi', 'jkl'].exists(el, el == 'mno')", {}, test::BoolValueIs(false), }, { "comprehension_complex", "['a' + 'b' + 'c', 'd' + 'ef', 'g' + 'hi', 'j' + 'kl']" ".exists(el, el.startsWith('g'))", {}, test::BoolValueIs(true), }, TestCase{ "unsafe_message_access", "nested_test_all_types.child.payload", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb(child { payload { single_int32: 1 } })pb")}}, test::StructValueIs( ParsedProtoStructEquals(R"pb(single_int32: 1)pb")), }, TestCase{ "unsafe_message_access_repeated_field", "nested_test_all_types.payload.repeated_int32.size() == 3", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb(payload { repeated_int32: 1 repeated_int32: 2 repeated_int32: 3 })pb")}}, test::BoolValueIs(true), }, TestCase{ "unsafe_message_access_repeated_field_index", "nested_test_all_types.payload.repeated_int32[1] == 2", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb(payload { repeated_int32: 1 repeated_int32: 2 repeated_int32: 3 })pb")}}, test::BoolValueIs(true), }, TestCase{ "unsafe_message_access_map_field", "nested_test_all_types.payload.map_int32_string.size() == 2", {{"nested_test_all_types", MakeNestedTestAllTypes( R"pb(payload { map_int32_string { key: 1 value: "foo" } map_int32_string { key: 2 value: "bar" } })pb")}}, test::BoolValueIs(true), }, TestCase{ "unsafe_message_access_map_field_index", "nested_test_all_types.payload.map_int32_string[1] == 'foo'", {{"nested_test_all_types", MakeNestedTestAllTypes( R"pb(payload { map_int32_string { key: 1 value: "foo" } map_int32_string { key: 2 value: "bar" } })pb")}}, test::BoolValueIs(true), }, TestCase{ "unsafe_message_access_string_field", "nested_test_all_types.payload.single_string == 'foo'", {{"nested_test_all_types", MakeNestedTestAllTypes( R"pb(payload { single_string: "foo" })pb")}}, test::BoolValueIs(true), }, TestCase{ "unsafe_message_access_assign", "NestedTestAllTypes{payload: " "nested_test_all_types.child.payload}", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb(child { payload { single_int32: 1 } })pb")}}, test::StructValueIs(ParsedProtoStructEquals(R"pb(payload { single_int32: 1 })pb")), }, TestCase{ "unsafe_message_access_assign_repeated_field", "TestAllTypes{repeated_int32: " "nested_test_all_types.payload.repeated_int32}", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( payload { repeated_int32: [ 1, 2, 3 ] } )pb")}}, test::StructValueIs(ParsedProtoStructEquals( R"pb(repeated_int32: [ 1, 2, 3 ])pb")), }, TestCase{ "unsafe_message_access_assign_map_field", "TestAllTypes{map_int32_string: " "nested_test_all_types.payload.map_int32_string}", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( payload { map_int32_string { key: 1 value: "foo" } map_int32_string { key: 2 value: "bar" } } )pb")}}, test::StructValueIs(ParsedProtoStructEquals( R"pb(map_int32_string { key: 1 value: "foo" } map_int32_string { key: 2 value: "bar" })pb")), }, TestCase{ "unsafe_message_access_assign_string_field", "TestAllTypes{single_string: " "nested_test_all_types.payload.single_string}", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( payload { single_string: 'foo is a long string that is not inlined abcdef' } )pb")}}, test::StructValueIs(ParsedProtoStructEquals( R"pb(single_string: 'foo is a long string that is not inlined abcdef')pb")), }, TestCase{ "unsafe_message_access_assign_bytes_field", "TestAllTypes{single_bytes: " "nested_test_all_types.payload.single_bytes}", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( payload { single_bytes: 'foo is a long string that is not inlined abcdef' } )pb")}}, test::StructValueIs(ParsedProtoStructEquals( R"pb(single_bytes: 'foo is a long string that is not inlined abcdef')pb")), }, TestCase{ "unsafe_message_access_assign_from_repeated_string_field", "TestAllTypes{single_string: " "nested_test_all_types.payload.repeated_string[0]}", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( payload { repeated_string: 'foo is a long string that is not inlined abcdef' } )pb")}}, test::StructValueIs(ParsedProtoStructEquals( R"pb(single_string: 'foo is a long string that is not inlined abcdef')pb")), }, TestCase{ "unsafe_message_access_assign_from_map_string_field", "TestAllTypes{single_string: " "nested_test_all_types.payload.map_int32_string[1]}", {{"nested_test_all_types", MakeNestedTestAllTypes(R"pb( payload { map_int32_string { key: 1 value: "foo is a long string that is not inlined abcdef" } } )pb")}}, test::StructValueIs(ParsedProtoStructEquals( R"pb(single_string: "foo is a long string that is not inlined abcdef")pb")), }, }), testing::Values(Options::kDefault, Options::kExhaustive, Options::kFoldConstants)), &TestCaseName); MATCHER_P(IsSameInstance, expected, "") { return std::mem_fn(&ParsedMessageValue::operator->)(&arg) == expected; } // Returns true if the string value is backed by the same instance as the // expected string. Note: this only applies for string values that are too big // to be inlined in the StringValue and not represented as a absl::Cord. MATCHER_P(IsSameStringInstance, expected, "") { const StringValue& got = arg; std::string buf; absl::string_view got_view = got.ToStringView(&buf); bool result = got_view.data() == expected.data() && got_view.size() == expected.size(); if (!result) { *result_listener << absl::StrFormat("got: %p, wanted: %p", got_view.data(), expected.data()); } return result; } class ViewTypesMemorySafetyTest : public testing::TestWithParam { protected: Options EvaluationOptions() { return GetParam(); } }; // Test cases demonstrating how inputs as views are handled. TEST_P(ViewTypesMemorySafetyTest, WrappedMessage) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( child { payload { repeated_int32: [ 1, 2, 3 ] } } payload { repeated_string: [ "foo", "bar", "baz" ] } )pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "condition ? nested_test_all_types : NestedTestAllTypes{}")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes* proto = NestedTestAllTypes::default_instance().New(&arena); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); Activation activation; activation.InsertOrAssignValue("condition", BoolValue(true)); activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is the input message. ASSERT_TRUE(result.IsParsedMessage()); const ParsedMessageValue& result_msg = result.GetParsedMessage(); EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); EXPECT_EQ(result_msg->GetArena(), &arena); EXPECT_THAT(result_msg, IsSameInstance(proto)); } // Test cases demonstrating how inputs as views are handled. TEST_P(ViewTypesMemorySafetyTest, WrappedMessageFields) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( child { payload { repeated_int32: [ 1, 2, 3 ] } } payload { repeated_string: [ "foo", "bar", "baz" ] } )pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile("nested_test_all_types.child.payload")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes* proto = NestedTestAllTypes::default_instance().New(&arena); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); Activation activation; activation.InsertOrAssignValue("condition", BoolValue(true)); activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is an alias of a sub-message in the input. ASSERT_TRUE(result.IsParsedMessage()); const ParsedMessageValue& result_msg = result.GetParsedMessage(); EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals( "repeated_int32: [ 1, 2, 3 ]"))); EXPECT_EQ(result_msg->GetArena(), &arena); EXPECT_THAT(result_msg, IsSameInstance(&(proto->child().payload()))); } TEST_P(ViewTypesMemorySafetyTest, WrappedMessageDifferentArena) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( child { payload { repeated_int32: [ 1, 2, 3 ] } } payload { repeated_string: [ "foo", "bar", "baz" ] } )pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "condition ? nested_test_all_types : NestedTestAllTypes{}")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; google::protobuf::Arena other_arena; NestedTestAllTypes* proto = NestedTestAllTypes::default_instance().New(&other_arena); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, proto)); Activation activation; activation.InsertOrAssignValue("condition", BoolValue(true)); activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessage(proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is a copy of the input message. ASSERT_TRUE(result.IsParsedMessage()); const ParsedMessageValue& result_msg = result.GetParsedMessage(); EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); EXPECT_EQ(result_msg->GetArena(), &arena); EXPECT_THAT(result_msg, Not(IsSameInstance(proto))); } TEST_P(ViewTypesMemorySafetyTest, WrappedMessageFromAny) { // Arrange: create the runtime. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( child { payload { repeated_int32: [ 1, 2, 3 ] } } payload { repeated_string: [ "foo", "bar", "baz" ] } )pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "condition ? nested_test_all_types : NestedTestAllTypes{}")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); Any any; any.PackFrom(proto); Activation activation; activation.InsertOrAssignValue("condition", BoolValue(true)); activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessage(&any, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); // Assert ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.IsParsedMessage()); const ParsedMessageValue& result_msg = result.GetParsedMessage(); EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); EXPECT_EQ(result_msg->GetArena(), &arena); } TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageDifferentArena) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( child { payload { repeated_int32: [ 1, 2, 3 ] } } payload { repeated_string: [ "foo", "bar", "baz" ] } )pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "condition ? nested_test_all_types : NestedTestAllTypes{}")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. // The unsafe version will alias the input message, so caller must ensure // the input outlives the use of the `Value` rather than assuming it // is managed by the evaluation arena. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); Activation activation; activation.InsertOrAssignValue("condition", BoolValue(true)); activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is an alias of the input message. ASSERT_TRUE(result.IsParsedMessage()); const ParsedMessageValue& result_msg = result.GetParsedMessage(); EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals(kProtoValue))); EXPECT_EQ(result_msg->GetArena(), nullptr); EXPECT_THAT(result_msg, IsSameInstance(&proto)); } TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageFields) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( child { payload { repeated_int32: [ 1, 2, 3 ] } } payload { repeated_string: [ "foo", "bar", "baz" ] } )pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile("nested_test_all_types.child.payload")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); Activation activation; activation.InsertOrAssignValue("condition", BoolValue(true)); activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is an alias of a sub-message in the input. ASSERT_TRUE(result.IsParsedMessage()); const ParsedMessageValue& result_msg = result.GetParsedMessage(); EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals( "repeated_int32: [ 1, 2, 3 ]"))); EXPECT_EQ(result_msg->GetArena(), nullptr); EXPECT_THAT(result_msg, IsSameInstance(&(proto.child().payload()))); } TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageRepeatedField) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( payload { repeated_nested_message: { bb: 42 } } )pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "nested_test_all_types.payload.repeated_nested_message[0]")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); Activation activation; activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is an alias of a sub-message in the input. ASSERT_TRUE(result.IsParsedMessage()); const ParsedMessageValue& result_msg = result.GetParsedMessage(); EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals("bb: 42"))); EXPECT_EQ(result_msg->GetArena(), nullptr); EXPECT_THAT(result_msg, IsSameInstance(&(proto.payload().repeated_nested_message(0)))); } TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageMapField) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "nested_test_all_types.payload.map_string_message['foo']")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"pb( payload { map_string_message: { key: "foo" value: { bb: 42 } } map_string_message: { key: "baz" value: { bb: 43 } } })pb", &proto)); Activation activation; activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is an alias of a sub-message in the input. ASSERT_TRUE(result.IsParsedMessage()); const ParsedMessageValue& result_msg = result.GetParsedMessage(); EXPECT_THAT(result_msg, test::StructValueIs(ParsedProtoStructEquals(R"pb(bb: 42)pb"))); EXPECT_THAT( result_msg, IsSameInstance(&(proto.payload().map_string_message().at("foo")))); } TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageStringFields) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( child { payload { single_string: "foo that is too big to be inlined..." } } )pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "nested_test_all_types.child.payload.single_string")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); Activation activation; activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is an alias of a sub-message in the input. ASSERT_TRUE(result.IsString()); const StringValue& result_string = result.GetString(); EXPECT_THAT(result_string, StringValueIs("foo that is too big to be inlined...")); EXPECT_THAT(result_string, IsSameStringInstance(absl::string_view( proto.child().payload().single_string()))); } TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageRepeatedStringField) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( payload { repeated_string: "foo that is too big to be inlined..." } )pb"; ASSERT_OK_AND_ASSIGN(ValidationResult validation, GetCompiler().Compile( "nested_test_all_types.payload.repeated_string[0]")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); Activation activation; activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is an alias of a sub-message in the input. ASSERT_TRUE(result.IsString()); const StringValue& result_string = result.GetString(); EXPECT_THAT(result_string, StringValueIs("foo that is too big to be inlined...")); EXPECT_THAT(result_string, IsSameStringInstance(absl::string_view( proto.payload().repeated_string(0)))); } TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageMapStringField) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); constexpr absl::string_view kProtoValue = R"pb( payload { map_string_string: { key: "foo" value: "bar that is too big to be inlined..." } })pb"; ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "nested_test_all_types.payload.map_string_string['foo']")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kProtoValue, &proto)); Activation activation; activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: the result is an alias of a sub-message in the input. ASSERT_TRUE(result.IsString()); const StringValue& result_string = result.GetString(); EXPECT_THAT(result_string, StringValueIs("bar that is too big to be inlined...")); EXPECT_THAT(result_string, IsSameStringInstance(absl::string_view( proto.payload().map_string_string().at("foo")))); } TEST_P(ViewTypesMemorySafetyTest, UnsafeWrappedMessageStringFieldAssign) { // Arrange: create the runtime and expression. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, ConfigureRuntimeImpl(false, EvaluationOptions())); ASSERT_OK_AND_ASSIGN( ValidationResult validation, GetCompiler().Compile( "TestAllTypes{single_string: " "nested_test_all_types.child.payload.single_string}.single_string")); ASSERT_TRUE(validation.IsValid()) << validation.FormatError(); ASSERT_OK_AND_ASSIGN(auto ast, validation.ReleaseAst()); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, runtime->CreateProgram(std::move(ast))); // Act: wrap the message and evaluate the expression. google::protobuf::Arena arena; NestedTestAllTypes proto; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( child { payload { single_string: "foo that is too big to be inlined..." } })pb", &proto)); Activation activation; activation.InsertOrAssignValue( "nested_test_all_types", Value::WrapMessageUnsafe(&proto, google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory(), &arena)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); // Assert: check that the result is not tied to the alias. // This is not a safe assumption generally, but making sure that the runtime // is making a defensive copy when building a message assumed to be on the // arena. Callers cannot safely assume this for arbitrary expressions. proto.Clear(); ASSERT_TRUE(result.IsString()); const StringValue& result_string = result.GetString(); EXPECT_THAT(result_string, StringValueIs("foo that is too big to be inlined...")); EXPECT_THAT(result_string, Not(IsSameStringInstance(absl::string_view( proto.child().payload().single_string())))); } INSTANTIATE_TEST_SUITE_P(Cases, ViewTypesMemorySafetyTest, testing::Values(Options::kDefault, Options::kExhaustive, Options::kFoldConstants), [](const testing::TestParamInfo& info) { switch (info.param) { case Options::kDefault: return "default"; case Options::kExhaustive: return "exhaustive"; case Options::kFoldConstants: return "opt"; } }); } // namespace } // namespace cel ================================================ FILE: runtime/optional_types.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/optional_types.h" #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/function_adapter.h" #include "common/casting.h" #include "common/type.h" #include "common/value.h" #include "internal/casts.h" #include "internal/number.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/internal/errors.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel::extensions { namespace { Value OptionalOf(const Value& value, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { return OptionalValue::Of(value, arena); } Value OptionalNone() { return OptionalValue::None(); } Value OptionalOfNonZeroValue( const Value& value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (value.IsZeroValue()) { return OptionalNone(); } return OptionalOf(value, descriptor_pool, message_factory, arena); } absl::StatusOr OptionalGetValue(const OpaqueValue& opaque_value) { if (auto optional_value = opaque_value.AsOptional(); optional_value) { return optional_value->Value(); } return ErrorValue{runtime_internal::CreateNoMatchingOverloadError("value")}; } absl::StatusOr OptionalHasValue(const OpaqueValue& opaque_value) { if (auto optional_value = opaque_value.AsOptional(); optional_value) { return BoolValue{optional_value->HasValue()}; } return ErrorValue{ runtime_internal::CreateNoMatchingOverloadError("hasValue")}; } absl::StatusOr SelectOptionalFieldStruct( const StructValue& struct_value, const StringValue& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::string field_name; auto field_name_view = key.NativeString(field_name); CEL_ASSIGN_OR_RETURN(auto has_field, struct_value.HasFieldByName(field_name_view)); if (!has_field) { return OptionalValue::None(); } CEL_ASSIGN_OR_RETURN( auto field, struct_value.GetFieldByName(field_name_view, descriptor_pool, message_factory, arena)); return OptionalValue::Of(std::move(field), arena); } absl::StatusOr SelectOptionalFieldMap( const MapValue& map, const StringValue& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { absl::optional value; CEL_ASSIGN_OR_RETURN(value, map.Find(key, descriptor_pool, message_factory, arena)); if (value) { return OptionalValue::Of(std::move(*value), arena); } return OptionalValue::None(); } absl::StatusOr SelectOptionalField( const OpaqueValue& opaque_value, const StringValue& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (auto optional_value = opaque_value.AsOptional(); optional_value) { if (!optional_value->HasValue()) { return OptionalValue::None(); } auto container = optional_value->Value(); if (auto map_value = container.AsMap(); map_value) { return SelectOptionalFieldMap(*map_value, key, descriptor_pool, message_factory, arena); } if (auto struct_value = container.AsStruct(); struct_value) { return SelectOptionalFieldStruct(*struct_value, key, descriptor_pool, message_factory, arena); } } return ErrorValue{runtime_internal::CreateNoMatchingOverloadError("_[?_]")}; } absl::StatusOr MapOptIndexOptionalValue( const MapValue& map, const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { absl::optional value; if (auto double_key = cel::As(key); double_key) { // Try int/uint. auto number = internal::Number::FromDouble(double_key->NativeValue()); if (number.LosslessConvertibleToInt()) { CEL_ASSIGN_OR_RETURN(value, map.Find(IntValue{number.AsInt()}, descriptor_pool, message_factory, arena)); if (value) { return OptionalValue::Of(std::move(*value), arena); } } if (number.LosslessConvertibleToUint()) { CEL_ASSIGN_OR_RETURN(value, map.Find(UintValue{number.AsUint()}, descriptor_pool, message_factory, arena)); if (value) { return OptionalValue::Of(std::move(*value), arena); } } } else { CEL_ASSIGN_OR_RETURN( value, map.Find(key, descriptor_pool, message_factory, arena)); if (value) { return OptionalValue::Of(std::move(*value), arena); } if (auto int_key = key.AsInt(); int_key && int_key->NativeValue() >= 0) { CEL_ASSIGN_OR_RETURN( value, map.Find(UintValue{static_cast(int_key->NativeValue())}, descriptor_pool, message_factory, arena)); if (value) { return OptionalValue::Of(std::move(*value), arena); } } else if (auto uint_key = key.AsUint(); uint_key && uint_key->NativeValue() <= static_cast(std::numeric_limits::max())) { CEL_ASSIGN_OR_RETURN( value, map.Find(IntValue{static_cast(uint_key->NativeValue())}, descriptor_pool, message_factory, arena)); if (value) { return OptionalValue::Of(std::move(*value), arena); } } } return OptionalValue::None(); } absl::StatusOr ListOptIndexOptionalInt( const ListValue& list, int64_t key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(auto list_size, list.Size()); if (key < 0 || static_cast(key) >= list_size) { return OptionalValue::None(); } CEL_ASSIGN_OR_RETURN(auto element, list.Get(static_cast(key), descriptor_pool, message_factory, arena)); return OptionalValue::Of(std::move(element), arena); } absl::StatusOr OptionalOptIndexOptionalValue( const OpaqueValue& opaque_value, const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (auto optional_value = As(opaque_value); optional_value) { if (!optional_value->HasValue()) { return OptionalValue::None(); } auto container = optional_value->Value(); if (auto map_value = cel::As(container); map_value) { return MapOptIndexOptionalValue(*map_value, key, descriptor_pool, message_factory, arena); } if (auto list_value = cel::As(container); list_value) { if (auto int_value = cel::As(key); int_value) { return ListOptIndexOptionalInt(*list_value, int_value->NativeValue(), descriptor_pool, message_factory, arena); } } } return ErrorValue{runtime_internal::CreateNoMatchingOverloadError("_[?_]")}; } absl::StatusOr ListFirst(const cel::ListValue& list, const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); if (size == 0) { return Value(OptionalValue::None()); } CEL_ASSIGN_OR_RETURN(Value value, list.Get(0, descriptor_pool, message_factory, arena)); return Value(OptionalValue::Of(std::move(value), arena)); } absl::StatusOr ListLast(const cel::ListValue& list, const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); if (size == 0) { return Value(OptionalValue::None()); } CEL_ASSIGN_OR_RETURN(Value value, list.Get(static_cast(size) - 1, descriptor_pool, message_factory, arena)); return Value(OptionalValue::Of(std::move(value), arena)); } absl::StatusOr ListUnwrapOpt( const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { auto builder = NewListValueBuilder(arena); CEL_ASSIGN_OR_RETURN(auto list_size, list.Size()); builder->Reserve(list_size); absl::Status status = list.ForEach( [&](const Value& value) -> absl::StatusOr { if (auto optional_value = value.AsOptional(); optional_value) { if (optional_value->HasValue()) { CEL_RETURN_IF_ERROR(builder->Add(optional_value->Value())); } } else { return absl::InvalidArgumentError(absl::StrFormat( "optional.unwrap() expected a list(optional(T)), but %s " "was found in the list.", value.GetTypeName())); } return true; }, descriptor_pool, message_factory, arena); if (!status.ok()) { return ErrorValue(status); } return std::move(*builder).Build(); } absl::Status RegisterOptionalTypeFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { if (!options.enable_qualified_type_identifiers) { return absl::FailedPreconditionError( "optional_type requires " "RuntimeOptions.enable_qualified_type_identifiers"); } if (!options.enable_heterogeneous_equality) { return absl::FailedPreconditionError( "optional_type requires RuntimeOptions.enable_heterogeneous_equality"); } CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor("optional.of", false), UnaryFunctionAdapter::WrapFunction(&OptionalOf))); CEL_RETURN_IF_ERROR( registry.Register(UnaryFunctionAdapter::CreateDescriptor( "optional.ofNonZeroValue", false), UnaryFunctionAdapter::WrapFunction( &OptionalOfNonZeroValue))); CEL_RETURN_IF_ERROR(registry.Register( NullaryFunctionAdapter::CreateDescriptor("optional.none", false), NullaryFunctionAdapter::WrapFunction(&OptionalNone))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, OpaqueValue>::CreateDescriptor("value", true), UnaryFunctionAdapter, OpaqueValue>::WrapFunction( &OptionalGetValue))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, OpaqueValue>::CreateDescriptor("hasValue", true), UnaryFunctionAdapter, OpaqueValue>::WrapFunction( &OptionalHasValue))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StructValue, StringValue>::CreateDescriptor("_?._", false), BinaryFunctionAdapter, StructValue, StringValue>:: WrapFunction(&SelectOptionalFieldStruct))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, MapValue, StringValue>::CreateDescriptor("_?._", false), BinaryFunctionAdapter, MapValue, StringValue>:: WrapFunction(&SelectOptionalFieldMap))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, OpaqueValue, StringValue>::CreateDescriptor("_?._", false), BinaryFunctionAdapter, OpaqueValue, StringValue>::WrapFunction(&SelectOptionalField))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, MapValue, Value>::CreateDescriptor("_[?_]", false), BinaryFunctionAdapter, MapValue, Value>::WrapFunction(&MapOptIndexOptionalValue))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, ListValue, int64_t>::CreateDescriptor("_[?_]", false), BinaryFunctionAdapter, ListValue, int64_t>::WrapFunction(&ListOptIndexOptionalInt))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, OpaqueValue, Value>::CreateDescriptor("_[?_]", false), BinaryFunctionAdapter, OpaqueValue, Value>:: WrapFunction(&OptionalOptIndexOptionalValue))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( "optional.unwrap", false), UnaryFunctionAdapter, ListValue>::WrapFunction( &ListUnwrapOpt))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( "unwrapOpt", true), UnaryFunctionAdapter, ListValue>::WrapFunction( &ListUnwrapOpt))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( "first", true), UnaryFunctionAdapter, ListValue>::WrapFunction( &ListFirst))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( "last", true), UnaryFunctionAdapter, ListValue>::WrapFunction( &ListLast))); return absl::OkStatus(); } } // namespace absl::Status EnableOptionalTypes(RuntimeBuilder& builder) { auto& runtime = cel::internal::down_cast( runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder)); CEL_RETURN_IF_ERROR(RegisterOptionalTypeFunctions( builder.function_registry(), runtime.expr_builder().options())); CEL_RETURN_IF_ERROR(builder.type_registry().RegisterType(OptionalType())); runtime.expr_builder().enable_optional_types(); return absl::OkStatus(); } } // namespace cel::extensions ================================================ FILE: runtime/optional_types.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_OPTIONAL_TYPES_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_OPTIONAL_TYPES_H_ #include "absl/status/status.h" #include "runtime/runtime_builder.h" namespace cel::extensions { // EnableOptionalTypes enable support for optional syntax and types in CEL. // // The optional value type makes it possible to express whether variables have // been provided, whether a result has been computed, and in the future whether // an object field path, map key value, or list index has a value. // // # Syntax Changes // // OptionalTypes are unlike other CEL extensions because they modify the CEL // syntax itself, notably through the use of a `?` preceding a field name or // index value. // // ## Field Selection // // The optional syntax in field selection is denoted as `obj.?field`. In other // words, if a field is set, return `optional.of(obj.field)“, else // `optional.none()`. The optional field selection is viral in the sense that // after the first optional selection all subsequent selections or indices // are treated as optional, i.e. the following expressions are equivalent: // // obj.?field.subfield // obj.?field.?subfield // // ## Indexing // // Similar to field selection, the optional syntax can be used in index // expressions on maps and lists: // // list[?0] // map[?key] // // ## Optional Field Setting // // When creating map or message literals, if a field may be optionally set // based on its presence, then placing a `?` before the field name or key // will ensure the type on the right-hand side must be optional(T) where T // is the type of the field or key-value. // // The following returns a map with the key expression set only if the // subfield is present, otherwise an empty map is created: // // {?key: obj.?field.subfield} // // ## Optional Element Setting // // When creating list literals, an element in the list may be optionally added // when the element expression is preceded by a `?`: // // [a, ?b, ?c] // return a list with either [a], [a, b], [a, b, c], or [a, c] // // # Optional.Of // // Create an optional(T) value of a given value with type T. // // optional.of(10) // // # Optional.OfNonZeroValue // // Create an optional(T) value of a given value with type T if it is not a // zero-value. A zero-value the default empty value for any given CEL type, // including empty protobuf message types. If the value is empty, the result // of this call will be optional.none(). // // optional.ofNonZeroValue([1, 2, 3]) // optional(list(int)) // optional.ofNonZeroValue([]) // optional.none() // optional.ofNonZeroValue(0) // optional.none() // optional.ofNonZeroValue("") // optional.none() // // # Optional.None // // Create an empty optional value. // // # HasValue // // Determine whether the optional contains a value. // // optional.of(b'hello').hasValue() // true // optional.ofNonZeroValue({}).hasValue() // false // // # Value // // Get the value contained by the optional. If the optional does not have a // value, the result will be a CEL error. // // optional.of(b'hello').value() // b'hello' // optional.ofNonZeroValue({}).value() // error // // # Or // // If the value on the left-hand side is optional.none(), the optional value // on the right hand side is returned. If the value on the left-hand set is // valued, then it is returned. This operation is short-circuiting and will // only evaluate as many links in the `or` chain as are needed to return a // non-empty optional value. // // obj.?field.or(m[?key]) // l[?index].or(obj.?field.subfield).or(obj.?other) // // # OrValue // // Either return the value contained within the optional on the left-hand side // or return the alternative value on the right hand side. // // m[?key].orValue("none") // // # OptMap // // Apply a transformation to the optional's underlying value if it is not empty // and return an optional typed result based on the transformation. The // transformation expression type must return a type T which is wrapped into // an optional. // // msg.?elements.optMap(e, e.size()).orValue(0) // // # OptFlatMap // // Introduced in version: 1 // // Apply a transformation to the optional's underlying value if it is not empty // and return the result. The transform expression must return an optional(T) // rather than type T. This can be useful when dealing with zero values and // conditionally generating an empty or non-empty result in ways which cannot // be expressed with `optMap`. // // msg.?elements.optFlatMap(e, e[?0]) // return the first element if present. absl::Status EnableOptionalTypes(RuntimeBuilder& builder); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_RUNTIME_OPTIONAL_TYPES_H_ ================================================ FILE: runtime/optional_types_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/optional_types.h" #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "common/value.h" #include "common/value_testing.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/options.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/function.h" #include "runtime/internal/runtime_impl.h" #include "runtime/reference_resolver.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::test::BoolValueIs; using ::cel::test::IntValueIs; using ::cel::test::OptionalValueIs; using ::cel::test::OptionalValueIsEmpty; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::TestWithParam; MATCHER_P(MatchesOptionalReceiver1, name, "") { const FunctionDescriptor& descriptor = arg.descriptor; std::vector types{Kind::kOpaque}; return descriptor.name() == name && descriptor.receiver_style() == true && descriptor.types() == types; } MATCHER_P2(MatchesOptionalReceiver2, name, kind, "") { const FunctionDescriptor& descriptor = arg.descriptor; std::vector types{Kind::kOpaque, kind}; return descriptor.name() == name && descriptor.receiver_style() == true && descriptor.types() == types; } MATCHER_P2(MatchesOptionalSelect, kind1, kind2, "") { const FunctionDescriptor& descriptor = arg.descriptor; std::vector types{kind1, kind2}; return descriptor.name() == "_?._" && descriptor.receiver_style() == false && descriptor.types() == types; } MATCHER_P2(MatchesOptionalIndex, kind1, kind2, "") { const FunctionDescriptor& descriptor = arg.descriptor; std::vector types{kind1, kind2}; return descriptor.name() == "_[?_]" && descriptor.receiver_style() == false && descriptor.types() == types; } TEST(EnableOptionalTypes, HeterogeneousEqualityRequired) { ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), RuntimeOptions{.enable_qualified_type_identifiers = true, .enable_heterogeneous_equality = false})); EXPECT_THAT(EnableOptionalTypes(builder), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST(EnableOptionalTypes, QualifiedTypeIdentifiersRequired) { ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), RuntimeOptions{.enable_qualified_type_identifiers = false, .enable_heterogeneous_equality = true})); EXPECT_THAT(EnableOptionalTypes(builder), StatusIs(absl::StatusCode::kFailedPrecondition)); } TEST(EnableOptionalTypes, PreconditionsSatisfied) { ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), RuntimeOptions{.enable_qualified_type_identifiers = true, .enable_heterogeneous_equality = true})); EXPECT_THAT(EnableOptionalTypes(builder), IsOk()); } TEST(EnableOptionalTypes, Functions) { ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), RuntimeOptions{.enable_qualified_type_identifiers = true, .enable_heterogeneous_equality = true})); ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); EXPECT_THAT(builder.function_registry().FindStaticOverloads("hasValue", true, {Kind::kOpaque}), ElementsAre(MatchesOptionalReceiver1("hasValue"))); EXPECT_THAT(builder.function_registry().FindStaticOverloads("value", true, {Kind::kOpaque}), ElementsAre(MatchesOptionalReceiver1("value"))); EXPECT_THAT(builder.function_registry().FindStaticOverloads( "_?._", false, {Kind::kStruct, Kind::kString}), ElementsAre(MatchesOptionalSelect(Kind::kStruct, Kind::kString))); EXPECT_THAT(builder.function_registry().FindStaticOverloads( "_?._", false, {Kind::kMap, Kind::kString}), ElementsAre(MatchesOptionalSelect(Kind::kMap, Kind::kString))); EXPECT_THAT(builder.function_registry().FindStaticOverloads( "_?._", false, {Kind::kOpaque, Kind::kString}), ElementsAre(MatchesOptionalSelect(Kind::kOpaque, Kind::kString))); EXPECT_THAT(builder.function_registry().FindStaticOverloads( "_[?_]", false, {Kind::kMap, Kind::kAny}), ElementsAre(MatchesOptionalIndex(Kind::kMap, Kind::kAny))); EXPECT_THAT(builder.function_registry().FindStaticOverloads( "_[?_]", false, {Kind::kList, Kind::kInt}), ElementsAre(MatchesOptionalIndex(Kind::kList, Kind::kInt))); EXPECT_THAT(builder.function_registry().FindStaticOverloads( "_[?_]", false, {Kind::kOpaque, Kind::kAny}), ElementsAre(MatchesOptionalIndex(Kind::kOpaque, Kind::kAny))); } struct EvaluateResultTestCase { std::string name; std::string expression; test::ValueMatcher value_matcher; template friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { sink.Append(tc.name); } }; class OptionalTypesTest : public TestWithParam> { public: const EvaluateResultTestCase& GetTestCase() { return std::get<0>(GetParam()); } bool EnableShortCircuiting() { return std::get<1>(GetParam()); } }; TEST_P(OptionalTypesTest, RecursivePlan) { RuntimeOptions opts; opts.enable_qualified_type_identifiers = true; opts.max_recursion_depth = -1; opts.short_circuiting = EnableShortCircuiting(); const EvaluateResultTestCase& test_case = GetTestCase(); ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_OK(EnableOptionalTypes(builder)); ASSERT_OK( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(test_case.expression, "", ParserOptions{.enable_optional_syntax = true})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, test_case.value_matcher) << test_case.expression; } TEST_P(OptionalTypesTest, Defaults) { RuntimeOptions opts; opts.enable_qualified_type_identifiers = true; opts.short_circuiting = EnableShortCircuiting(); const EvaluateResultTestCase& test_case = GetTestCase(); ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_OK(EnableOptionalTypes(builder)); ASSERT_OK( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse(test_case.expression, "", ParserOptions{.enable_optional_syntax = true})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, test_case.value_matcher) << test_case.expression; } INSTANTIATE_TEST_SUITE_P( Basic, OptionalTypesTest, testing::Combine( testing::ValuesIn(std::vector{ {"optional_none_hasValue", "optional.none().hasValue()", BoolValueIs(false)}, {"optional_of_hasValue", "optional.of(0).hasValue()", BoolValueIs(true)}, {"optional_ofNonZeroValue_hasValue", "optional.ofNonZeroValue(0).hasValue()", BoolValueIs(false)}, {"optional_or_absent", "optional.ofNonZeroValue(0).or(optional.ofNonZeroValue(0))", OptionalValueIsEmpty()}, {"optional_or_present", "optional.of(1).or(optional.none())", OptionalValueIs(IntValueIs(1))}, {"optional_orValue_absent", "optional.ofNonZeroValue(0).orValue(1)", IntValueIs(1)}, {"optional_orValue_present", "optional.of(1).orValue(2)", IntValueIs(1)}, {"list_of_optional", "[optional.of(1)][0].orValue(1)", IntValueIs(1)}, {"list_unwrap_empty", "optional.unwrap([]) == []", BoolValueIs(true)}, {"list_unwrap_empty_optional_none", "optional.unwrap([optional.none(), optional.none()]) == []", BoolValueIs(true)}, {"list_unwrap_three_elements", "optional.unwrap([optional.of(42), optional.none(), " "optional.of(\"a\")]) == [42, \"a\"]", BoolValueIs(true)}, {"list_unwrap_no_none", "optional.unwrap([optional.of(42), optional.of(\"a\")]) == [42, " "\"a\"]", BoolValueIs(true)}, {"list_unwrapOpt_empty", "[].unwrapOpt() == []", BoolValueIs(true)}, {"list_unwrapOpt_empty_optional_none", "[optional.none(), optional.none()].unwrapOpt() == []", BoolValueIs(true)}, {"list_unwrapOpt_three_elements", "[optional.of(42), optional.none(), " "optional.of(\"a\")].unwrapOpt() == [42, \"a\"]", BoolValueIs(true)}, {"list_unwrapOpt_no_none", "[optional.of(42), optional.of(\"a\")].unwrapOpt() == [42, \"a\"]", BoolValueIs(true)}, {"list_first", "[1, 2, 3].first()", OptionalValueIs(IntValueIs(1))}, {"list_first_empty", "[].first()", OptionalValueIsEmpty()}, {"list_last", "[1, 2, 3].last()", OptionalValueIs(IntValueIs(3))}, {"list_last_empty", "[].last()", OptionalValueIsEmpty()}, }), /*enable_short_circuiting*/ testing::Bool())); class UnreachableFunction final : public cel::Function { public: explicit UnreachableFunction(int64_t* count) : count_(count) {} absl::StatusOr Invoke(absl::Span args, const InvokeContext& context) const override { ++(*count_); return ErrorValue(absl::CancelledError()); } private: int64_t* const count_; }; TEST(OptionalTypesTest, ErrorShortCircuiting) { RuntimeOptions opts{.enable_qualified_type_identifiers = true}; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); int64_t unreachable_count = 0; ASSERT_OK(EnableOptionalTypes(builder)); ASSERT_OK( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways)); ASSERT_OK(builder.function_registry().Register( cel::FunctionDescriptor("unreachable", false, {}), std::make_unique(&unreachable_count))); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("optional.of(1 / 0).orValue(unreachable())", "", ParserOptions{.enable_optional_syntax = true})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); Activation activation; ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_EQ(unreachable_count, 0); ASSERT_TRUE(result->Is()) << result->DebugString(); EXPECT_THAT(result.GetError().NativeValue(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("divide by zero"))); } TEST(OptionalTypesTest, CreateList_TypeConversionError) { RuntimeOptions opts{.enable_qualified_type_identifiers = true}; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); ASSERT_THAT( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("[?foo]", "", ParserOptions{.enable_optional_syntax = true})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); Activation activation; activation.InsertOrAssignValue("foo", IntValue(1)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.IsError()) << result.DebugString(); EXPECT_THAT(result.GetError().ToStatus(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("type conversion error"))); } TEST(OptionalTypesTest, CreateMap_TypeConversionError) { RuntimeOptions opts{.enable_qualified_type_identifiers = true}; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); ASSERT_THAT( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("{?1: foo}", "", ParserOptions{.enable_optional_syntax = true})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); Activation activation; activation.InsertOrAssignValue("foo", IntValue(1)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.IsError()) << result.DebugString(); EXPECT_THAT(result.GetError().ToStatus(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("type conversion error"))); } TEST(OptionalTypesTest, CreateStruct_KeyTypeConversionError) { RuntimeOptions opts{.enable_qualified_type_identifiers = true}; google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN( auto builder, CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); ASSERT_THAT( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN( ParsedExpr expr, Parse("cel.expr.conformance.proto2.TestAllTypes{?single_int32: foo}", "", ParserOptions{.enable_optional_syntax = true})); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); Activation activation; activation.InsertOrAssignValue("foo", IntValue(1)); ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); ASSERT_TRUE(result.IsError()) << result.DebugString(); EXPECT_THAT(result.GetError().ToStatus(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("type conversion error"))); } } // namespace } // namespace cel::extensions ================================================ FILE: runtime/reference_resolver.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/reference_resolver.h" #include "absl/base/macros.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/native_type.h" #include "eval/compiler/qualified_reference_resolver.h" #include "internal/casts.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" namespace cel { namespace { using ::cel::internal::down_cast; using ::cel::runtime_internal::RuntimeFriendAccess; using ::cel::runtime_internal::RuntimeImpl; absl::StatusOr RuntimeImplFromBuilder(RuntimeBuilder& builder) { Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); if (RuntimeFriendAccess::RuntimeTypeId(runtime) != NativeTypeId::For()) { return absl::UnimplementedError( "regex precompilation only supported on the default cel::Runtime " "implementation."); } RuntimeImpl& runtime_impl = down_cast(runtime); return &runtime_impl; } google::api::expr::runtime::ReferenceResolverOption Convert( ReferenceResolverEnabled enabled) { switch (enabled) { case ReferenceResolverEnabled::kCheckedExpressionOnly: return google::api::expr::runtime::ReferenceResolverOption::kCheckedOnly; case ReferenceResolverEnabled::kAlways: return google::api::expr::runtime::ReferenceResolverOption::kAlways; } ABSL_LOG(FATAL) << "unsupported ReferenceResolverEnabled enumerator: " << static_cast(enabled); } } // namespace absl::Status EnableReferenceResolver(RuntimeBuilder& builder, ReferenceResolverEnabled enabled) { CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, RuntimeImplFromBuilder(builder)); ABSL_ASSERT(runtime_impl != nullptr); runtime_impl->expr_builder().AddAstTransform( NewReferenceResolverExtension(Convert(enabled))); return absl::OkStatus(); } } // namespace cel ================================================ FILE: runtime/reference_resolver.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_REFERENCE_RESOLVER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_REFERENCE_RESOLVER_H_ #include "absl/status/status.h" #include "runtime/runtime_builder.h" namespace cel { enum class ReferenceResolverEnabled { kCheckedExpressionOnly, kAlways }; // Enables expression rewrites to normalize the AST representation of // references to qualified names of enum constants, variables and functions. // // For parse-only expressions, this is only able to disambiguate functions based // on registered overloads in the runtime. // // Note: This may require making a deep copy of the input expression in order to // apply the rewrites. // // Applied adjustments: // - for dot-qualified variable names represented as select operations, // replaces select operations with an identifier. // - for dot-qualified functions, replaces receiver call with a global // function call. // - for compile time constants (such as enum values), inlines the constant // value as a literal. absl::Status EnableReferenceResolver(RuntimeBuilder& builder, ReferenceResolverEnabled enabled); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_REFERENCE_RESOLVER_H_ ================================================ FILE: runtime/reference_resolver_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/reference_resolver.h" #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/function_adapter.h" #include "common/value.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/register_function_helper.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" namespace cel { namespace { using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::expr::CheckedExpr; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::absl_testing::StatusIs; using ::testing::HasSubstr; TEST(ReferenceResolver, ResolveQualifiedFunctions) { RuntimeOptions options; ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_OK( EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways)); absl::Status status = RegisterHelper>:: RegisterGlobalOverload( "com.example.Exp", [](int64_t base, int64_t exp) -> int64_t { int64_t result = 1; for (int64_t i = 0; i < exp; ++i) { result *= base; } return result; }, builder.function_registry()); ASSERT_OK(status); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("com.example.Exp(2, 3) == 8")); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, parsed_expr)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_TRUE(value.GetBool().NativeValue()); } TEST(ReferenceResolver, ResolveQualifiedFunctionsCheckedOnly) { RuntimeOptions options; ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_OK(EnableReferenceResolver( builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); absl::Status status = RegisterHelper>:: RegisterGlobalOverload( "com.example.Exp", [](int64_t base, int64_t exp) -> int64_t { int64_t result = 1; for (int64_t i = 0; i < exp; ++i) { result *= base; } return result; }, builder.function_registry()); ASSERT_OK(status); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("com.example.Exp(2, 3) == 8")); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, parsed_expr), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("No overloads provided"))); } // com.example.x + com.example.y constexpr absl::string_view kIdentifierExpression = R"pb( reference_map: { key: 3 value: { name: "com.example.x" } } reference_map: { key: 4 value: { overload_id: "add_int64" } } reference_map: { key: 7 value: { name: "com.example.y" } } type_map: { key: 3 value: { primitive: INT64 } } type_map: { key: 4 value: { primitive: INT64 } } type_map: { key: 7 value: { primitive: INT64 } } source_info: { location: "" line_offsets: 30 positions: { key: 1 value: 0 } positions: { key: 2 value: 3 } positions: { key: 3 value: 11 } positions: { key: 4 value: 14 } positions: { key: 5 value: 16 } positions: { key: 6 value: 19 } positions: { key: 7 value: 27 } } expr: { id: 4 call_expr: { function: "_+_" args: { id: 3 # compilers typically already apply this rewrite, but older saved # expressions might preserve the original parse. select_expr { operand { id: 8 select_expr { operand: { id: 9 ident_expr { name: "com" } } field: "example" } } field: "x" } } args: { id: 7 ident_expr: { name: "com.example.y" } } } })pb"; TEST(ReferenceResolver, ResolveQualifiedIdentifiers) { RuntimeOptions options; ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_OK(EnableReferenceResolver( builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); CheckedExpr checked_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kIdentifierExpression, &checked_expr)); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, checked_expr)); google::protobuf::Arena arena; Activation activation; activation.InsertOrAssignValue("com.example.x", IntValue(3)); activation.InsertOrAssignValue("com.example.y", IntValue(4)); ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_EQ(value.GetInt().NativeValue(), 7); } TEST(ReferenceResolver, ResolveQualifiedIdentifiersSkipParseOnly) { RuntimeOptions options; ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_OK(EnableReferenceResolver( builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); CheckedExpr checked_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kIdentifierExpression, &checked_expr)); // Discard type-check information Expr unchecked_expr = checked_expr.expr(); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, checked_expr.expr())); google::protobuf::Arena arena; Activation activation; activation.InsertOrAssignValue("com.example.x", IntValue(3)); activation.InsertOrAssignValue("com.example.y", IntValue(4)); ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_THAT(value.GetError().NativeValue(), StatusIs(absl::StatusCode::kUnknown, HasSubstr("\"com\""))); } // cel.expr.conformance.proto2.GlobalEnum.GAZ == 2 constexpr absl::string_view kEnumExpr = R"pb( reference_map: { key: 8 value: { name: "cel.expr.conformance.proto2.GlobalEnum.GAZ" value: { int64_value: 2 } } } reference_map: { key: 9 value: { overload_id: "equals" } } type_map: { key: 8 value: { primitive: INT64 } } type_map: { key: 9 value: { primitive: BOOL } } type_map: { key: 10 value: { primitive: INT64 } } source_info: { location: "" line_offsets: 1 line_offsets: 64 line_offsets: 77 positions: { key: 1 value: 13 } positions: { key: 2 value: 19 } positions: { key: 3 value: 23 } positions: { key: 4 value: 28 } positions: { key: 5 value: 33 } positions: { key: 6 value: 36 } positions: { key: 7 value: 43 } positions: { key: 8 value: 54 } positions: { key: 9 value: 59 } positions: { key: 10 value: 62 } } expr: { id: 9 call_expr: { function: "_==_" args: { id: 8 ident_expr: { name: "cel.expr.conformance.proto2.GlobalEnum.GAZ" } } args: { id: 10 const_expr: { int64_value: 2 } } } })pb"; TEST(ReferenceResolver, ResolveEnumConstants) { RuntimeOptions options; ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_OK(EnableReferenceResolver( builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); CheckedExpr checked_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kEnumExpr, &checked_expr)); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, checked_expr)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_TRUE(value.GetBool().NativeValue()); } TEST(ReferenceResolver, ResolveEnumConstantsSkipParseOnly) { RuntimeOptions options; ASSERT_OK_AND_ASSIGN(RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); ASSERT_OK(EnableReferenceResolver( builder, ReferenceResolverEnabled::kCheckedExpressionOnly)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); CheckedExpr checked_expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kEnumExpr, &checked_expr)); Expr unchecked_expr = checked_expr.expr(); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( *runtime, unchecked_expr)); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); ASSERT_TRUE(value->Is()); EXPECT_THAT( value.GetError().NativeValue(), StatusIs(absl::StatusCode::kUnknown, HasSubstr("\"cel.expr.conformance.proto2.GlobalEnum.GAZ\""))); } } // namespace } // namespace cel ================================================ FILE: runtime/regex_precompilation.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/regex_precompilation.h" #include "absl/base/macros.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/native_type.h" #include "eval/compiler/regex_precompilation_optimization.h" #include "internal/casts.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" namespace cel::extensions { namespace { using ::cel::internal::down_cast; using ::cel::runtime_internal::RuntimeFriendAccess; using ::cel::runtime_internal::RuntimeImpl; using ::google::api::expr::runtime::CreateRegexPrecompilationExtension; absl::StatusOr RuntimeImplFromBuilder(RuntimeBuilder& builder) { Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); if (RuntimeFriendAccess::RuntimeTypeId(runtime) != NativeTypeId::For()) { return absl::UnimplementedError( "regex precompilation only supported on the default cel::Runtime " "implementation."); } RuntimeImpl& runtime_impl = down_cast(runtime); return &runtime_impl; } } // namespace absl::Status EnableRegexPrecompilation(RuntimeBuilder& builder) { CEL_ASSIGN_OR_RETURN(RuntimeImpl * runtime_impl, RuntimeImplFromBuilder(builder)); ABSL_ASSERT(runtime_impl != nullptr); runtime_impl->expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension( runtime_impl->expr_builder().options().regex_max_program_size)); return absl::OkStatus(); } } // namespace cel::extensions ================================================ FILE: runtime/regex_precompilation.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_REGEX_PRECOMPILATION_FOLDING_H_ #define THIRD_PARTY_CEL_CPP_REGEX_PRECOMPILATION_FOLDING_H_ #include "absl/status/status.h" #include "runtime/runtime_builder.h" namespace cel::extensions { // Enable regular expression precompilation. // // Attempts to precompile regular expression patterns that are known to be // constant in 'match' calls. If an invalid pattern is encountered, expression // planning will fail instead of returning a program. absl::Status EnableRegexPrecompilation(RuntimeBuilder& builder); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_REGEX_PRECOMPILATION_FOLDING_H_ ================================================ FILE: runtime/regex_precompilation_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/regex_precompilation.h" #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "base/function_adapter.h" #include "common/value.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" #include "runtime/constant_folding.h" #include "runtime/register_function_helper.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::_; using ::testing::HasSubstr; using ValueMatcher = testing::Matcher; struct TestCase { std::string name; std::string expression; ValueMatcher result_matcher; absl::Status create_status; }; MATCHER_P(IsIntValue, expected, "") { const Value& value = arg; return value->Is() && value.GetInt().NativeValue() == expected; } MATCHER_P(IsBoolValue, expected, "") { const Value& value = arg; return value->Is() && value.GetBool().NativeValue() == expected; } MATCHER_P(IsErrorValue, expected_substr, "") { const Value& value = arg; return value->Is() && absl::StrContains(value.GetError().NativeValue().message(), expected_substr); } class RegexPrecompilationTest : public testing::TestWithParam {}; TEST_P(RegexPrecompilationTest, Basic) { RuntimeOptions options; const TestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); auto status = RegisterHelper, const StringValue&, const StringValue&>>:: RegisterGlobalOverload( "prepend", [](const StringValue& value, const StringValue& prefix) { return StringValue( absl::StrCat(prefix.ToString(), value.ToString())); }, builder.function_registry()); ASSERT_THAT(status, IsOk()); ASSERT_THAT(EnableRegexPrecompilation(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expression)); auto program_or = ProtobufRuntimeAdapter::CreateProgram(*runtime, parsed_expr); if (!test_case.create_status.ok()) { ASSERT_THAT(program_or.status(), StatusIs(test_case.create_status.code(), HasSubstr(test_case.create_status.message()))); return; } ASSERT_OK_AND_ASSIGN(auto program, std::move(program_or)); google::protobuf::Arena arena; Activation activation; activation.InsertOrAssignValue("string_var", StringValue(&arena, "string_var")); ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); EXPECT_THAT(value, test_case.result_matcher); } TEST_P(RegexPrecompilationTest, WithConstantFolding) { RuntimeOptions options; const TestCase& test_case = GetParam(); ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); auto status = RegisterHelper, const StringValue&, const StringValue&>>:: RegisterGlobalOverload( "prepend", [](const StringValue& value, const StringValue& prefix) { return StringValue( absl::StrCat(prefix.ToString(), value.ToString())); }, builder.function_registry()); ASSERT_THAT(status, IsOk()); ASSERT_THAT(EnableConstantFolding(builder), IsOk()); ASSERT_THAT(EnableRegexPrecompilation(builder), IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse(test_case.expression)); auto program_or = ProtobufRuntimeAdapter::CreateProgram(*runtime, parsed_expr); if (!test_case.create_status.ok()) { ASSERT_THAT(program_or.status(), StatusIs(test_case.create_status.code(), HasSubstr(test_case.create_status.message()))); return; } ASSERT_OK_AND_ASSIGN(auto program, std::move(program_or)); google::protobuf::Arena arena; Activation activation; activation.InsertOrAssignValue("string_var", StringValue(&arena, "string_var")); ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); EXPECT_THAT(value, test_case.result_matcher); } INSTANTIATE_TEST_SUITE_P( Cases, RegexPrecompilationTest, testing::ValuesIn(std::vector{ {"matches_receiver", R"(string_var.matches(r's\w+_var'))", IsBoolValue(true)}, {"matches_receiver_false", R"(string_var.matches(r'string_var\d+'))", IsBoolValue(false)}, {"matches_global_true", R"(matches(string_var, r's\w+_var'))", IsBoolValue(true)}, {"matches_global_false", R"(matches(string_var, r'string_var\d+'))", IsBoolValue(false)}, {"matches_bad_re2_expression", "matches('123', r'(?& info) { return info.param.name; }); } // namespace } // namespace cel::extensions ================================================ FILE: runtime/register_function_helper.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_REGISTER_FUNCTION_HELPER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_REGISTER_FUNCTION_HELPER_H_ #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "common/function_descriptor.h" #include "runtime/function_registry.h" namespace cel { // Helper class for performing registration with function adapter. // // Usage: // // auto status = RegisterHelper> // ::RegisterGlobalOverload( // '_<_', // [](int64_t x, int64_t y) -> bool {return x < y}, // registry); // // if (!status.ok) return status; // // Note: if using this with status macros (*RETURN_IF_ERROR), an extra set of // parentheses is needed around the multi-argument template specifier. template class RegisterHelper { public: // Generic registration for an adapted function. Prefer using one of the more // specific Register* functions. template static absl::Status Register(absl::string_view name, bool receiver_style, FunctionT&& fn, FunctionRegistry& registry, bool strict) { return registry.Register( AdapterT::CreateDescriptor(name, receiver_style, strict), AdapterT::WrapFunction(std::forward(fn))); } template static absl::Status Register(absl::string_view name, bool receiver_style, FunctionT&& fn, FunctionRegistry& registry, FunctionDescriptorOptions options = {}) { return registry.Register( AdapterT::CreateDescriptor(name, receiver_style, options), AdapterT::WrapFunction(std::forward(fn))); } // Registers a global overload (.e.g. size() ) template static absl::Status RegisterGlobalOverload(absl::string_view name, FunctionT&& fn, FunctionRegistry& registry) { return Register(name, /*receiver_style=*/false, std::forward(fn), registry); } // Registers a member overload (.e.g. .size()) template static absl::Status RegisterMemberOverload(absl::string_view name, FunctionT&& fn, FunctionRegistry& registry) { return Register(name, /*receiver_style=*/true, std::forward(fn), registry); } // Registers a non-strict overload. // // Non-strict functions may receive errors or unknown values as arguments, // and must correctly propagate them. // // Most extension functions should prefer 'strict' overloads where the // evaluator handles unknown and error propagation. template static absl::Status RegisterNonStrictOverload(absl::string_view name, FunctionT&& fn, FunctionRegistry& registry) { return Register(name, /*receiver_style=*/false, std::forward(fn), registry, /*strict=*/false); } }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_REGISTER_FUNCTION_HELPER_H_ ================================================ FILE: runtime/runtime.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Interfaces for runtime concepts. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_H_ #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/ast.h" #include "base/type_provider.h" #include "common/native_type.h" #include "common/value.h" #include "runtime/activation_interface.h" #include "runtime/runtime_issue.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace runtime_internal { class RuntimeFriendAccess; } // namespace runtime_internal class EmbedderContext; // Options for the Program::Evaluate call. struct EvaluateOptions { // Optional message factory to use for the duration of the Evaluate call. // If unset, a default message factory will be provided by the runtime. google::protobuf::MessageFactory* absl_nullable message_factory = nullptr; // Optional embedder context to use for the duration of the Evaluate call. // This is used to access custom data in extension functions. // This is only propagated to functions that are marked as context sensitive. const EmbedderContext* absl_nullable embedder_context = nullptr; }; // Representation of an evaluable CEL expression. // // See Runtime below for creating new programs. class Program { public: virtual ~Program() = default; // Evaluate the program. // // Non-recoverable errors (i.e. outside of CEL's notion of an error) are // returned as a non-ok absl::Status. These are propagated immediately and do // not participate in CEL's notion of error handling. // // CEL errors are represented as result with an Ok status and a held // cel::ErrorValue result. // // Activation manages instances of variables available in the cel expression's // environment. // // The arena will be used to as necessary to allocate values and must outlive // the returned value, as must this program. // // For consistency, users should use the same arena to create values // in the activation and for Program evaluation. absl::StatusOr Evaluate( google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, const ActivationInterface& activation, const EvaluateOptions& options = {}) const ABSL_ATTRIBUTE_LIFETIME_BOUND { return EvaluateImpl(activation, arena, options); } ABSL_DEPRECATED("Use the EvaluateOptions overload instead.") absl::StatusOr Evaluate( google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nullable message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, const ActivationInterface& activation) const ABSL_ATTRIBUTE_LIFETIME_BOUND { return EvaluateImpl(activation, arena, {message_factory}); } virtual const TypeProvider& GetTypeProvider() const = 0; protected: virtual absl::StatusOr EvaluateImpl( const ActivationInterface& activation, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, const EvaluateOptions& options) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; }; // Representation for a traceable CEL expression. // // Implementations provide an additional Trace method that evaluates the // expression and invokes a callback allowing callers to inspect intermediate // state during evaluation. class TraceableProgram : public Program { public: // EvaluationListener may be provided to an EvaluateWithCallback call to // inspect intermediate values during evaluation. // // The callback is called on after every program step that corresponds // to an AST expression node. The value provided is the top of the value // stack, corresponding to the result of evaluating the given sub expression. // // A returning a non-ok status stops evaluation and forwards the error. using EvaluationListener = absl::AnyInvocable; using Program::Evaluate; // Evaluate the Program plan with a Listener. // // The given callback will be invoked after evaluating any program step // that corresponds to an AST node in the planned CEL expression. // // If the callback returns a non-ok status, evaluation stops and the Status // is forwarded as the result of the EvaluateWithCallback call. absl::StatusOr Trace( google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, const ActivationInterface& activation, EvaluationListener evaluation_listener, const EvaluateOptions& options = {}) const ABSL_ATTRIBUTE_LIFETIME_BOUND { return TraceImpl(activation, std::move(evaluation_listener), arena, options); } ABSL_DEPRECATED("Use the EvaluateOptions overload instead.") absl::StatusOr Trace( google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, google::protobuf::MessageFactory* absl_nullable message_factory ABSL_ATTRIBUTE_LIFETIME_BOUND, const ActivationInterface& activation, EvaluationListener evaluation_listener) const ABSL_ATTRIBUTE_LIFETIME_BOUND { return TraceImpl(activation, std::move(evaluation_listener), arena, {message_factory}); } protected: absl::StatusOr EvaluateImpl(const ActivationInterface& activation, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, const EvaluateOptions& options) const ABSL_ATTRIBUTE_LIFETIME_BOUND override { return TraceImpl(activation, nullptr, arena, options); } virtual absl::StatusOr TraceImpl( const ActivationInterface& activation, EvaluationListener evaluation_listener, google::protobuf::Arena* absl_nonnull arena ABSL_ATTRIBUTE_LIFETIME_BOUND, const EvaluateOptions& options) const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; }; // Interface for a CEL runtime. // // Manages the state necessary to generate Programs. // // Runtime instances should be created from a RuntimeBuilder rather than // instantiated directly. // // Implementations provided by CEL will be thread-compatible, but write // operations on the underlying environment (TypeRegistry, FunctionRegistry) or // on the implementation via down casting must be synchronized by the caller and // may invalidate any Programs created from the Runtime. class Runtime { public: struct CreateProgramOptions { // Optional output for collecting issues encountered while planning. // If non-null, vector is cleared and encountered issues are added. std::vector* issues = nullptr; }; virtual ~Runtime() = default; absl::StatusOr> CreateProgram( std::unique_ptr ast) const { return CreateProgram(std::move(ast), CreateProgramOptions{}); } virtual absl::StatusOr> CreateProgram( std::unique_ptr ast, const CreateProgramOptions& options) const = 0; absl::StatusOr> CreateTraceableProgram( std::unique_ptr ast) const { return CreateTraceableProgram(std::move(ast), CreateProgramOptions{}); } virtual absl::StatusOr> CreateTraceableProgram(std::unique_ptr ast, const CreateProgramOptions& options) const = 0; virtual const TypeProvider& GetTypeProvider() const = 0; virtual const google::protobuf::DescriptorPool* absl_nonnull GetDescriptorPool() const = 0; virtual google::protobuf::MessageFactory* absl_nonnull GetMessageFactory() const = 0; private: friend class runtime_internal::RuntimeFriendAccess; virtual NativeTypeId GetNativeTypeId() const = 0; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_H_ ================================================ FILE: runtime/runtime_builder.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_H_ #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "runtime/function_registry.h" #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" namespace cel { // Forward declare for friend access to avoid requiring a link dependency on the // standard implementation and some extensions. namespace runtime_internal { class RuntimeFriendAccess; } // namespace runtime_internal class RuntimeBuilder; absl::StatusOr CreateRuntimeBuilder( absl_nonnull std::shared_ptr, const RuntimeOptions&); // RuntimeBuilder provides mutable accessors to configure a new runtime. // // Instances of this class are consumed when built. class RuntimeBuilder { public: // Move-only RuntimeBuilder(const RuntimeBuilder&) = delete; RuntimeBuilder& operator=(const RuntimeBuilder&) = delete; RuntimeBuilder(RuntimeBuilder&&) = default; RuntimeBuilder& operator=(RuntimeBuilder&&) = default; TypeRegistry& type_registry() { ABSL_DCHECK(runtime_ != nullptr); return *type_registry_; } FunctionRegistry& function_registry() { ABSL_DCHECK(runtime_ != nullptr); return *function_registry_; } // Return the built runtime. // // The builder is left in an undefined state after this call and cannot be // reused. absl::StatusOr> Build() && { return std::move(runtime_); } private: friend class runtime_internal::RuntimeFriendAccess; friend absl::StatusOr CreateRuntimeBuilder( absl_nonnull std::shared_ptr, const RuntimeOptions&); // Constructor for a new runtime builder. // // It's assumed that the type registry and function registry are managed by // the runtime. // // CEL users should use one of the factory functions for a new builder. // See standard_runtime_builder_factory.h and runtime_builder_factory.h RuntimeBuilder(TypeRegistry& type_registry, FunctionRegistry& function_registry, std::unique_ptr runtime) : type_registry_(&type_registry), function_registry_(&function_registry), runtime_(std::move(runtime)) {} Runtime& runtime() { return *runtime_; } TypeRegistry* type_registry_; FunctionRegistry* function_registry_; std::unique_ptr runtime_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_H_ ================================================ FILE: runtime/runtime_builder_factory.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/runtime_builder_factory.h" #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "internal/noop_delete.h" #include "internal/status_macros.h" #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "google/protobuf/descriptor.h" namespace cel { using ::cel::runtime_internal::RuntimeEnv; using ::cel::runtime_internal::RuntimeImpl; absl::StatusOr CreateRuntimeBuilder( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, const RuntimeOptions& options) { ABSL_DCHECK(descriptor_pool != nullptr); return CreateRuntimeBuilder( std::shared_ptr( descriptor_pool, internal::NoopDeleteFor()), options); } absl::StatusOr CreateRuntimeBuilder( absl_nonnull std::shared_ptr descriptor_pool, const RuntimeOptions& options) { // TODO(uncreated-issue/57): and internal API for adding extensions that need to // downcast to the runtime impl. // TODO(uncreated-issue/56): add API for attaching an issue listener (replacing the // vector overloads). ABSL_DCHECK(descriptor_pool != nullptr); auto environment = std::make_shared(std::move(descriptor_pool)); CEL_RETURN_IF_ERROR(environment->Initialize()); auto runtime_impl = std::make_unique(std::move(environment), options); runtime_impl->expr_builder().set_container(options.container); auto& type_registry = runtime_impl->type_registry(); auto& function_registry = runtime_impl->function_registry(); return RuntimeBuilder(type_registry, function_registry, std::move(runtime_impl)); } } // namespace cel ================================================ FILE: runtime/runtime_builder_factory.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_FACTORY_H_ #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "google/protobuf/descriptor.h" namespace cel { // Create an unconfigured builder using the default Runtime implementation. // // The provided descriptor pool is used when dealing with `google.protobuf.Any` // messages, as well as for implementing struct creation syntax // `foo.Bar{my_field: 1}`. The descriptor pool must outlive the resulting // RuntimeBuilder, the `Runtime` it creates, and any `Program` that the // `Runtime` creates. The descriptor pool must include the minimally necessary // descriptors required by CEL. Those are the following: // - google.protobuf.NullValue // - google.protobuf.BoolValue // - google.protobuf.Int32Value // - google.protobuf.Int64Value // - google.protobuf.UInt32Value // - google.protobuf.UInt64Value // - google.protobuf.FloatValue // - google.protobuf.DoubleValue // - google.protobuf.BytesValue // - google.protobuf.StringValue // - google.protobuf.Any // - google.protobuf.Duration // - google.protobuf.Timestamp // // This is provided for environments that only use a subset of the CEL standard // builtins. Most users should prefer CreateStandardRuntimeBuilder. // // Callers must register appropriate builtins. absl::StatusOr CreateRuntimeBuilder( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, const RuntimeOptions& options); absl::StatusOr CreateRuntimeBuilder( absl_nonnull std::shared_ptr descriptor_pool, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_BUILDER_FACTORY_H_ ================================================ FILE: runtime/runtime_issue.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_ISSUE_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_ISSUE_H_ #include #include "absl/status/status.h" namespace cel { // Represents an issue with a given CEL expression. // // The error details are represented as an absl::Status for compatibility // reasons, but users should not depend on this. class RuntimeIssue { public: // Severity of the RuntimeIssue. // // Can be used to determine whether to continue program planning or return // early. enum class Severity { // The issue may lead to runtime errors in evaluation. kWarning = 0, // The expression is invalid or unsupported. kError = 1, // Arbitrary max value above Error. kNotForUseWithExhaustiveSwitchStatements = 15 }; // Code for well-known runtime error kinds. enum class ErrorCode { // Overload not provided for given function call signature. kNoMatchingOverload, // Field access refers to unknown field for given type. kNoSuchField, // Other error outside the canonical set. kOther, }; static RuntimeIssue CreateError(absl::Status status, ErrorCode error_code = ErrorCode::kOther) { return RuntimeIssue(std::move(status), Severity::kError, error_code); } static RuntimeIssue CreateWarning(absl::Status status, ErrorCode error_code = ErrorCode::kOther) { return RuntimeIssue(std::move(status), Severity::kWarning, error_code); } RuntimeIssue(const RuntimeIssue& other) = default; RuntimeIssue& operator=(const RuntimeIssue& other) = default; RuntimeIssue(RuntimeIssue&& other) = default; RuntimeIssue& operator=(RuntimeIssue&& other) = default; Severity severity() const { return severity_; } ErrorCode error_code() const { return error_code_; } const absl::Status& ToStatus() const& { return status_; } absl::Status ToStatus() && { return std::move(status_); } private: RuntimeIssue(absl::Status status, Severity severity, ErrorCode error_code) : status_(std::move(status)), error_code_(error_code), severity_(severity) {} absl::Status status_; ErrorCode error_code_; Severity severity_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_ISSUE_H_ ================================================ FILE: runtime/runtime_options.h ================================================ /* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_OPTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_OPTIONS_H_ #include #include "absl/base/attributes.h" namespace cel { // Options for unknown processing. enum class UnknownProcessingOptions { // No unknown processing. kDisabled, // Only attributes supported. kAttributeOnly, // Attributes and functions supported. Function results are dependent on the // logic for handling unknown_attributes, so clients must opt in to both. kAttributeAndFunction }; // Options for handling unset wrapper types on field access. enum class ProtoWrapperTypeOptions { // Default: legacy behavior following proto semantics (unset behaves as though // it is set to default value). kUnsetProtoDefault, // CEL spec behavior, unset wrapper is treated as a null value when accessed. kUnsetNull, }; // LINT.IfChange // Interpreter options for controlling evaluation and builtin functions. // // Members should provide simple parameters for configuring core features and // built-ins. // // Optimizations or features that have a heavy footprint should be added via an // extension API. struct RuntimeOptions { // Default container for resolving variables, types, and functions. // Follows protobuf namespace rules. std::string container = ""; // Level of unknown support enabled. UnknownProcessingOptions unknown_processing = UnknownProcessingOptions::kDisabled; bool enable_missing_attribute_errors = false; // Enable timestamp duration overflow checks. // // The CEL-Spec indicates that overflow should occur outside the range of // string-representable timestamps, and at the limit of durations which can be // expressed with a single int64 value. bool enable_timestamp_duration_overflow_errors = false; // Enable short-circuiting of the logical operator evaluation. If enabled, // AND, OR, and TERNARY do not evaluate the entire expression once the the // resulting value is known from the left-hand side. bool short_circuiting = true; // Enable comprehension expressions (e.g. exists, all) bool enable_comprehension = true; // Set maximum number of iterations in the comprehension expressions if // comprehensions are enabled. The limit applies globally per an evaluation, // including the nested loops as well. Use value 0 to disable the upper bound. int comprehension_max_iterations = 10000; // Enable list append within comprehensions. Note, this option is not safe // with hand-rolled ASTs. bool enable_comprehension_list_append = false; // Enable mutable map construction within comprehensions. Note, this option is // not safe with hand-rolled ASTs. bool enable_comprehension_mutable_map = false; // Enable RE2 match() overload. bool enable_regex = true; // Set maximum program size for RE2 regex if regex overload is enabled. // Evaluates to an error if a regex exceeds it. Use value 0 to disable the // upper bound. int regex_max_program_size = 0; // Enable string() overloads. bool enable_string_conversion = true; // Enable string concatenation overload. bool enable_string_concat = true; // Enable list concatenation overload. bool enable_list_concat = true; // Enable list membership overload. bool enable_list_contains = true; // Treat builder warnings as fatal errors. bool fail_on_warnings = true; // Enable the resolution of qualified type identifiers as type values instead // of field selections. // // This toggle may cause certain identifiers which overlap with CEL built-in // type or with protobuf message types linked into the binary to be resolved // as static type values rather than as per-eval variables. bool enable_qualified_type_identifiers = false; // Enable heterogeneous comparisons (e.g. support for cross-type comparisons). ABSL_DEPRECATED( "The ability to disable heterogeneous equality is being removed in the " "near future") bool enable_heterogeneous_equality = true; // Enables unwrapping proto wrapper types to null if unset. e.g. if an // expression access a field of type google.protobuf.Int64Value that is unset, // that will result in a Null cel value, as opposed to returning the // cel representation of the proto defined default int64: 0. bool enable_empty_wrapper_null_unboxing = false; // Enable lazy cel.bind alias initialization. // // This is now always enabled. Setting this option has no effect. It will be // removed in a later update. bool enable_lazy_bind_initialization = true; // Enable recursive planning with a maximum recursion depth for evaluable // programs. // // This limit is proportional to the maximum number of recursive Evaluate // calls that a single expression program might require while evaluating. This // is coarse -- the actual C++ stack requirements will vary depending on the // expression. // // This does not account for re-entrant evaluation in a client's extension // function (i.e. a CEL function that calls Evaluate on another CEL program) // // If the limit is exceeded, the planner will return an error instead of // planning the program. // // -1 means unbounded. // 0 means disabled (using a heap-based stack machine instead), which is the // default. int max_recursion_depth = 0; // Enable tracing support for recursively planned programs. // // Unlike the stack machine implementation, supporting tracing can affect // performance whether or not tracing is requested for a given evaluation. bool enable_recursive_tracing = false; // Enable fast implementations for some CEL standard functions. // // Uses a custom implementation for some functions in the CEL standard, // bypassing normal dispatching logic and safety checks for functions. // // This prevents extending or disabling these functions in most cases. The // expression planner will make a best effort attempt to check if custom // overloads have been added for these functions, and will attempt to use them // if they exist. // // Currently applies to !_, @not_strictly_false, _==_, _!=_, @in bool enable_fast_builtins = true; // When enabled, string(double) will format the double with enough precision // to ensure that the original double value can be recovered exactly. // // If available, will use the `std::to_chars` standard library function to // perform the conversion to generate the shortest representation. // // Otherwise, will fall back to formatting with the worst-case required // precision. // // If disabled, will use the legacy behavior of rounding to 6 decimal places. bool enable_precision_preserving_double_format = true; }; // LINT.ThenChange(//depot/google3/eval/public/cel_options.h) } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_RUNTIME_OPTIONS_H_ ================================================ FILE: runtime/standard/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") # Provides registrars for CEL standard definitions. # TODO(uncreated-issue/41): CEL users shouldn't need to use these directly, instead they should prefer to # use RegisterBuiltins when available. package( # Under active development, not yet being released. default_visibility = ["//visibility:public"], ) cc_library( name = "comparison_functions", srcs = [ "comparison_functions.cc", ], hdrs = [ "comparison_functions.h", ], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:number", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/time", ], ) cc_test( name = "comparison_functions_test", size = "small", srcs = [ "comparison_functions_test.cc", ], deps = [ ":comparison_functions", "//base:builtins", "//common:kind", "//internal:testing", "@com_google_absl//absl/strings", ], ) cc_library( name = "container_membership_functions", srcs = [ "container_membership_functions.cc", ], hdrs = [ "container_membership_functions.h", ], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:number", "//internal:status_macros", "//runtime:function_registry", "//runtime:register_function_helper", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "container_membership_functions_test", size = "small", srcs = [ "container_membership_functions_test.cc", ], deps = [ ":container_membership_functions", "//base:builtins", "//common:function_descriptor", "//common:kind", "//internal:testing", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/strings", ], ) cc_library( name = "equality_functions", srcs = ["equality_functions.cc"], hdrs = ["equality_functions.h"], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//common:value_kind", "//internal:number", "//internal:status_macros", "//runtime:function_registry", "//runtime:register_function_helper", "//runtime:runtime_options", "//runtime/internal:errors", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "equality_functions_test", size = "small", srcs = [ "equality_functions_test.cc", ], deps = [ ":equality_functions", "//base:builtins", "//common:function_descriptor", "//common:kind", "//internal:testing", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/status:status_matchers", ], ) cc_library( name = "logical_functions", srcs = [ "logical_functions.cc", ], hdrs = [ "logical_functions.h", ], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:status_macros", "//runtime:function_registry", "//runtime:register_function_helper", "//runtime:runtime_options", "//runtime/internal:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) cc_test( name = "logical_functions_test", size = "small", srcs = [ "logical_functions_test.cc", ], deps = [ ":logical_functions", "//base:builtins", "//common:function_descriptor", "//common:kind", "//common:value", "//internal:testing", "//internal:testing_descriptor_pool", "//internal:testing_message_factory", "//runtime:function", "//runtime:function_overload_reference", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "container_functions", srcs = ["container_functions.cc"], hdrs = ["container_functions.h"], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "container_functions_test", size = "small", srcs = [ "container_functions_test.cc", ], deps = [ ":container_functions", "//base:builtins", "//common:function_descriptor", "//internal:testing", ], ) cc_library( name = "type_conversion_functions", srcs = ["type_conversion_functions.cc"], hdrs = ["type_conversion_functions.h"], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:overflow", "//internal:status_macros", "//internal:time", "//internal:utf8", "//runtime:function", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "type_conversion_functions_test", size = "small", srcs = [ "type_conversion_functions_test.cc", ], deps = [ ":type_conversion_functions", "//base:builtins", "//common:function_descriptor", "//internal:testing", ], ) cc_library( name = "arithmetic_functions", srcs = ["arithmetic_functions.cc"], hdrs = ["arithmetic_functions.h"], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:overflow", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) cc_test( name = "arithmetic_functions_test", size = "small", srcs = [ "arithmetic_functions_test.cc", ], deps = [ ":arithmetic_functions", "//base:builtins", "//common:function_descriptor", "//internal:testing", ], ) cc_library( name = "time_functions", srcs = ["time_functions.cc"], hdrs = ["time_functions.h"], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:overflow", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", ], ) cc_test( name = "time_functions_test", size = "small", srcs = [ "time_functions_test.cc", ], deps = [ ":time_functions", "//base:builtins", "//common:function_descriptor", "//internal:testing", ], ) cc_library( name = "string_functions", srcs = ["string_functions.cc"], hdrs = ["string_functions.h"], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "string_functions_test", size = "small", srcs = [ "string_functions_test.cc", ], deps = [ ":string_functions", "//base:builtins", "//common:function_descriptor", "//internal:testing", ], ) cc_library( name = "regex_functions", srcs = ["regex_functions.cc"], hdrs = ["regex_functions.h"], deps = [ "//base:builtins", "//base:function_adapter", "//common:value", "//internal:re2_options", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", ], ) cc_test( name = "regex_functions_test", srcs = ["regex_functions_test.cc"], deps = [ ":regex_functions", "//base:builtins", "//common:function_descriptor", "//internal:testing", ], ) ================================================ FILE: runtime/standard/arithmetic_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/arithmetic_functions.h" #include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "internal/overflow.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { namespace { // Template functions providing arithmetic operations template Value Add(Type v0, Type v1); template <> Value Add(int64_t v0, int64_t v1) { auto sum = cel::internal::CheckedAdd(v0, v1); if (!sum.ok()) { return ErrorValue(sum.status()); } return IntValue(*sum); } template <> Value Add(uint64_t v0, uint64_t v1) { auto sum = cel::internal::CheckedAdd(v0, v1); if (!sum.ok()) { return ErrorValue(sum.status()); } return UintValue(*sum); } template <> Value Add(double v0, double v1) { return DoubleValue(v0 + v1); } template Value Sub(Type v0, Type v1); template <> Value Sub(int64_t v0, int64_t v1) { auto diff = cel::internal::CheckedSub(v0, v1); if (!diff.ok()) { return ErrorValue(diff.status()); } return IntValue(*diff); } template <> Value Sub(uint64_t v0, uint64_t v1) { auto diff = cel::internal::CheckedSub(v0, v1); if (!diff.ok()) { return ErrorValue(diff.status()); } return UintValue(*diff); } template <> Value Sub(double v0, double v1) { return DoubleValue(v0 - v1); } template Value Mul(Type v0, Type v1); template <> Value Mul(int64_t v0, int64_t v1) { auto prod = cel::internal::CheckedMul(v0, v1); if (!prod.ok()) { return ErrorValue(prod.status()); } return IntValue(*prod); } template <> Value Mul(uint64_t v0, uint64_t v1) { auto prod = cel::internal::CheckedMul(v0, v1); if (!prod.ok()) { return ErrorValue(prod.status()); } return UintValue(*prod); } template <> Value Mul(double v0, double v1) { return DoubleValue(v0 * v1); } template Value Div(Type v0, Type v1); // Division operations for integer types should check for // division by 0 template <> Value Div(int64_t v0, int64_t v1) { auto quot = cel::internal::CheckedDiv(v0, v1); if (!quot.ok()) { return ErrorValue(quot.status()); } return IntValue(*quot); } // Division operations for integer types should check for // division by 0 template <> Value Div(uint64_t v0, uint64_t v1) { auto quot = cel::internal::CheckedDiv(v0, v1); if (!quot.ok()) { return ErrorValue(quot.status()); } return UintValue(*quot); } template <> Value Div(double v0, double v1) { static_assert(std::numeric_limits::is_iec559, "Division by zero for doubles must be supported"); // For double, division will result in +/- inf return DoubleValue(v0 / v1); } // Modulo operation template Value Modulo(Type v0, Type v1); // Modulo operations for integer types should check for // division by 0 template <> Value Modulo(int64_t v0, int64_t v1) { auto mod = cel::internal::CheckedMod(v0, v1); if (!mod.ok()) { return ErrorValue(mod.status()); } return IntValue(*mod); } template <> Value Modulo(uint64_t v0, uint64_t v1) { auto mod = cel::internal::CheckedMod(v0, v1); if (!mod.ok()) { return ErrorValue(mod.status()); } return UintValue(*mod); } // Helper method // Registers all arithmetic functions for template parameter type. template absl::Status RegisterArithmeticFunctionsForType(FunctionRegistry& registry) { using FunctionAdapter = cel::BinaryFunctionAdapter; CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kAdd, false), FunctionAdapter::WrapFunction(&Add))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kSubtract, false), FunctionAdapter::WrapFunction(&Sub))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kMultiply, false), FunctionAdapter::WrapFunction(&Mul))); return registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kDivide, false), FunctionAdapter::WrapFunction(&Div)); } } // namespace absl::Status RegisterArithmeticFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); // Modulo CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor( cel::builtin::kModulo, false), BinaryFunctionAdapter::WrapFunction( &Modulo))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor( cel::builtin::kModulo, false), BinaryFunctionAdapter::WrapFunction( &Modulo))); // Negation group CEL_RETURN_IF_ERROR( registry.Register(UnaryFunctionAdapter::CreateDescriptor( cel::builtin::kNeg, false), UnaryFunctionAdapter::WrapFunction( [](int64_t value) -> Value { auto inv = cel::internal::CheckedNegation(value); if (!inv.ok()) { return ErrorValue(inv.status()); } return IntValue(*inv); }))); return registry.Register( UnaryFunctionAdapter::CreateDescriptor(cel::builtin::kNeg, false), UnaryFunctionAdapter::WrapFunction( [](double value) -> double { return -value; })); } } // namespace cel ================================================ FILE: runtime/standard/arithmetic_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_ARITHMETIC_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_ARITHMETIC_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register builtin arithmetic operators: // _+_ (addition), _-_ (subtraction), -_ (negation), _/_ (division), // _*_ (multiplication), _%_ (modulo) // // Most users should use RegisterBuiltinFunctions, which includes these // definitions. absl::Status RegisterArithmeticFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_ARITHMETIC_FUNCTIONS_H_ ================================================ FILE: runtime/standard/arithmetic_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/arithmetic_functions.h" #include #include "base/builtins.h" #include "common/function_descriptor.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::UnorderedElementsAre; MATCHER_P2(MatchesOperatorDescriptor, name, expected_kind, "") { const FunctionDescriptor& descriptor = arg.descriptor; std::vector types{expected_kind, expected_kind}; return descriptor.name() == name && descriptor.receiver_style() == false && descriptor.types() == types; } MATCHER_P(MatchesNegationDescriptor, expected_kind, "") { const FunctionDescriptor& descriptor = arg.descriptor; std::vector types{expected_kind}; return descriptor.name() == builtin::kNeg && descriptor.receiver_style() == false && descriptor.types() == types; } TEST(RegisterArithmeticFunctions, Registered) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterArithmeticFunctions(registry, options)); EXPECT_THAT(registry.FindStaticOverloads(builtin::kAdd, false, {Kind::kAny, Kind::kAny}), UnorderedElementsAre( MatchesOperatorDescriptor(builtin::kAdd, Kind::kInt), MatchesOperatorDescriptor(builtin::kAdd, Kind::kDouble), MatchesOperatorDescriptor(builtin::kAdd, Kind::kUint))); EXPECT_THAT(registry.FindStaticOverloads(builtin::kSubtract, false, {Kind::kAny, Kind::kAny}), UnorderedElementsAre( MatchesOperatorDescriptor(builtin::kSubtract, Kind::kInt), MatchesOperatorDescriptor(builtin::kSubtract, Kind::kDouble), MatchesOperatorDescriptor(builtin::kSubtract, Kind::kUint))); EXPECT_THAT(registry.FindStaticOverloads(builtin::kDivide, false, {Kind::kAny, Kind::kAny}), UnorderedElementsAre( MatchesOperatorDescriptor(builtin::kDivide, Kind::kInt), MatchesOperatorDescriptor(builtin::kDivide, Kind::kDouble), MatchesOperatorDescriptor(builtin::kDivide, Kind::kUint))); EXPECT_THAT(registry.FindStaticOverloads(builtin::kMultiply, false, {Kind::kAny, Kind::kAny}), UnorderedElementsAre( MatchesOperatorDescriptor(builtin::kMultiply, Kind::kInt), MatchesOperatorDescriptor(builtin::kMultiply, Kind::kDouble), MatchesOperatorDescriptor(builtin::kMultiply, Kind::kUint))); EXPECT_THAT(registry.FindStaticOverloads(builtin::kModulo, false, {Kind::kAny, Kind::kAny}), UnorderedElementsAre( MatchesOperatorDescriptor(builtin::kModulo, Kind::kInt), MatchesOperatorDescriptor(builtin::kModulo, Kind::kUint))); EXPECT_THAT(registry.FindStaticOverloads(builtin::kNeg, false, {Kind::kAny}), UnorderedElementsAre(MatchesNegationDescriptor(Kind::kInt), MatchesNegationDescriptor(Kind::kDouble))); } // TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for // evaluator available. } // namespace } // namespace cel ================================================ FILE: runtime/standard/comparison_functions.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/comparison_functions.h" #include #include "absl/status/status.h" #include "absl/time/time.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "internal/number.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { namespace { using ::cel::internal::Number; // Comparison template functions template bool LessThan(Type t1, Type t2) { return (t1 < t2); } template bool LessThanOrEqual(Type t1, Type t2) { return (t1 <= t2); } template bool GreaterThan(Type t1, Type t2) { return LessThan(t2, t1); } template bool GreaterThanOrEqual(Type t1, Type t2) { return LessThanOrEqual(t2, t1); } // String value comparions specializations template <> bool LessThan(const StringValue& t1, const StringValue& t2) { return t1.Compare(t2) < 0; } template <> bool LessThanOrEqual(const StringValue& t1, const StringValue& t2) { return t1.Compare(t2) <= 0; } template <> bool GreaterThan(const StringValue& t1, const StringValue& t2) { return t1.Compare(t2) > 0; } template <> bool GreaterThanOrEqual(const StringValue& t1, const StringValue& t2) { return t1.Compare(t2) >= 0; } // bytes value comparions specializations template <> bool LessThan(const BytesValue& t1, const BytesValue& t2) { return t1.Compare(t2) < 0; } template <> bool LessThanOrEqual(const BytesValue& t1, const BytesValue& t2) { return t1.Compare(t2) <= 0; } template <> bool GreaterThan(const BytesValue& t1, const BytesValue& t2) { return t1.Compare(t2) > 0; } template <> bool GreaterThanOrEqual(const BytesValue& t1, const BytesValue& t2) { return t1.Compare(t2) >= 0; } // Duration comparison specializations template <> bool LessThan(absl::Duration t1, absl::Duration t2) { return absl::operator<(t1, t2); } template <> bool LessThanOrEqual(absl::Duration t1, absl::Duration t2) { return absl::operator<=(t1, t2); } template <> bool GreaterThan(absl::Duration t1, absl::Duration t2) { return absl::operator>(t1, t2); } template <> bool GreaterThanOrEqual(absl::Duration t1, absl::Duration t2) { return absl::operator>=(t1, t2); } // Timestamp comparison specializations template <> bool LessThan(absl::Time t1, absl::Time t2) { return absl::operator<(t1, t2); } template <> bool LessThanOrEqual(absl::Time t1, absl::Time t2) { return absl::operator<=(t1, t2); } template <> bool GreaterThan(absl::Time t1, absl::Time t2) { return absl::operator>(t1, t2); } template <> bool GreaterThanOrEqual(absl::Time t1, absl::Time t2) { return absl::operator>=(t1, t2); } template bool CrossNumericLessThan(T t, U u) { return Number(t) < Number(u); } template bool CrossNumericGreaterThan(T t, U u) { return Number(t) > Number(u); } template bool CrossNumericLessOrEqualTo(T t, U u) { return Number(t) <= Number(u); } template bool CrossNumericGreaterOrEqualTo(T t, U u) { return Number(t) >= Number(u); } template absl::Status RegisterComparisonFunctionsForType( cel::FunctionRegistry& registry) { using FunctionAdapter = BinaryFunctionAdapter; CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kLess, false), FunctionAdapter::WrapFunction(LessThan))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, false), FunctionAdapter::WrapFunction(LessThanOrEqual))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, false), FunctionAdapter::WrapFunction(GreaterThan))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, false), FunctionAdapter::WrapFunction(GreaterThanOrEqual))); return absl::OkStatus(); } absl::Status RegisterHomogenousComparisonFunctions( cel::FunctionRegistry& registry) { CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); return absl::OkStatus(); } template absl::Status RegisterCrossNumericComparisons(cel::FunctionRegistry& registry) { using FunctionAdapter = BinaryFunctionAdapter; CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kLess, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericLessThan))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericGreaterThan))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericGreaterOrEqualTo))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericLessOrEqualTo))); return absl::OkStatus(); } absl::Status RegisterHeterogeneousComparisonFunctions( cel::FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR( (RegisterCrossNumericComparisons(registry))); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); return absl::OkStatus(); } } // namespace absl::Status RegisterComparisonFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { if (options.enable_heterogeneous_equality) { CEL_RETURN_IF_ERROR(RegisterHeterogeneousComparisonFunctions(registry)); } else { CEL_RETURN_IF_ERROR(RegisterHomogenousComparisonFunctions(registry)); } return absl::OkStatus(); } } // namespace cel ================================================ FILE: runtime/standard/comparison_functions.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register built in comparison functions (<, <=, >, >=). // // Most users should prefer to use RegisterBuiltinFunctions. // // This is call is included in RegisterBuiltinFunctions -- calling both // RegisterBuiltinFunctions and RegisterComparisonFunctions directly on the same // registry will result in an error. absl::Status RegisterComparisonFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_COMPARISON_FUNCTIONS_H_ ================================================ FILE: runtime/standard/comparison_functions_test.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/comparison_functions.h" #include #include "absl/strings/str_cat.h" #include "base/builtins.h" #include "common/kind.h" #include "internal/testing.h" namespace cel { namespace { MATCHER_P2(DefinesHomogenousOverload, name, argument_kind, absl::StrCat(name, " for ", KindToString(argument_kind))) { const cel::FunctionRegistry& registry = arg; return !registry .FindStaticOverloads(name, /*receiver_style=*/false, {argument_kind, argument_kind}) .empty(); } constexpr std::array kOrderableTypes = { Kind::kBool, Kind::kInt64, Kind::kUint64, Kind::kString, Kind::kDouble, Kind::kBytes, Kind::kDuration, Kind::kTimestamp}; TEST(RegisterComparisonFunctionsTest, LessThanDefined) { RuntimeOptions default_options; FunctionRegistry registry; ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); for (Kind kind : kOrderableTypes) { EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kLess, kind)); } } TEST(RegisterComparisonFunctionsTest, LessThanOrEqualDefined) { RuntimeOptions default_options; FunctionRegistry registry; ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); for (Kind kind : kOrderableTypes) { EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kLessOrEqual, kind)); } } TEST(RegisterComparisonFunctionsTest, GreaterThanDefined) { RuntimeOptions default_options; FunctionRegistry registry; ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); for (Kind kind : kOrderableTypes) { EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kGreater, kind)); } } TEST(RegisterComparisonFunctionsTest, GreaterThanOrEqualDefined) { RuntimeOptions default_options; FunctionRegistry registry; ASSERT_OK(RegisterComparisonFunctions(registry, default_options)); for (Kind kind : kOrderableTypes) { EXPECT_THAT(registry, DefinesHomogenousOverload(builtin::kGreaterOrEqual, kind)); } } // TODO(uncreated-issue/41): move functional tests from wrapper library after top-level // APIs are available for planning and running an expression. } // namespace } // namespace cel ================================================ FILE: runtime/standard/container_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/container_functions.h" #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "common/values/list_value_builder.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace { absl::StatusOr MapSizeImpl(const MapValue& value) { return value.Size(); } absl::StatusOr ListSizeImpl(const ListValue& value) { return value.Size(); } // Concatenation for CelList type. absl::StatusOr ConcatList( const ListValue& value1, const ListValue& value2, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(auto size1, value1.Size()); if (size1 == 0) { return value2; } CEL_ASSIGN_OR_RETURN(auto size2, value2.Size()); if (size2 == 0) { return value1; } // TODO(uncreated-issue/50): add option for checking lists have homogenous element // types and use a more specialized list type when possible. auto list_builder = NewListValueBuilder(arena); list_builder->Reserve(size1 + size2); for (size_t i = 0; i < size1; i++) { CEL_ASSIGN_OR_RETURN( Value elem, value1.Get(i, descriptor_pool, message_factory, arena)); CEL_RETURN_IF_ERROR(list_builder->Add(std::move(elem))); } for (size_t i = 0; i < size2; i++) { CEL_ASSIGN_OR_RETURN( Value elem, value2.Get(i, descriptor_pool, message_factory, arena)); CEL_RETURN_IF_ERROR(list_builder->Add(std::move(elem))); } return std::move(*list_builder).Build(); } // AppendList will append the elements in value2 to value1. // // This call will only be invoked within comprehensions where `value1` is an // intermediate result which cannot be directly assigned or co-mingled with a // user-provided list. absl::StatusOr AppendList(ListValue value1, const Value& value2) { // The `value1` object cannot be directly addressed and is an intermediate // variable. Once the comprehension completes this value will in effect be // treated as immutable. if (auto mutable_list_value = cel::common_internal::AsMutableListValue(value1); mutable_list_value) { CEL_RETURN_IF_ERROR(mutable_list_value->Append(value2)); return value1; } return absl::InvalidArgumentError("Unexpected call to runtime list append."); } } // namespace absl::Status RegisterContainerFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { // receiver style = true/false // Support both the global and receiver style size() for lists and maps. for (bool receiver_style : {true, false}) { CEL_RETURN_IF_ERROR(registry.Register( cel::UnaryFunctionAdapter, const ListValue&>:: CreateDescriptor(cel::builtin::kSize, receiver_style), UnaryFunctionAdapter, const ListValue&>::WrapFunction(ListSizeImpl))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, const MapValue&>:: CreateDescriptor(cel::builtin::kSize, receiver_style), UnaryFunctionAdapter, const MapValue&>::WrapFunction(MapSizeImpl))); } if (options.enable_list_concat) { CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter< absl::StatusOr, const ListValue&, const ListValue&>::CreateDescriptor(cel::builtin::kAdd, false), BinaryFunctionAdapter, const ListValue&, const ListValue&>::WrapFunction(ConcatList))); } return registry.Register( BinaryFunctionAdapter< absl::StatusOr, ListValue, const Value&>::CreateDescriptor(cel::builtin::kRuntimeListAppend, false), BinaryFunctionAdapter, ListValue, const Value&>::WrapFunction(AppendList)); } } // namespace cel ================================================ FILE: runtime/standard/container_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register built in container functions. // // Most users should prefer to use RegisterBuiltinFunctions. // // This call is included in RegisterBuiltinFunctions -- calling both // RegisterBuiltinFunctions and RegisterContainerFunctions directly on the same // registry will result in an error. absl::Status RegisterContainerFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_FUNCTIONS_H_ ================================================ FILE: runtime/standard/container_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/container_functions.h" #include #include "base/builtins.h" #include "common/function_descriptor.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { const FunctionDescriptor& descriptor = arg.descriptor; const std::vector& types = expected_kinds; return descriptor.name() == name && descriptor.receiver_style() == receiver && descriptor.types() == types; } TEST(RegisterContainerFunctions, RegistersSizeFunctions) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterContainerFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kSize, false, {Kind::kAny}), UnorderedElementsAre(MatchesDescriptor(builtin::kSize, false, std::vector{Kind::kList}), MatchesDescriptor(builtin::kSize, false, std::vector{Kind::kMap}))); EXPECT_THAT( registry.FindStaticOverloads(builtin::kSize, true, {Kind::kAny}), UnorderedElementsAre(MatchesDescriptor(builtin::kSize, true, std::vector{Kind::kList}), MatchesDescriptor(builtin::kSize, true, std::vector{Kind::kMap}))); } TEST(RegisterContainerFunctions, RegisterListConcatEnabled) { FunctionRegistry registry; RuntimeOptions options; options.enable_list_concat = true; ASSERT_OK(RegisterContainerFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kAdd, false, {Kind::kAny, Kind::kAny}), UnorderedElementsAre(MatchesDescriptor( builtin::kAdd, false, std::vector{Kind::kList, Kind::kList}))); } TEST(RegisterContainerFunctions, RegisterListConcateDisabled) { FunctionRegistry registry; RuntimeOptions options; options.enable_list_concat = false; ASSERT_OK(RegisterContainerFunctions(registry, options)); EXPECT_THAT(registry.FindStaticOverloads(builtin::kAdd, false, {Kind::kAny, Kind::kAny}), IsEmpty()); } TEST(RegisterContainerFunctions, RegisterRuntimeListAppend) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterContainerFunctions(registry, options)); EXPECT_THAT(registry.FindStaticOverloads(builtin::kRuntimeListAppend, false, {Kind::kAny, Kind::kAny}), UnorderedElementsAre(MatchesDescriptor( builtin::kRuntimeListAppend, false, std::vector{Kind::kList, Kind::kAny}))); } // TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for // evaluator available. } // namespace } // namespace cel ================================================ FILE: runtime/standard/container_membership_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/container_membership_functions.h" #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "internal/number.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/register_function_helper.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::internal::Number; static constexpr std::array in_operators = { cel::builtin::kIn, // @in for map and list types. cel::builtin::kInFunction, // deprecated in() -- for backwards compat cel::builtin::kInDeprecated, // deprecated _in_ -- for backwards compat }; template bool ValueEquals(const Value& value, T other); template <> bool ValueEquals(const Value& value, bool other) { if (auto bool_value = As(value); bool_value) { return bool_value->NativeValue() == other; } return false; } template <> bool ValueEquals(const Value& value, int64_t other) { if (auto int_value = As(value); int_value) { return int_value->NativeValue() == other; } return false; } template <> bool ValueEquals(const Value& value, uint64_t other) { if (auto uint_value = As(value); uint_value) { return uint_value->NativeValue() == other; } return false; } template <> bool ValueEquals(const Value& value, double other) { if (auto double_value = As(value); double_value) { return double_value->NativeValue() == other; } return false; } template <> bool ValueEquals(const Value& value, const StringValue& other) { if (auto string_value = As(value); string_value) { return string_value->Equals(other); } return false; } template <> bool ValueEquals(const Value& value, const BytesValue& other) { if (auto bytes_value = As(value); bytes_value) { return bytes_value->Equals(other); } return false; } // Template function implementing CEL in() function template absl::StatusOr In( T value, const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(auto size, list.Size()); Value element; for (int i = 0; i < size; i++) { CEL_RETURN_IF_ERROR( list.Get(i, descriptor_pool, message_factory, arena, &element)); if (ValueEquals(element, value)) { return true; } } return false; } // Implementation for @in operator using heterogeneous equality. absl::StatusOr HeterogeneousEqualityIn( const Value& value, const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { return list.Contains(value, descriptor_pool, message_factory, arena); } absl::Status RegisterListMembershipFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { for (absl::string_view op : in_operators) { if (options.enable_heterogeneous_equality) { CEL_RETURN_IF_ERROR( (RegisterHelper, const Value&, const ListValue&>>:: RegisterGlobalOverload(op, &HeterogeneousEqualityIn, registry))); } else { CEL_RETURN_IF_ERROR( (RegisterHelper, bool, const ListValue&>>:: RegisterGlobalOverload(op, In, registry))); CEL_RETURN_IF_ERROR( (RegisterHelper, int64_t, const ListValue&>>:: RegisterGlobalOverload(op, In, registry))); CEL_RETURN_IF_ERROR( (RegisterHelper, uint64_t, const ListValue&>>:: RegisterGlobalOverload(op, In, registry))); CEL_RETURN_IF_ERROR( (RegisterHelper, double, const ListValue&>>:: RegisterGlobalOverload(op, In, registry))); CEL_RETURN_IF_ERROR( (RegisterHelper, const StringValue&, const ListValue&>>:: RegisterGlobalOverload(op, In, registry))); CEL_RETURN_IF_ERROR( (RegisterHelper, const BytesValue&, const ListValue&>>:: RegisterGlobalOverload(op, In, registry))); } } return absl::OkStatus(); } absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { const bool enable_heterogeneous_equality = options.enable_heterogeneous_equality; auto boolKeyInSet = [enable_heterogeneous_equality]( bool key, const MapValue& map_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { Value has; CEL_RETURN_IF_ERROR(map_value.Has(BoolValue(key), descriptor_pool, message_factory, arena, &has)); if (has.IsTrue()) { return has; } if (enable_heterogeneous_equality) { return BoolValue(false); } return has; }; auto intKeyInSet = [enable_heterogeneous_equality]( int64_t key, const MapValue& map_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { Value result; CEL_RETURN_IF_ERROR(map_value.Has(IntValue(key), descriptor_pool, message_factory, arena, &result)); if (enable_heterogeneous_equality) { if (result.IsTrue()) { return result; } Number number = Number::FromInt64(key); if (number.LosslessConvertibleToUint()) { Value result_alt; CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), descriptor_pool, message_factory, arena, &result_alt)); if (result_alt.IsTrue()) { return result_alt; } } return BoolValue(false); } return result; }; auto stringKeyInSet = [enable_heterogeneous_equality]( const StringValue& key, const MapValue& map_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { Value result; CEL_RETURN_IF_ERROR( map_value.Has(key, descriptor_pool, message_factory, arena, &result)); if (result.IsBool()) { return result; } if (enable_heterogeneous_equality) { return BoolValue(false); } return result; }; auto uintKeyInSet = [enable_heterogeneous_equality]( uint64_t key, const MapValue& map_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { Value has; CEL_RETURN_IF_ERROR(map_value.Has(UintValue(key), descriptor_pool, message_factory, arena, &has)); if (enable_heterogeneous_equality) { if (has.IsTrue()) { return has; } Value has_alt; Number number = Number::FromUint64(key); if (number.LosslessConvertibleToInt()) { CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), descriptor_pool, message_factory, arena, &has_alt)); if (has.IsTrue()) { return has; } } return BoolValue(false); } return has; }; auto doubleKeyInSet = [](double key, const MapValue& map_value, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { Number number = Number::FromDouble(key); if (number.LosslessConvertibleToInt()) { Value has; CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), descriptor_pool, message_factory, arena, &has)); if (has.IsTrue()) { return has; } } if (number.LosslessConvertibleToUint()) { Value has; CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), descriptor_pool, message_factory, arena, &has)); if (has.IsTrue()) { return has; } } return BoolValue(false); }; for (auto op : in_operators) { auto status = RegisterHelper, const StringValue&, const MapValue&>>::RegisterGlobalOverload(op, stringKeyInSet, registry); if (!status.ok()) return status; status = RegisterHelper< BinaryFunctionAdapter, bool, const MapValue&>>:: RegisterGlobalOverload(op, boolKeyInSet, registry); if (!status.ok()) return status; status = RegisterHelper, int64_t, const MapValue&>>:: RegisterGlobalOverload(op, intKeyInSet, registry); if (!status.ok()) return status; status = RegisterHelper, uint64_t, const MapValue&>>:: RegisterGlobalOverload(op, uintKeyInSet, registry); if (!status.ok()) return status; if (enable_heterogeneous_equality) { status = RegisterHelper, double, const MapValue&>>:: RegisterGlobalOverload(op, doubleKeyInSet, registry); if (!status.ok()) return status; } } return absl::OkStatus(); } } // namespace absl::Status RegisterContainerMembershipFunctions( FunctionRegistry& registry, const RuntimeOptions& options) { if (options.enable_list_contains) { CEL_RETURN_IF_ERROR(RegisterListMembershipFunctions(registry, options)); } return RegisterMapMembershipFunctions(registry, options); } } // namespace cel ================================================ FILE: runtime/standard/container_membership_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_MEMBERSHIP_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_MEMBERSHIP_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register container membership functions // in and in . // // The in operator follows the same behavior as equality, following the // .enable_heterogeneous_equality option. absl::Status RegisterContainerMembershipFunctions( FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_CONTAINER_MEMBERSHIP_FUNCTIONS_H_ ================================================ FILE: runtime/standard/container_membership_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/container_membership_functions.h" #include #include #include "absl/strings/string_view.h" #include "base/builtins.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "internal/testing.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { namespace { using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { const FunctionDescriptor& descriptor = *arg; const std::vector& types = expected_kinds; return descriptor.name() == name && descriptor.receiver_style() == receiver && descriptor.types() == types; } static constexpr std::array kInOperators = { builtin::kIn, builtin::kInDeprecated, builtin::kInFunction}; TEST(RegisterContainerMembershipFunctions, RegistersHomogeneousInOperator) { FunctionRegistry registry; RuntimeOptions options; options.enable_heterogeneous_equality = false; ASSERT_OK(RegisterContainerMembershipFunctions(registry, options)); auto overloads = registry.ListFunctions(); for (absl::string_view operator_name : kInOperators) { EXPECT_THAT( overloads[operator_name], UnorderedElementsAre( MatchesDescriptor(operator_name, false, std::vector{Kind::kInt, Kind::kList}), MatchesDescriptor(operator_name, false, std::vector{Kind::kUint, Kind::kList}), MatchesDescriptor(operator_name, false, std::vector{Kind::kDouble, Kind::kList}), MatchesDescriptor(operator_name, false, std::vector{Kind::kString, Kind::kList}), MatchesDescriptor(operator_name, false, std::vector{Kind::kBytes, Kind::kList}), MatchesDescriptor(operator_name, false, std::vector{Kind::kBool, Kind::kList}), MatchesDescriptor(operator_name, false, std::vector{Kind::kInt, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kUint, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kString, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kBool, Kind::kMap}))); } } TEST(RegisterContainerMembershipFunctions, RegistersHeterogeneousInOperation) { FunctionRegistry registry; RuntimeOptions options; options.enable_heterogeneous_equality = true; ASSERT_OK(RegisterContainerMembershipFunctions(registry, options)); auto overloads = registry.ListFunctions(); for (absl::string_view operator_name : kInOperators) { EXPECT_THAT( overloads[operator_name], UnorderedElementsAre( MatchesDescriptor(operator_name, false, std::vector{Kind::kAny, Kind::kList}), MatchesDescriptor(operator_name, false, std::vector{Kind::kInt, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kUint, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kDouble, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kString, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kBool, Kind::kMap}))); } } TEST(RegisterContainerMembershipFunctions, RegistersInOperatorListsDisabled) { FunctionRegistry registry; RuntimeOptions options; options.enable_list_contains = false; ASSERT_OK(RegisterContainerMembershipFunctions(registry, options)); auto overloads = registry.ListFunctions(); for (absl::string_view operator_name : kInOperators) { EXPECT_THAT( overloads[operator_name], UnorderedElementsAre( MatchesDescriptor(operator_name, false, std::vector{Kind::kInt, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kUint, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kDouble, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kString, Kind::kMap}), MatchesDescriptor(operator_name, false, std::vector{Kind::kBool, Kind::kMap}))); } } // TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for // evaluator available. } // namespace } // namespace cel ================================================ FILE: runtime/standard/equality_functions.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/equality_functions.h" #include #include #include #include #include #include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "common/value_kind.h" #include "internal/number.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/internal/errors.h" #include "runtime/register_function_helper.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::cel::builtin::kEqual; using ::cel::builtin::kInequal; using ::cel::internal::Number; // Declaration for the functors for generic equality operator. // Equal only defined for same-typed values. // Nullopt is returned if equality is not defined. struct HomogenousEqualProvider { static constexpr bool kIsHeterogeneous = false; absl::StatusOr> operator()( const Value& lhs, const Value& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; }; // Equal defined between compatible types. // Nullopt is returned if equality is not defined. struct HeterogeneousEqualProvider { static constexpr bool kIsHeterogeneous = true; absl::StatusOr> operator()( const Value& lhs, const Value& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; }; // Comparison template functions template absl::optional Inequal(Type lhs, Type rhs) { return lhs != rhs; } template <> absl::optional Inequal(const StringValue& lhs, const StringValue& rhs) { return !lhs.Equals(rhs); } template <> absl::optional Inequal(const BytesValue& lhs, const BytesValue& rhs) { return !lhs.Equals(rhs); } template <> absl::optional Inequal(const NullValue&, const NullValue&) { return false; } template <> absl::optional Inequal(const TypeValue& lhs, const TypeValue& rhs) { return lhs.name() != rhs.name(); } template absl::optional Equal(Type lhs, Type rhs) { return lhs == rhs; } template <> absl::optional Equal(const StringValue& lhs, const StringValue& rhs) { return lhs.Equals(rhs); } template <> absl::optional Equal(const BytesValue& lhs, const BytesValue& rhs) { return lhs.Equals(rhs); } template <> absl::optional Equal(const NullValue&, const NullValue&) { return true; } template <> absl::optional Equal(const TypeValue& lhs, const TypeValue& rhs) { return lhs.name() == rhs.name(); } // Equality for lists. Template parameter provides either heterogeneous or // homogenous equality for comparing members. template absl::StatusOr> ListEqual( const ListValue& lhs, const ListValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (&lhs == &rhs) { return true; } CEL_ASSIGN_OR_RETURN(auto lhs_size, lhs.Size()); CEL_ASSIGN_OR_RETURN(auto rhs_size, rhs.Size()); if (lhs_size != rhs_size) { return false; } for (int i = 0; i < lhs_size; ++i) { CEL_ASSIGN_OR_RETURN(auto lhs_i, lhs.Get(i, descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN(auto rhs_i, rhs.Get(i, descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN(absl::optional eq, EqualsProvider()(lhs_i, rhs_i, descriptor_pool, message_factory, arena)); if (!eq.has_value() || !*eq) { return eq; } } return true; } // Opaque types only support heterogeneous equality, and by extension that means // optionals. Heterogeneous equality being enabled is enforced by // `EnableOptionalTypes`. absl::StatusOr> OpaqueEqual( const OpaqueValue& lhs, const OpaqueValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { Value result; CEL_RETURN_IF_ERROR( lhs.Equal(rhs, descriptor_pool, message_factory, arena, &result)); if (auto bool_value = result.AsBool(); bool_value) { return bool_value->NativeValue(); } return TypeConversionError(result.GetTypeName(), "bool").NativeValue(); } absl::optional NumberFromValue(const Value& value) { if (value.Is()) { return Number::FromInt64(value.GetInt().NativeValue()); } else if (value.Is()) { return Number::FromUint64(value.GetUint().NativeValue()); } else if (value.Is()) { return Number::FromDouble(value.GetDouble().NativeValue()); } return absl::nullopt; } absl::StatusOr> CheckAlternativeNumericType( const Value& key, const MapValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { absl::optional number = NumberFromValue(key); if (!number.has_value()) { return absl::nullopt; } if (!key.IsInt() && number->LosslessConvertibleToInt()) { absl::optional entry; CEL_ASSIGN_OR_RETURN(entry, rhs.Find(IntValue(number->AsInt()), descriptor_pool, message_factory, arena)); if (entry) { return entry; } } if (!key.IsUint() && number->LosslessConvertibleToUint()) { absl::optional entry; CEL_ASSIGN_OR_RETURN(entry, rhs.Find(UintValue(number->AsUint()), descriptor_pool, message_factory, arena)); if (entry) { return entry; } } return absl::nullopt; } // Equality for maps. Template parameter provides either heterogeneous or // homogenous equality for comparing values. template absl::StatusOr> MapEqual( const MapValue& lhs, const MapValue& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (&lhs == &rhs) { return true; } if (lhs.Size() != rhs.Size()) { return false; } CEL_ASSIGN_OR_RETURN(auto iter, lhs.NewIterator()); while (iter->HasNext()) { CEL_ASSIGN_OR_RETURN(auto lhs_key, iter->Next(descriptor_pool, message_factory, arena)); absl::optional entry; CEL_ASSIGN_OR_RETURN( entry, rhs.Find(lhs_key, descriptor_pool, message_factory, arena)); if (!entry && EqualsProvider::kIsHeterogeneous) { CEL_ASSIGN_OR_RETURN( entry, CheckAlternativeNumericType(lhs_key, rhs, descriptor_pool, message_factory, arena)); } if (!entry) { return false; } CEL_ASSIGN_OR_RETURN(auto lhs_value, lhs.Get(lhs_key, descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN(absl::optional eq, EqualsProvider()(lhs_value, *entry, descriptor_pool, message_factory, arena)); if (!eq.has_value() || !*eq) { return eq; } } return true; } // Helper for wrapping ==/!= implementations. // Name should point to a static constexpr string so the lambda capture is safe. template std::function WrapComparison(Op op, absl::string_view name) { return [op = std::move(op), name]( Type lhs, Type rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> Value { absl::optional result = op(lhs, rhs); if (result.has_value()) { return BoolValue(*result); } return ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError(name)); }; } // Helper method // // Registers all equality functions for template parameters type. template absl::Status RegisterEqualityFunctionsForType(cel::FunctionRegistry& registry) { using FunctionAdapter = cel::RegisterHelper>; // Inequality CEL_RETURN_IF_ERROR(FunctionAdapter::RegisterGlobalOverload( kInequal, WrapComparison(&Inequal, kInequal), registry)); // Equality CEL_RETURN_IF_ERROR(FunctionAdapter::RegisterGlobalOverload( kEqual, WrapComparison(&Equal, kEqual), registry)); return absl::OkStatus(); } template auto ComplexEquality(Op&& op) { return [op = std::forward(op)]( const Type& t1, const Type& t2, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { CEL_ASSIGN_OR_RETURN(absl::optional result, op(t1, t2, descriptor_pool, message_factory, arena)); if (!result.has_value()) { return ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError(kEqual)); } return BoolValue(*result); }; } template auto ComplexInequality(Op&& op) { return [op = std::forward(op)]( Type t1, Type t2, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { CEL_ASSIGN_OR_RETURN(absl::optional result, op(t1, t2, descriptor_pool, message_factory, arena)); if (!result.has_value()) { return ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError(kInequal)); } return BoolValue(!*result); }; } template absl::Status RegisterComplexEqualityFunctionsForType( absl::FunctionRef>( Type, Type, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull)> op, cel::FunctionRegistry& registry) { using FunctionAdapter = cel::RegisterHelper< BinaryFunctionAdapter, Type, Type>>; // Inequality CEL_RETURN_IF_ERROR(FunctionAdapter::RegisterGlobalOverload( kInequal, ComplexInequality(op), registry)); // Equality CEL_RETURN_IF_ERROR(FunctionAdapter::RegisterGlobalOverload( kEqual, ComplexEquality(op), registry)); return absl::OkStatus(); } absl::Status RegisterHomogenousEqualityFunctions( cel::FunctionRegistry& registry) { CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR(RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterEqualityFunctionsForType(registry)); CEL_RETURN_IF_ERROR( RegisterComplexEqualityFunctionsForType( &ListEqual, registry)); CEL_RETURN_IF_ERROR( RegisterComplexEqualityFunctionsForType( &MapEqual, registry)); return absl::OkStatus(); } absl::Status RegisterNullMessageEqualityFunctions(FunctionRegistry& registry) { // equals CEL_RETURN_IF_ERROR( (cel::RegisterHelper< BinaryFunctionAdapter>:: RegisterGlobalOverload( kEqual, [](const StructValue&, const NullValue&) { return false; }, registry))); CEL_RETURN_IF_ERROR( (cel::RegisterHelper< BinaryFunctionAdapter>:: RegisterGlobalOverload( kEqual, [](const NullValue&, const StructValue&) { return false; }, registry))); // inequals CEL_RETURN_IF_ERROR( (cel::RegisterHelper< BinaryFunctionAdapter>:: RegisterGlobalOverload( kInequal, [](const StructValue&, const NullValue&) { return true; }, registry))); return cel::RegisterHelper< BinaryFunctionAdapter>:: RegisterGlobalOverload( kInequal, [](const NullValue&, const StructValue&) { return true; }, registry); } template absl::StatusOr> HomogenousValueEqual( const Value& v1, const Value& v2, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (v1.kind() != v2.kind()) { return absl::nullopt; } static_assert(std::is_lvalue_reference_v, "unexpected value copy"); switch (v1->kind()) { case ValueKind::kBool: return Equal(v1.GetBool().NativeValue(), v2.GetBool().NativeValue()); case ValueKind::kNull: return Equal(v1.GetNull(), v2.GetNull()); case ValueKind::kInt: return Equal(v1.GetInt().NativeValue(), v2.GetInt().NativeValue()); case ValueKind::kUint: return Equal(v1.GetUint().NativeValue(), v2.GetUint().NativeValue()); case ValueKind::kDouble: return Equal(v1.GetDouble().NativeValue(), v2.GetDouble().NativeValue()); case ValueKind::kDuration: return Equal(v1.GetDuration().NativeValue(), v2.GetDuration().NativeValue()); case ValueKind::kTimestamp: return Equal(v1.GetTimestamp().NativeValue(), v2.GetTimestamp().NativeValue()); case ValueKind::kCelType: return Equal(v1.GetType(), v2.GetType()); case ValueKind::kString: return Equal(v1.GetString(), v2.GetString()); case ValueKind::kBytes: return Equal(v1.GetBytes(), v2.GetBytes()); case ValueKind::kList: return ListEqual(v1.GetList(), v2.GetList(), descriptor_pool, message_factory, arena); case ValueKind::kMap: return MapEqual(v1.GetMap(), v2.GetMap(), descriptor_pool, message_factory, arena); case ValueKind::kOpaque: return OpaqueEqual(v1.GetOpaque(), v2.GetOpaque(), descriptor_pool, message_factory, arena); default: return absl::nullopt; } } absl::StatusOr EqualOverloadImpl( const Value& lhs, const Value& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(absl::optional result, runtime_internal::ValueEqualImpl( lhs, rhs, descriptor_pool, message_factory, arena)); if (result.has_value()) { return BoolValue(*result); } return ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError(kEqual)); } absl::StatusOr InequalOverloadImpl( const Value& lhs, const Value& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { CEL_ASSIGN_OR_RETURN(absl::optional result, runtime_internal::ValueEqualImpl( lhs, rhs, descriptor_pool, message_factory, arena)); if (result.has_value()) { return BoolValue(!*result); } return ErrorValue( cel::runtime_internal::CreateNoMatchingOverloadError(kInequal)); } absl::Status RegisterHeterogeneousEqualityFunctions( cel::FunctionRegistry& registry) { using Adapter = cel::RegisterHelper< BinaryFunctionAdapter, const Value&, const Value&>>; CEL_RETURN_IF_ERROR( Adapter::RegisterGlobalOverload(kEqual, &EqualOverloadImpl, registry)); CEL_RETURN_IF_ERROR(Adapter::RegisterGlobalOverload( kInequal, &InequalOverloadImpl, registry)); return absl::OkStatus(); } absl::StatusOr> HomogenousEqualProvider::operator()( const Value& lhs, const Value& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { return HomogenousValueEqual( lhs, rhs, descriptor_pool, message_factory, arena); } absl::StatusOr> HeterogeneousEqualProvider::operator()( const Value& lhs, const Value& rhs, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const { return runtime_internal::ValueEqualImpl(lhs, rhs, descriptor_pool, message_factory, arena); } } // namespace namespace runtime_internal { absl::StatusOr> ValueEqualImpl( const Value& v1, const Value& v2, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { if (v1.kind() == v2.kind()) { if (v1.IsStruct() && v2.IsStruct()) { CEL_ASSIGN_OR_RETURN( Value result, v1.GetStruct().Equal(v2, descriptor_pool, message_factory, arena)); if (result.IsBool()) { return result.GetBool().NativeValue(); } return false; } return HomogenousValueEqual( v1, v2, descriptor_pool, message_factory, arena); } absl::optional lhs = NumberFromValue(v1); absl::optional rhs = NumberFromValue(v2); if (rhs.has_value() && lhs.has_value()) { return *lhs == *rhs; } // TODO(uncreated-issue/6): It's currently possible for the interpreter to create a // map containing an Error. Return no matching overload to propagate an error // instead of a false result. if (v1.IsError() || v1.IsUnknown() || v2.IsError() || v2.IsUnknown()) { return absl::nullopt; } return false; } } // namespace runtime_internal absl::Status RegisterEqualityFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { if (options.enable_heterogeneous_equality) { if (options.enable_fast_builtins) { // If enabled, the evaluator provides an implementation that works // directly on the value stack. return absl::OkStatus(); } // Heterogeneous equality uses one generic overload that delegates to the // right equality implementation at runtime. CEL_RETURN_IF_ERROR(RegisterHeterogeneousEqualityFunctions(registry)); } else { CEL_RETURN_IF_ERROR(RegisterHomogenousEqualityFunctions(registry)); CEL_RETURN_IF_ERROR(RegisterNullMessageEqualityFunctions(registry)); } return absl::OkStatus(); } } // namespace cel ================================================ FILE: runtime/standard/equality_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_EQUALITY_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_EQUALITY_FUNCTIONS_H_ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "common/value.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace runtime_internal { // Exposed implementation for == operator. This is used to implement other // runtime functions. // // Nullopt is returned if the comparison is undefined (e.g. special value types // error and unknown). absl::StatusOr> ValueEqualImpl( const Value& v1, const Value& v2, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena); } // namespace runtime_internal // Register equality functions // ==, != // // options.enable_heterogeneous_equality controls which flavor of equality is // used. // // For legacy equality (.enable_heterogeneous_equality = false), equality is // defined between same-typed values only. // // For the CEL specification's definition of equality // (.enable_heterogeneous_equality = true), equality is defined between most // types, with false returned if the two different types are incomparable. absl::Status RegisterEqualityFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_EQUALITY_FUNCTIONS_H_ ================================================ FILE: runtime/standard/equality_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/equality_functions.h" #include #include "absl/status/status_matchers.h" #include "base/builtins.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "internal/testing.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { const FunctionDescriptor& descriptor = *arg; const std::vector& types = expected_kinds; return descriptor.name() == name && descriptor.receiver_style() == receiver && descriptor.types() == types; } TEST(RegisterEqualityFunctionsHomogeneous, RegistersEqualOperators) { FunctionRegistry registry; RuntimeOptions options; options.enable_heterogeneous_equality = false; ASSERT_THAT(RegisterEqualityFunctions(registry, options), IsOk()); auto overloads = registry.ListFunctions(); EXPECT_THAT( overloads[builtin::kEqual], UnorderedElementsAre( MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kList, Kind::kList}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kMap, Kind::kMap}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kBool, Kind::kBool}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kInt, Kind::kInt}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kUint, Kind::kUint}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kDouble, Kind::kDouble}), MatchesDescriptor( builtin::kEqual, false, std::vector{Kind::kDuration, Kind::kDuration}), MatchesDescriptor( builtin::kEqual, false, std::vector{Kind::kTimestamp, Kind::kTimestamp}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kString, Kind::kString}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kBytes, Kind::kBytes}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kType, Kind::kType}), // Structs comparable to null, but struct == struct undefined. MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kStruct, Kind::kNullType}), MatchesDescriptor(builtin::kEqual, false, std::vector{Kind::kNullType, Kind::kStruct}), MatchesDescriptor( builtin::kEqual, false, std::vector{Kind::kNullType, Kind::kNullType}))); EXPECT_THAT( overloads[builtin::kInequal], UnorderedElementsAre( MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kList, Kind::kList}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kMap, Kind::kMap}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kBool, Kind::kBool}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kInt, Kind::kInt}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kUint, Kind::kUint}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kDouble, Kind::kDouble}), MatchesDescriptor( builtin::kInequal, false, std::vector{Kind::kDuration, Kind::kDuration}), MatchesDescriptor( builtin::kInequal, false, std::vector{Kind::kTimestamp, Kind::kTimestamp}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kString, Kind::kString}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kBytes, Kind::kBytes}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kType, Kind::kType}), // Structs comparable to null, but struct != struct undefined. MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kStruct, Kind::kNullType}), MatchesDescriptor(builtin::kInequal, false, std::vector{Kind::kNullType, Kind::kStruct}), MatchesDescriptor( builtin::kInequal, false, std::vector{Kind::kNullType, Kind::kNullType}))); } TEST(RegisterEqualityFunctionsHeterogeneous, RegistersEqualOperators) { FunctionRegistry registry; RuntimeOptions options; options.enable_heterogeneous_equality = true; options.enable_fast_builtins = false; ASSERT_THAT(RegisterEqualityFunctions(registry, options), IsOk()); auto overloads = registry.ListFunctions(); EXPECT_THAT( overloads[builtin::kEqual], UnorderedElementsAre(MatchesDescriptor( builtin::kEqual, false, std::vector{Kind::kAny, Kind::kAny}))); EXPECT_THAT(overloads[builtin::kInequal], UnorderedElementsAre(MatchesDescriptor( builtin::kInequal, false, std::vector{Kind::kAny, Kind::kAny}))); } TEST(RegisterEqualityFunctionsHeterogeneous, NotRegisteredWhenFastBuiltinsEnabled) { FunctionRegistry registry; RuntimeOptions options; options.enable_heterogeneous_equality = true; options.enable_fast_builtins = true; ASSERT_THAT(RegisterEqualityFunctions(registry, options), IsOk()); auto overloads = registry.ListFunctions(); EXPECT_THAT(overloads[builtin::kEqual], IsEmpty()); EXPECT_THAT(overloads[builtin::kInequal], IsEmpty()); } // TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for // evaluator available. } // namespace } // namespace cel ================================================ FILE: runtime/standard/logical_functions.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/logical_functions.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/internal/errors.h" #include "runtime/register_function_helper.h" #include "runtime/runtime_options.h" namespace cel { namespace { using ::cel::runtime_internal::CreateNoMatchingOverloadError; Value NotStrictlyFalseImpl(const Value& value) { if (value.IsBool()) { return value; } if (value.IsError() || value.IsUnknown()) { return TrueValue(); } // Should only accept bool unknown or error. return ErrorValue(CreateNoMatchingOverloadError(builtin::kNotStrictlyFalse)); } } // namespace absl::Status RegisterLogicalFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { // logical NOT CEL_RETURN_IF_ERROR( (RegisterHelper>::RegisterGlobalOverload( builtin::kNot, [](bool value) -> bool { return !value; }, registry))); // Strictness using StrictnessHelper = RegisterHelper>; CEL_RETURN_IF_ERROR(StrictnessHelper::RegisterNonStrictOverload( builtin::kNotStrictlyFalse, &NotStrictlyFalseImpl, registry)); CEL_RETURN_IF_ERROR(StrictnessHelper::RegisterNonStrictOverload( builtin::kNotStrictlyFalseDeprecated, &NotStrictlyFalseImpl, registry)); return absl::OkStatus(); } } // namespace cel ================================================ FILE: runtime/standard/logical_functions.h ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_LOGICAL_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_LOGICAL_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register logical operators ! and @not_strictly_false. // // &&, ||, ?: are special cased by the interpreter (not implemented via the // function registry.) // // Most users should use RegisterBuiltinFunctions, which includes these // definitions. absl::Status RegisterLogicalFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_LOGICAL_FUNCTIONS_H_ ================================================ FILE: runtime/standard/logical_functions_test.cc ================================================ // Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/logical_functions.h" #include #include #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "base/builtins.h" #include "common/function_descriptor.h" #include "common/kind.h" #include "common/value.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "internal/testing_message_factory.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace { using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::Matcher; using ::testing::Truly; MATCHER_P3(DescriptorIs, name, arg_kinds, is_receiver, "") { const FunctionOverloadReference& ref = arg; const FunctionDescriptor& descriptor = ref.descriptor; return descriptor.name() == name && descriptor.ShapeMatches(is_receiver, arg_kinds); } MATCHER_P(IsBool, expected, "") { const Value& value = arg; return value->Is() && value.GetBool().NativeValue() == expected; } // TODO(uncreated-issue/48): replace this with a parsed expr when the non-protobuf // parser is available. absl::StatusOr TestDispatchToFunction( const FunctionRegistry& registry, absl::string_view simple_name, absl::Span args, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { std::vector arg_matcher_; arg_matcher_.reserve(args.size()); for (const auto& value : args) { arg_matcher_.push_back(ValueKindToKind(value->kind())); } std::vector refs = registry.FindStaticOverloads( simple_name, /*receiver_style=*/false, arg_matcher_); if (refs.size() != 1) { return absl::InvalidArgumentError("ambiguous overloads"); } return refs[0].implementation.Invoke(args, descriptor_pool, message_factory, arena); } TEST(RegisterLogicalFunctions, NotStrictlyFalseRegistered) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterLogicalFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kNotStrictlyFalse, /*receiver_style=*/false, {Kind::kAny}), ElementsAre(DescriptorIs(builtin::kNotStrictlyFalse, std::vector{Kind::kBool}, false))); } TEST(RegisterLogicalFunctions, LogicalNotRegistered) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterLogicalFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kNot, /*receiver_style=*/false, {Kind::kAny}), ElementsAre( DescriptorIs(builtin::kNot, std::vector{Kind::kBool}, false))); } struct TestCase { using ArgumentFactory = std::function()>; std::string function; ArgumentFactory arguments; absl::StatusOr> result_matcher; }; class LogicalFunctionsTest : public testing::TestWithParam { protected: google::protobuf::Arena arena_; }; TEST_P(LogicalFunctionsTest, Runner) { const TestCase& test_case = GetParam(); cel::FunctionRegistry registry; ASSERT_OK(RegisterLogicalFunctions(registry, RuntimeOptions())); std::vector args = test_case.arguments(); absl::StatusOr result = TestDispatchToFunction( registry, test_case.function, args, cel::internal::GetTestingDescriptorPool(), cel::internal::GetTestingMessageFactory(), &arena_); EXPECT_EQ(result.ok(), test_case.result_matcher.ok()); if (!test_case.result_matcher.ok()) { EXPECT_EQ(result.status().code(), test_case.result_matcher.status().code()); EXPECT_THAT(result.status().message(), HasSubstr(test_case.result_matcher.status().message())); } else { ASSERT_TRUE(result.ok()) << "unexpected error" << result.status(); EXPECT_THAT(*result, *test_case.result_matcher); } } INSTANTIATE_TEST_SUITE_P( Cases, LogicalFunctionsTest, testing::ValuesIn(std::vector{ TestCase{builtin::kNot, []() -> std::vector { return {BoolValue(true)}; }, IsBool(false)}, TestCase{builtin::kNot, []() -> std::vector { return {BoolValue(false)}; }, IsBool(true)}, TestCase{builtin::kNot, []() -> std::vector { return {BoolValue(true), BoolValue(false)}; }, absl::InvalidArgumentError("")}, TestCase{builtin::kNotStrictlyFalse, []() -> std::vector { return {BoolValue(true)}; }, IsBool(true)}, TestCase{builtin::kNotStrictlyFalse, []() -> std::vector { return {BoolValue(false)}; }, IsBool(false)}, TestCase{builtin::kNotStrictlyFalse, []() -> std::vector { return {ErrorValue(absl::InternalError("test"))}; }, IsBool(true)}, TestCase{builtin::kNotStrictlyFalse, []() -> std::vector { return {UnknownValue()}; }, IsBool(true)}, TestCase{builtin::kNotStrictlyFalse, []() -> std::vector { return {IntValue(42)}; }, Truly([](const Value& v) { return v->Is() && absl::StrContains( v.GetError().NativeValue().message(), "No matching overloads"); })}, })); } // namespace } // namespace cel ================================================ FILE: runtime/standard/regex_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/regex_functions.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "internal/re2_options.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "re2/re2.h" namespace cel { namespace {} // namespace absl::Status RegisterRegexFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { if (options.enable_regex) { auto regex_matches = [max_size = options.regex_max_program_size]( const StringValue& target, const StringValue& regex) -> Value { RE2 re2(regex.ToString(), cel::internal::MakeRE2Options()); CEL_RETURN_IF_ERROR(cel::internal::CheckRE2(re2, max_size)) .With(ErrorValueReturn()); return BoolValue(RE2::PartialMatch(target.ToString(), re2)); }; // bind str.matches(re) and matches(str, re) for (bool receiver_style : {true, false}) { using MatchFnAdapter = BinaryFunctionAdapter; CEL_RETURN_IF_ERROR( registry.Register(MatchFnAdapter::CreateDescriptor( cel::builtin::kRegexMatch, receiver_style), MatchFnAdapter::WrapFunction(regex_matches))); } } // if options.enable_regex return absl::OkStatus(); } } // namespace cel ================================================ FILE: runtime/standard/regex_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_REGEX_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_REGEX_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register builtin regex functions: // // (string).matches(re:string) -> bool // matches(string, re:string) -> bool // // These are implemented with RE2. // // Most users should use RegisterBuiltinFunctions, which includes these // definitions. absl::Status RegisterRegexFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_REGEX_FUNCTIONS_H_ ================================================ FILE: runtime/standard/regex_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/regex_functions.h" #include #include "base/builtins.h" #include "common/function_descriptor.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; enum class CallStyle { kFree, kReceiver }; MATCHER_P2(MatchesDescriptor, name, call_style, "") { bool receiver_style; switch (call_style) { case CallStyle::kReceiver: receiver_style = true; break; case CallStyle::kFree: receiver_style = false; break; } const FunctionDescriptor& descriptor = *arg; std::vector types{Kind::kString, Kind::kString}; return descriptor.name() == name && descriptor.receiver_style() == receiver_style && descriptor.types() == types; } TEST(RegisterRegexFunctions, Registered) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterRegexFunctions(registry, options)); auto overloads = registry.ListFunctions(); EXPECT_THAT(overloads[builtin::kRegexMatch], UnorderedElementsAre( MatchesDescriptor(builtin::kRegexMatch, CallStyle::kReceiver), MatchesDescriptor(builtin::kRegexMatch, CallStyle::kFree))); } TEST(RegisterRegexFunctions, NotRegisteredIfDisabled) { FunctionRegistry registry; RuntimeOptions options; options.enable_regex = false; ASSERT_OK(RegisterRegexFunctions(registry, options)); auto overloads = registry.ListFunctions(); EXPECT_THAT(overloads[builtin::kRegexMatch], IsEmpty()); } // TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for // evaluator available. } // namespace } // namespace cel ================================================ FILE: runtime/standard/string_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/string_functions.h" #include #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { namespace { // Concatenation for string type. absl::StatusOr ConcatString( const StringValue& value1, const StringValue& value2, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { return StringValue::Concat(value1, value2, arena); } // Concatenation for bytes type. absl::StatusOr ConcatBytes( const BytesValue& value1, const BytesValue& value2, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { return BytesValue::Concat(value1, value2, arena); } bool StringContains(const StringValue& value, const StringValue& substr) { return value.Contains(substr); } bool StringEndsWith(const StringValue& value, const StringValue& suffix) { return value.EndsWith(suffix); } bool StringStartsWith(const StringValue& value, const StringValue& prefix) { return value.StartsWith(prefix); } absl::Status RegisterSizeFunctions(FunctionRegistry& registry) { // String size auto size_func = [](const StringValue& value) -> int64_t { return value.Size(); }; // Support global and receiver style size() operations on strings. using StrSizeFnAdapter = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR(StrSizeFnAdapter::RegisterGlobalOverload( cel::builtin::kSize, size_func, registry)); CEL_RETURN_IF_ERROR(StrSizeFnAdapter::RegisterMemberOverload( cel::builtin::kSize, size_func, registry)); // Bytes size auto bytes_size_func = [](const BytesValue& value) -> int64_t { return value.Size(); }; // Support global and receiver style size() operations on bytes. using BytesSizeFnAdapter = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR(BytesSizeFnAdapter::RegisterGlobalOverload( cel::builtin::kSize, bytes_size_func, registry)); return BytesSizeFnAdapter::RegisterMemberOverload(cel::builtin::kSize, bytes_size_func, registry); } absl::Status RegisterConcatFunctions(FunctionRegistry& registry) { using StrCatFnAdapter = BinaryFunctionAdapter, const StringValue&, const StringValue&>; CEL_RETURN_IF_ERROR(StrCatFnAdapter::RegisterGlobalOverload( cel::builtin::kAdd, &ConcatString, registry)); using BytesCatFnAdapter = BinaryFunctionAdapter, const BytesValue&, const BytesValue&>; return BytesCatFnAdapter::RegisterGlobalOverload(cel::builtin::kAdd, &ConcatBytes, registry); } } // namespace absl::Status RegisterStringFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { // Basic substring tests (contains, startsWith, endsWith) for (bool receiver_style : {true, false}) { auto status = BinaryFunctionAdapter:: Register(cel::builtin::kStringContains, receiver_style, StringContains, registry); CEL_RETURN_IF_ERROR(status); status = BinaryFunctionAdapter:: Register(cel::builtin::kStringEndsWith, receiver_style, StringEndsWith, registry); CEL_RETURN_IF_ERROR(status); status = BinaryFunctionAdapter:: Register(cel::builtin::kStringStartsWith, receiver_style, StringStartsWith, registry); CEL_RETURN_IF_ERROR(status); } // string concatenation if enabled if (options.enable_string_concat) { CEL_RETURN_IF_ERROR(RegisterConcatFunctions(registry)); } return RegisterSizeFunctions(registry); } } // namespace cel ================================================ FILE: runtime/standard/string_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_STRING_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_STRING_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register builtin string and bytes functions: // _+_ (concatenation), size, contains, startsWith, endsWith // Most users should use RegisterBuiltinFunctions, which includes these // definitions. absl::Status RegisterStringFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_STRING_FUNCTIONS_H_ ================================================ FILE: runtime/standard/string_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/string_functions.h" #include #include "base/builtins.h" #include "common/function_descriptor.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; enum class CallStyle { kFree, kReceiver }; MATCHER_P3(MatchesDescriptor, name, call_style, expected_kinds, "") { bool receiver_style; switch (call_style) { case CallStyle::kFree: receiver_style = false; break; case CallStyle::kReceiver: receiver_style = true; break; } const FunctionDescriptor& descriptor = *arg; const std::vector& types = expected_kinds; return descriptor.name() == name && descriptor.receiver_style() == receiver_style && descriptor.types() == types; } TEST(RegisterStringFunctions, FunctionsRegistered) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterStringFunctions(registry, options)); auto overloads = registry.ListFunctions(); EXPECT_THAT( overloads[builtin::kAdd], UnorderedElementsAre( MatchesDescriptor(builtin::kAdd, CallStyle::kFree, std::vector{Kind::kString, Kind::kString}), MatchesDescriptor(builtin::kAdd, CallStyle::kFree, std::vector{Kind::kBytes, Kind::kBytes}))); EXPECT_THAT(overloads[builtin::kSize], UnorderedElementsAre( MatchesDescriptor(builtin::kSize, CallStyle::kFree, std::vector{Kind::kString}), MatchesDescriptor(builtin::kSize, CallStyle::kFree, std::vector{Kind::kBytes}), MatchesDescriptor(builtin::kSize, CallStyle::kReceiver, std::vector{Kind::kString}), MatchesDescriptor(builtin::kSize, CallStyle::kReceiver, std::vector{Kind::kBytes}))); EXPECT_THAT( overloads[builtin::kStringContains], UnorderedElementsAre( MatchesDescriptor(builtin::kStringContains, CallStyle::kFree, std::vector{Kind::kString, Kind::kString}), MatchesDescriptor(builtin::kStringContains, CallStyle::kReceiver, std::vector{Kind::kString, Kind::kString}))); EXPECT_THAT( overloads[builtin::kStringStartsWith], UnorderedElementsAre( MatchesDescriptor(builtin::kStringStartsWith, CallStyle::kFree, std::vector{Kind::kString, Kind::kString}), MatchesDescriptor(builtin::kStringStartsWith, CallStyle::kReceiver, std::vector{Kind::kString, Kind::kString}))); EXPECT_THAT( overloads[builtin::kStringEndsWith], UnorderedElementsAre( MatchesDescriptor(builtin::kStringEndsWith, CallStyle::kFree, std::vector{Kind::kString, Kind::kString}), MatchesDescriptor(builtin::kStringEndsWith, CallStyle::kReceiver, std::vector{Kind::kString, Kind::kString}))); } TEST(RegisterStringFunctions, ConcatSkippedWhenDisabled) { FunctionRegistry registry; RuntimeOptions options; options.enable_string_concat = false; ASSERT_OK(RegisterStringFunctions(registry, options)); auto overloads = registry.ListFunctions(); EXPECT_THAT(overloads[builtin::kAdd], IsEmpty()); } // TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for // evaluator available. } // namespace } // namespace cel ================================================ FILE: runtime/standard/time_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/time_functions.h" #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/time/civil_time.h" #include "absl/time/time.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "internal/overflow.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { namespace { // Timestamp absl::Status FindTimeBreakdown(absl::Time timestamp, absl::string_view tz, absl::TimeZone::CivilInfo* breakdown) { absl::TimeZone time_zone; // Early return if there is no timezone. if (tz.empty()) { *breakdown = time_zone.At(timestamp); return absl::OkStatus(); } // Check to see whether the timezone is an IANA timezone. if (absl::LoadTimeZone(tz, &time_zone)) { *breakdown = time_zone.At(timestamp); return absl::OkStatus(); } // Check for times of the format: [+-]HH:MM and convert them into durations // specified as [+-]HHhMMm. if (absl::StrContains(tz, ":")) { std::string dur = absl::StrCat(tz, "m"); absl::StrReplaceAll({{":", "h"}}, &dur); absl::Duration d; if (absl::ParseDuration(dur, &d)) { timestamp += d; *breakdown = time_zone.At(timestamp); return absl::OkStatus(); } } // Otherwise, error. return absl::InvalidArgumentError("Invalid timezone"); } Value GetTimeBreakdownPart( absl::Time timestamp, absl::string_view tz, const std::function& extractor_func) { absl::TimeZone::CivilInfo breakdown; auto status = FindTimeBreakdown(timestamp, tz, &breakdown); if (!status.ok()) { return ErrorValue(status); } return IntValue(extractor_func(breakdown)); } Value GetFullYear(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart(timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return breakdown.cs.year(); }); } Value GetMonth(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart(timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return breakdown.cs.month() - 1; }); } Value GetDayOfYear(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart( timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return absl::GetYearDay(absl::CivilDay(breakdown.cs)) - 1; }); } Value GetDayOfMonth(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart(timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return breakdown.cs.day() - 1; }); } Value GetDate(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart(timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return breakdown.cs.day(); }); } Value GetDayOfWeek(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart( timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { absl::Weekday weekday = absl::GetWeekday(breakdown.cs); // get day of week from the date in UTC, zero-based, zero for Sunday, // based on GetDayOfWeek CEL function definition. int weekday_num = static_cast(weekday); weekday_num = (weekday_num == 6) ? 0 : weekday_num + 1; return weekday_num; }); } Value GetHours(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart(timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return breakdown.cs.hour(); }); } Value GetMinutes(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart(timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return breakdown.cs.minute(); }); } Value GetSeconds(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart(timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return breakdown.cs.second(); }); } Value GetMilliseconds(absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart( timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { return absl::ToInt64Milliseconds(breakdown.subsecond); }); } absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kFullYear, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetFullYear(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( builtin::kFullYear, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetFullYear(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kMonth, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetMonth(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor(builtin::kMonth, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetMonth(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kDayOfYear, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetDayOfYear(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( builtin::kDayOfYear, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetDayOfYear(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kDayOfMonth, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetDayOfMonth(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( builtin::kDayOfMonth, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetDayOfMonth(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kDate, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetDate(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor(builtin::kDate, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetDate(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kDayOfWeek, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetDayOfWeek(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( builtin::kDayOfWeek, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetDayOfWeek(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kHours, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetHours(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor(builtin::kHours, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetHours(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kMinutes, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetMinutes(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( builtin::kMinutes, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetMinutes(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kSeconds, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetSeconds(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( builtin::kSeconds, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetSeconds(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kMilliseconds, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetMilliseconds(ts, tz.ToString()); }))); return registry.Register( UnaryFunctionAdapter::CreateDescriptor( builtin::kMilliseconds, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetMilliseconds(ts, ""); })); } absl::Status RegisterCheckedTimeArithmeticFunctions( FunctionRegistry& registry) { CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, false), BinaryFunctionAdapter, absl::Time, absl::Duration>:: WrapFunction( [](absl::Time t1, absl::Duration d2) -> absl::StatusOr { auto sum = cel::internal::CheckedAdd(t1, d2); if (!sum.ok()) { return ErrorValue(sum.status()); } return TimestampValue(*sum); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, absl::Duration, absl::Time>::CreateDescriptor(builtin::kAdd, false), BinaryFunctionAdapter, absl::Duration, absl::Time>:: WrapFunction( [](absl::Duration d2, absl::Time t1) -> absl::StatusOr { auto sum = cel::internal::CheckedAdd(t1, d2); if (!sum.ok()) { return ErrorValue(sum.status()); } return TimestampValue(*sum); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, absl::Duration, absl::Duration>::CreateDescriptor(builtin::kAdd, false), BinaryFunctionAdapter< absl::StatusOr, absl::Duration, absl::Duration>::WrapFunction([](absl::Duration d1, absl::Duration d2) -> absl::StatusOr { auto sum = cel::internal::CheckedAdd(d1, d2); if (!sum.ok()) { return ErrorValue(sum.status()); } return DurationValue(*sum); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, absl::Time, absl::Duration>:: CreateDescriptor(builtin::kSubtract, false), BinaryFunctionAdapter, absl::Time, absl::Duration>:: WrapFunction( [](absl::Time t1, absl::Duration d2) -> absl::StatusOr { auto diff = cel::internal::CheckedSub(t1, d2); if (!diff.ok()) { return ErrorValue(diff.status()); } return TimestampValue(*diff); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, absl::Time, absl::Time>::CreateDescriptor(builtin::kSubtract, false), BinaryFunctionAdapter, absl::Time, absl::Time>:: WrapFunction( [](absl::Time t1, absl::Time t2) -> absl::StatusOr { auto diff = cel::internal::CheckedSub(t1, t2); if (!diff.ok()) { return ErrorValue(diff.status()); } return DurationValue(*diff); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter< absl::StatusOr, absl::Duration, absl::Duration>::CreateDescriptor(builtin::kSubtract, false), BinaryFunctionAdapter< absl::StatusOr, absl::Duration, absl::Duration>::WrapFunction([](absl::Duration d1, absl::Duration d2) -> absl::StatusOr { auto diff = cel::internal::CheckedSub(d1, d2); if (!diff.ok()) { return ErrorValue(diff.status()); } return DurationValue(*diff); }))); return absl::OkStatus(); } absl::Status RegisterUncheckedTimeArithmeticFunctions( FunctionRegistry& registry) { CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, false), BinaryFunctionAdapter::WrapFunction( [](absl::Time t1, absl::Duration d2) -> Value { return UnsafeTimestampValue(t1 + d2); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, false), BinaryFunctionAdapter::WrapFunction( [](absl::Duration d2, absl::Time t1) -> Value { return UnsafeTimestampValue(t1 + d2); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, false), BinaryFunctionAdapter:: WrapFunction([](absl::Duration d1, absl::Duration d2) -> Value { return UnsafeDurationValue(d1 + d2); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kSubtract, false), BinaryFunctionAdapter::WrapFunction( [](absl::Time t1, absl::Duration d2) -> Value { return UnsafeTimestampValue(t1 - d2); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor( builtin::kSubtract, false), BinaryFunctionAdapter::WrapFunction( [](absl::Time t1, absl::Time t2) -> Value { return UnsafeDurationValue(t1 - t2); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: CreateDescriptor(builtin::kSubtract, false), BinaryFunctionAdapter:: WrapFunction([](absl::Duration d1, absl::Duration d2) -> Value { return UnsafeDurationValue(d1 - d2); }))); return absl::OkStatus(); } absl::Status RegisterDurationFunctions(FunctionRegistry& registry) { // duration breakdown accessor functions using DurationAccessorFunction = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR(registry.Register( DurationAccessorFunction::CreateDescriptor(builtin::kHours, true), DurationAccessorFunction::WrapFunction( [](absl::Duration d) -> int64_t { return absl::ToInt64Hours(d); }))); CEL_RETURN_IF_ERROR(registry.Register( DurationAccessorFunction::CreateDescriptor(builtin::kMinutes, true), DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { return absl::ToInt64Minutes(d); }))); CEL_RETURN_IF_ERROR(registry.Register( DurationAccessorFunction::CreateDescriptor(builtin::kSeconds, true), DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { return absl::ToInt64Seconds(d); }))); return registry.Register( DurationAccessorFunction::CreateDescriptor(builtin::kMilliseconds, true), DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { constexpr int64_t millis_per_second = 1000L; return absl::ToInt64Milliseconds(d) % millis_per_second; })); } } // namespace absl::Status RegisterTimeFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { CEL_RETURN_IF_ERROR(RegisterTimestampFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterDurationFunctions(registry)); // Special arithmetic operators for Timestamp and Duration // TODO(uncreated-issue/37): deprecate unchecked time math functions when clients no // longer depend on them. if (options.enable_timestamp_duration_overflow_errors) { return RegisterCheckedTimeArithmeticFunctions(registry); } return RegisterUncheckedTimeArithmeticFunctions(registry); } } // namespace cel ================================================ FILE: runtime/standard/time_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TIME_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TIME_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register builtin timestamp and duration functions: // // (timestamp).getFullYear() -> int // (timestamp).getMonth() -> int // (timestamp).getDayOfYear() -> int // (timestamp).getDayOfMonth() -> int // (timestamp).getDayOfWeek() -> int // (timestamp).getDate() -> int // (timestamp).getHours() -> int // (timestamp).getMinutes() -> int // (timestamp).getSeconds() -> int // (timestamp).getMilliseconds() -> int // // (duration).getHours() -> int // (duration).getMinutes() -> int // (duration).getSeconds() -> int // (duration).getMilliseconds() -> int // // _+_(timestamp, duration) -> timestamp // _+_(duration, timestamp) -> timestamp // _+_(duration, duration) -> duration // _-_(timestamp, timestamp) -> duration // _-_(timestamp, duration) -> timestamp // _-_(duration, duration) -> duration // // Most users should use RegisterBuiltinFunctions, which includes these // definitions. absl::Status RegisterTimeFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TIME_FUNCTIONS_H_ ================================================ FILE: runtime/standard/time_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/time_functions.h" #include #include "base/builtins.h" #include "common/function_descriptor.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesOperatorDescriptor, name, expected_kind1, expected_kind2, "") { const FunctionDescriptor& descriptor = *arg; std::vector types{expected_kind1, expected_kind2}; return descriptor.name() == name && descriptor.receiver_style() == false && descriptor.types() == types; } MATCHER_P2(MatchesTimeAccessor, name, kind, "") { const FunctionDescriptor& descriptor = *arg; std::vector types{kind}; return descriptor.name() == name && descriptor.receiver_style() == true && descriptor.types() == types; } MATCHER_P2(MatchesTimezoneTimeAccessor, name, kind, "") { const FunctionDescriptor& descriptor = *arg; std::vector types{kind, Kind::kString}; return descriptor.name() == name && descriptor.receiver_style() == true && descriptor.types() == types; } TEST(RegisterTimeFunctions, MathOperatorsRegistered) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTimeFunctions(registry, options)); auto registered_functions = registry.ListFunctions(); EXPECT_THAT(registered_functions[builtin::kAdd], UnorderedElementsAre( MatchesOperatorDescriptor(builtin::kAdd, Kind::kDuration, Kind::kDuration), MatchesOperatorDescriptor(builtin::kAdd, Kind::kTimestamp, Kind::kDuration), MatchesOperatorDescriptor(builtin::kAdd, Kind::kDuration, Kind::kTimestamp))); EXPECT_THAT(registered_functions[builtin::kSubtract], UnorderedElementsAre( MatchesOperatorDescriptor(builtin::kSubtract, Kind::kDuration, Kind::kDuration), MatchesOperatorDescriptor(builtin::kSubtract, Kind::kTimestamp, Kind::kDuration), MatchesOperatorDescriptor( builtin::kSubtract, Kind::kTimestamp, Kind::kTimestamp))); } TEST(RegisterTimeFunctions, AccessorsRegistered) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTimeFunctions(registry, options)); auto registered_functions = registry.ListFunctions(); EXPECT_THAT( registered_functions[builtin::kFullYear], UnorderedElementsAre( MatchesTimeAccessor(builtin::kFullYear, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kFullYear, Kind::kTimestamp))); EXPECT_THAT( registered_functions[builtin::kDate], UnorderedElementsAre( MatchesTimeAccessor(builtin::kDate, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kDate, Kind::kTimestamp))); EXPECT_THAT( registered_functions[builtin::kMonth], UnorderedElementsAre( MatchesTimeAccessor(builtin::kMonth, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kMonth, Kind::kTimestamp))); EXPECT_THAT( registered_functions[builtin::kDayOfYear], UnorderedElementsAre( MatchesTimeAccessor(builtin::kDayOfYear, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kDayOfYear, Kind::kTimestamp))); EXPECT_THAT( registered_functions[builtin::kDayOfMonth], UnorderedElementsAre( MatchesTimeAccessor(builtin::kDayOfMonth, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kDayOfMonth, Kind::kTimestamp))); EXPECT_THAT( registered_functions[builtin::kDayOfWeek], UnorderedElementsAre( MatchesTimeAccessor(builtin::kDayOfWeek, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kDayOfWeek, Kind::kTimestamp))); EXPECT_THAT( registered_functions[builtin::kHours], UnorderedElementsAre( MatchesTimeAccessor(builtin::kHours, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kHours, Kind::kTimestamp), MatchesTimeAccessor(builtin::kHours, Kind::kDuration))); EXPECT_THAT( registered_functions[builtin::kMinutes], UnorderedElementsAre( MatchesTimeAccessor(builtin::kMinutes, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kMinutes, Kind::kTimestamp), MatchesTimeAccessor(builtin::kMinutes, Kind::kDuration))); EXPECT_THAT( registered_functions[builtin::kSeconds], UnorderedElementsAre( MatchesTimeAccessor(builtin::kSeconds, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kSeconds, Kind::kTimestamp), MatchesTimeAccessor(builtin::kSeconds, Kind::kDuration))); EXPECT_THAT( registered_functions[builtin::kMilliseconds], UnorderedElementsAre( MatchesTimeAccessor(builtin::kMilliseconds, Kind::kTimestamp), MatchesTimezoneTimeAccessor(builtin::kMilliseconds, Kind::kTimestamp), MatchesTimeAccessor(builtin::kMilliseconds, Kind::kDuration))); } // TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for // evaluator available. } // namespace } // namespace cel ================================================ FILE: runtime/standard/type_conversion_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/type_conversion_functions.h" #include #include #include // NOLINT (required for std::to_chars_result) #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "base/builtins.h" #include "base/function_adapter.h" #include "common/value.h" #include "internal/overflow.h" #include "internal/status_macros.h" #include "internal/time.h" #include "internal/utf8.h" #include "runtime/function.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #if defined(_LIBCPP_VERSION) && _LIBCPP_VERSION >= 14000 && \ !defined(__APPLE__) || \ defined(__GNUC__) && __GNUC__ >= 13 || \ defined(_MSC_VER) && _MSC_VER >= 1920 #define _CEL_CHAR_CONV_DOUBLE_TO_CHARS 1 #endif namespace cel { namespace { using ::cel::internal::EncodeDurationToJson; using ::cel::internal::EncodeTimestampToJson; using ::cel::internal::MaxTimestamp; using ::cel::internal::MinTimestamp; Value FormatDouble(double v, const Function::InvokeContext& context) { google::protobuf::Arena* arena = context.arena(); #if defined(CEL_NO_CHARCONV_DOUBLE_TO_CHARS) || \ !defined(_CEL_CHAR_CONV_DOUBLE_TO_CHARS) // Fallback to absl::StrFormat. Slower and handles edge cases around precision // differently but safe and covers most cases. return StringValue::From(absl::StrFormat("%.17g", v), arena); #else constexpr int kBufSize = 32; char buf[kBufSize]; std::to_chars_result result = std::to_chars(buf, buf + kBufSize, v, std::chars_format::general); if (result.ec != std::errc()) { return cel::ErrorValue(absl::InvalidArgumentError(absl::StrCat( "double format error: ", std::make_error_code(result.ec).message()))); } absl::string_view out(buf, result.ptr - buf); return StringValue::From(out, arena); #endif } Value LegacyFormatDouble(double v, const Function::InvokeContext& context) { return StringValue::From(absl::StrCat(v), context.arena()); } absl::Status RegisterBoolConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { // bool -> bool absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kBool, [](bool v) { return v; }, registry); CEL_RETURN_IF_ERROR(status); // string -> bool return UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kBool, [](const StringValue& v) -> Value { if ((v == "true") || (v == "True") || (v == "TRUE") || (v == "t") || (v == "1")) { return TrueValue(); } else if ((v == "false") || (v == "FALSE") || (v == "False") || (v == "f") || (v == "0")) { return FalseValue(); } else { return ErrorValue(absl::InvalidArgumentError( "Type conversion error from 'string' to 'bool'")); } }, registry); } absl::Status RegisterIntConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { // bool -> int absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kInt, [](bool v) { return static_cast(v); }, registry); CEL_RETURN_IF_ERROR(status); // double -> int status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kInt, [](double v) -> Value { auto conv = cel::internal::CheckedDoubleToInt64(v); if (!conv.ok()) { return ErrorValue(conv.status()); } return IntValue(*conv); }, registry); CEL_RETURN_IF_ERROR(status); // int -> int status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kInt, [](int64_t v) { return v; }, registry); CEL_RETURN_IF_ERROR(status); // string -> int status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kInt, [](const StringValue& s) -> Value { int64_t result; if (!absl::SimpleAtoi(s.ToString(), &result)) { return ErrorValue( absl::InvalidArgumentError("cannot convert string to int")); } return IntValue(result); }, registry); CEL_RETURN_IF_ERROR(status); // time -> int status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kInt, [](absl::Time t) { return absl::ToUnixSeconds(t); }, registry); CEL_RETURN_IF_ERROR(status); // uint -> int return UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kInt, [](uint64_t v) -> Value { auto conv = cel::internal::CheckedUint64ToInt64(v); if (!conv.ok()) { return ErrorValue(conv.status()); } return IntValue(*conv); }, registry); } absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { // May be optionally disabled to reduce potential allocs. if (!options.enable_string_conversion) { return absl::OkStatus(); } absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kString, [](const BytesValue& value) -> Value { auto valid = value.NativeValue([](const auto& value) -> bool { return internal::Utf8IsValid(value); }); if (!valid) { return ErrorValue( absl::InvalidArgumentError("malformed UTF-8 bytes")); } return StringValue(value.ToString()); }, registry); CEL_RETURN_IF_ERROR(status); // bool -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kString, [](bool value) -> StringValue { return StringValue(value ? "true" : "false"); }, registry); CEL_RETURN_IF_ERROR(status); // double -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kString, (options.enable_precision_preserving_double_format ? &FormatDouble : &LegacyFormatDouble), registry); CEL_RETURN_IF_ERROR(status); // int -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kString, [](int64_t value) -> StringValue { return StringValue(absl::StrCat(value)); }, registry); CEL_RETURN_IF_ERROR(status); // string -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kString, [](StringValue value) -> StringValue { return value; }, registry); CEL_RETURN_IF_ERROR(status); // uint -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kString, [](uint64_t value) -> StringValue { return StringValue(absl::StrCat(value)); }, registry); CEL_RETURN_IF_ERROR(status); // duration -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kString, [](absl::Duration value) -> Value { auto encode = EncodeDurationToJson(value); if (!encode.ok()) { return ErrorValue(encode.status()); } return StringValue(*encode); }, registry); CEL_RETURN_IF_ERROR(status); // timestamp -> string return UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kString, [](absl::Time value) -> Value { auto encode = EncodeTimestampToJson(value); if (!encode.ok()) { return ErrorValue(encode.status()); } return StringValue(*encode); }, registry); } absl::Status RegisterUintConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { // double -> uint absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kUint, [](double v) -> Value { auto conv = cel::internal::CheckedDoubleToUint64(v); if (!conv.ok()) { return ErrorValue(conv.status()); } return UintValue(*conv); }, registry); CEL_RETURN_IF_ERROR(status); // int -> uint status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kUint, [](int64_t v) -> Value { auto conv = cel::internal::CheckedInt64ToUint64(v); if (!conv.ok()) { return ErrorValue(conv.status()); } return UintValue(*conv); }, registry); CEL_RETURN_IF_ERROR(status); // string -> uint status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kUint, [](const StringValue& s) -> Value { uint64_t result; if (!absl::SimpleAtoi(s.ToString(), &result)) { return ErrorValue( absl::InvalidArgumentError("cannot convert string to uint")); } return UintValue(result); }, registry); CEL_RETURN_IF_ERROR(status); // uint -> uint return UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kUint, [](uint64_t v) { return v; }, registry); } absl::Status RegisterBytesConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { // bytes -> bytes absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kBytes, [](BytesValue value) -> BytesValue { return value; }, registry); CEL_RETURN_IF_ERROR(status); // string -> bytes return UnaryFunctionAdapter, const StringValue&>:: RegisterGlobalOverload( cel::builtin::kBytes, [](const StringValue& value) { return BytesValue(value.ToString()); }, registry); } absl::Status RegisterDoubleConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { // double -> double absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kDouble, [](double v) { return v; }, registry); CEL_RETURN_IF_ERROR(status); // int -> double status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kDouble, [](int64_t v) { return static_cast(v); }, registry); CEL_RETURN_IF_ERROR(status); // string -> double status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kDouble, [](const StringValue& s) -> Value { double result; if (absl::SimpleAtod(s.ToString(), &result)) { return DoubleValue(result); } else { return ErrorValue(absl::InvalidArgumentError( "cannot convert string to double")); } }, registry); CEL_RETURN_IF_ERROR(status); // uint -> double return UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kDouble, [](uint64_t v) { return static_cast(v); }, registry); } Value CreateDurationFromString(const StringValue& dur_str) { absl::Duration d; if (!absl::ParseDuration(dur_str.ToString(), &d)) { return ErrorValue( absl::InvalidArgumentError("String to Duration conversion failed")); } auto status = internal::ValidateDuration(d); if (!status.ok()) { return ErrorValue(std::move(status)); } return DurationValue(d); } absl::Status RegisterTimeConversionFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { // duration() conversion from string. CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kDuration, CreateDurationFromString, registry))); bool enable_timestamp_duration_overflow_errors = options.enable_timestamp_duration_overflow_errors; // timestamp conversion from int. CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kTimestamp, [=](int64_t epoch_seconds) -> Value { absl::Time ts = absl::FromUnixSeconds(epoch_seconds); if (enable_timestamp_duration_overflow_errors) { if (ts < MinTimestamp() || ts > MaxTimestamp()) { return ErrorValue(absl::OutOfRangeError("timestamp overflow")); } } return UnsafeTimestampValue(ts); }, registry))); // timestamp -> timestamp CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kTimestamp, [](absl::Time value) -> Value { return TimestampValue(value); }, registry))); // duration -> duration CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kDuration, [](absl::Duration value) -> Value { return DurationValue(value); }, registry))); // timestamp() conversion from string. return UnaryFunctionAdapter:: RegisterGlobalOverload( cel::builtin::kTimestamp, [=](const StringValue& time_str) -> Value { absl::Time ts; if (!absl::ParseTime(absl::RFC3339_full, time_str.ToString(), &ts, nullptr)) { return ErrorValue(absl::InvalidArgumentError( "String to Timestamp conversion failed")); } if (enable_timestamp_duration_overflow_errors) { if (ts < MinTimestamp() || ts > MaxTimestamp()) { return ErrorValue(absl::OutOfRangeError("timestamp overflow")); } } return UnsafeTimestampValue(ts); }, registry); } } // namespace absl::Status RegisterTypeConversionFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { CEL_RETURN_IF_ERROR(RegisterBoolConversionFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterBytesConversionFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterDoubleConversionFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterIntConversionFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterStringConversionFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterUintConversionFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterTimeConversionFunctions(registry, options)); // dyn() identity function. // TODO(issues/102): strip dyn() function references at type-check time. absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kDyn, [](const Value& value) -> Value { return value; }, registry); CEL_RETURN_IF_ERROR(status); // type(dyn) -> type return UnaryFunctionAdapter::RegisterGlobalOverload( cel::builtin::kType, [](const Value& value) { return TypeValue(value.GetRuntimeType()); }, registry); } } // namespace cel ================================================ FILE: runtime/standard/type_conversion_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TYPE_CONVERSION_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TYPE_CONVERSION_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register builtin type conversion functions: // dyn, int, uint, double, timestamp, duration, string, bytes, type // // Most users should use RegisterBuiltinFunctions, which includes these // definitions. absl::Status RegisterTypeConversionFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_TYPE_CONVERSION_FUNCTIONS_H_ ================================================ FILE: runtime/standard/type_conversion_functions_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard/type_conversion_functions.h" #include #include "base/builtins.h" #include "common/function_descriptor.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesUnaryDescriptor, name, receiver, expected_kind, "") { const FunctionDescriptor& descriptor = arg.descriptor; std::vector types{expected_kind}; return descriptor.name() == name && descriptor.receiver_style() == receiver && descriptor.types() == types; } TEST(RegisterTypeConversionFunctions, RegisterBoolConversionFunctions) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kBool, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kBool, false, Kind::kBool), MatchesUnaryDescriptor(builtin::kBool, false, Kind::kString))); } TEST(RegisterTypeConversionFunctions, RegisterIntConversionFunctions) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kInt, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kInt, false, Kind::kInt), MatchesUnaryDescriptor(builtin::kInt, false, Kind::kDouble), MatchesUnaryDescriptor(builtin::kInt, false, Kind::kUint), MatchesUnaryDescriptor(builtin::kInt, false, Kind::kBool), MatchesUnaryDescriptor(builtin::kInt, false, Kind::kString), MatchesUnaryDescriptor(builtin::kInt, false, Kind::kTimestamp))); } TEST(RegisterTypeConversionFunctions, RegisterUintConversionFunctions) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kUint, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kUint, false, Kind::kInt), MatchesUnaryDescriptor(builtin::kUint, false, Kind::kDouble), MatchesUnaryDescriptor(builtin::kUint, false, Kind::kUint), MatchesUnaryDescriptor(builtin::kUint, false, Kind::kString))); } TEST(RegisterTypeConversionFunctions, RegisterDoubleConversionFunctions) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kDouble, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kDouble, false, Kind::kInt), MatchesUnaryDescriptor(builtin::kDouble, false, Kind::kDouble), MatchesUnaryDescriptor(builtin::kDouble, false, Kind::kUint), MatchesUnaryDescriptor(builtin::kDouble, false, Kind::kString))); } TEST(RegisterTypeConversionFunctions, RegisterStringConversionFunctions) { FunctionRegistry registry; RuntimeOptions options; options.enable_string_conversion = true; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kString, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kString, false, Kind::kBool), MatchesUnaryDescriptor(builtin::kString, false, Kind::kInt), MatchesUnaryDescriptor(builtin::kString, false, Kind::kDouble), MatchesUnaryDescriptor(builtin::kString, false, Kind::kUint), MatchesUnaryDescriptor(builtin::kString, false, Kind::kString), MatchesUnaryDescriptor(builtin::kString, false, Kind::kBytes), MatchesUnaryDescriptor(builtin::kString, false, Kind::kDuration), MatchesUnaryDescriptor(builtin::kString, false, Kind::kTimestamp))); } TEST(RegisterTypeConversionFunctions, RegisterStringConversionFunctionsDisabled) { FunctionRegistry registry; RuntimeOptions options; options.enable_string_conversion = false; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kString, false, {Kind::kAny}), IsEmpty()); } TEST(RegisterTypeConversionFunctions, RegisterBytesConversionFunctions) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kBytes, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kBytes, false, Kind::kBytes), MatchesUnaryDescriptor(builtin::kBytes, false, Kind::kString))); } TEST(RegisterTypeConversionFunctions, RegisterTimeConversionFunctions) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT( registry.FindStaticOverloads(builtin::kTimestamp, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kTimestamp, false, Kind::kInt), MatchesUnaryDescriptor(builtin::kTimestamp, false, Kind::kString), MatchesUnaryDescriptor(builtin::kTimestamp, false, Kind::kTimestamp))); EXPECT_THAT( registry.FindStaticOverloads(builtin::kDuration, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kDuration, false, Kind::kString), MatchesUnaryDescriptor(builtin::kDuration, false, Kind::kDuration))); } TEST(RegisterTypeConversionFunctions, RegisterMetaTypeConversionFunctions) { FunctionRegistry registry; RuntimeOptions options; ASSERT_OK(RegisterTypeConversionFunctions(registry, options)); EXPECT_THAT(registry.FindStaticOverloads(builtin::kDyn, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kDyn, false, Kind::kAny))); EXPECT_THAT(registry.FindStaticOverloads(builtin::kType, false, {Kind::kAny}), UnorderedElementsAre( MatchesUnaryDescriptor(builtin::kType, false, Kind::kAny))); } // TODO(uncreated-issue/41): move functional parsed expr tests when modern APIs for // evaluator available. } // namespace } // namespace cel ================================================ FILE: runtime/standard_functions.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard_functions.h" #include "absl/status/status.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" #include "runtime/standard/arithmetic_functions.h" #include "runtime/standard/comparison_functions.h" #include "runtime/standard/container_functions.h" #include "runtime/standard/container_membership_functions.h" #include "runtime/standard/equality_functions.h" #include "runtime/standard/logical_functions.h" #include "runtime/standard/regex_functions.h" #include "runtime/standard/string_functions.h" #include "runtime/standard/time_functions.h" #include "runtime/standard/type_conversion_functions.h" namespace cel { absl::Status RegisterStandardFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { CEL_RETURN_IF_ERROR(RegisterArithmeticFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterComparisonFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterContainerFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterContainerMembershipFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterLogicalFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterRegexFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterStringFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterTimeFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterEqualityFunctions(registry, options)); return RegisterTypeConversionFunctions(registry, options); } } // namespace cel ================================================ FILE: runtime/standard_functions.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_FUNCTIONS_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_FUNCTIONS_H_ #include "absl/status/status.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" namespace cel { // Register all CEL standard definitions. // // See // https://github.com/google/cel-spec/blob/master/doc/langdef.md#standard-definitions absl::Status RegisterStandardFunctions(FunctionRegistry& registry, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_FUNCTIONS_H_ ================================================ FILE: runtime/standard_runtime_builder_factory.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard_runtime_builder_factory.h" #include #include #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "internal/noop_delete.h" #include "internal/status_macros.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_builder_factory.h" #include "runtime/runtime_options.h" #include "runtime/standard_functions.h" #include "google/protobuf/descriptor.h" namespace cel { absl::StatusOr CreateStandardRuntimeBuilder( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, const RuntimeOptions& options) { ABSL_DCHECK(descriptor_pool != nullptr); return CreateStandardRuntimeBuilder( std::shared_ptr( descriptor_pool, internal::NoopDeleteFor()), options); } absl::StatusOr CreateStandardRuntimeBuilder( absl_nonnull std::shared_ptr descriptor_pool, const RuntimeOptions& options) { ABSL_DCHECK(descriptor_pool != nullptr); CEL_ASSIGN_OR_RETURN( auto builder, CreateRuntimeBuilder(std::move(descriptor_pool), options)); CEL_RETURN_IF_ERROR( RegisterStandardFunctions(builder.function_registry(), options)); return builder; } } // namespace cel ================================================ FILE: runtime/standard_runtime_builder_factory.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_RUNTIME_BUILDER_FACTORY_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_RUNTIME_BUILDER_FACTORY_H_ #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "google/protobuf/descriptor.h" namespace cel { // Create a builder preconfigured with CEL standard definitions. // // See `CreateRuntimeBuilder` for a description of the requirements related to // `descriptor_pool`. absl::StatusOr CreateStandardRuntimeBuilder( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool ABSL_ATTRIBUTE_LIFETIME_BOUND, const RuntimeOptions& options); absl::StatusOr CreateStandardRuntimeBuilder( absl_nonnull std::shared_ptr descriptor_pool, const RuntimeOptions& options); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_STANDARD_RUNTIME_BUILDER_FACTORY_H_ ================================================ FILE: runtime/standard_runtime_builder_factory_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/standard_runtime_builder_factory.h" #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/builtins.h" #include "common/source.h" #include "common/value.h" #include "common/value_testing.h" #include "extensions/bindings_ext.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "parser/macro_registry.h" #include "parser/parser.h" #include "parser/standard_macros.h" #include "runtime/activation.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::test::BoolValueIs; using ::cel::test::IntValueIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::Truly; const cel::MacroRegistry& GetMacros() { static absl::NoDestructor macros([]() { MacroRegistry registry; ABSL_CHECK_OK(cel::RegisterStandardMacros(registry, {})); for (const auto& macro : extensions::bindings_macros()) { ABSL_CHECK_OK(registry.RegisterMacro(macro)); } return registry; }()); return *macros; } absl::StatusOr ParseWithTestMacros(absl::string_view expression) { auto src = cel::NewSource(expression, ""); ABSL_CHECK_OK(src.status()); return Parse(**src, GetMacros()); } TEST(StandardRuntimeTest, RecursionLimitExceeded) { RuntimeOptions opts; opts.max_recursion_depth = 1; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), opts)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Maximum recursion depth of 1 exceeded"))); } TEST(StandardRuntimeTest, RecursionUnderLimit) { RuntimeOptions opts; opts.max_recursion_depth = 2; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), opts)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); // Whether the implementation is recursive shouldn't affect observable // behavior, but it does have performance implications (it will skip // allocating a value stack). EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, IntValueIs(3)); } TEST(StandardRuntimeTest, RecursionLimitTracksLazyExpressions) { RuntimeOptions opts; opts.max_recursion_depth = 8; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), opts)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(R"cel( cel.bind(a, 4 + (3 + (2 + 1)), cel.bind(b, 7 + (6 + (5 + a)), 9 + (8 + b) ) ))cel")); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Maximum recursion depth of 8 exceeded"))); } struct EvaluateResultTestCase { std::string name; std::string expression; bool expected_result; std::function activation_builder; template friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { sink.Append(tc.name); } }; class StandardRuntimeTest : public TestWithParam { public: const EvaluateResultTestCase& GetTestCase() { return GetParam(); } }; TEST_P(StandardRuntimeTest, Defaults) { RuntimeOptions opts; const EvaluateResultTestCase& test_case = GetTestCase(); ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), opts)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(test_case.expression)); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); EXPECT_FALSE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); google::protobuf::Arena arena; Activation activation; if (test_case.activation_builder != nullptr) { ASSERT_THAT(test_case.activation_builder(activation), IsOk()); } ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) << test_case.expression; } TEST_P(StandardRuntimeTest, Recursive) { RuntimeOptions opts; opts.max_recursion_depth = -1; const EvaluateResultTestCase& test_case = GetTestCase(); ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), opts)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(test_case.expression)); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); // Whether the implementation is recursive shouldn't affect observable // behavior, but it does have performance implications (it will skip // allocating a value stack). EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); google::protobuf::Arena arena; Activation activation; if (test_case.activation_builder != nullptr) { ASSERT_THAT(test_case.activation_builder(activation), IsOk()); } ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) << test_case.expression; } TEST_P(StandardRuntimeTest, FastBuiltins) { RuntimeOptions opts; opts.enable_fast_builtins = true; const EvaluateResultTestCase& test_case = GetTestCase(); ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), opts)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(test_case.expression)); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); EXPECT_FALSE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); google::protobuf::Arena arena; Activation activation; if (test_case.activation_builder != nullptr) { ASSERT_THAT(test_case.activation_builder(activation), IsOk()); } ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) << test_case.expression; } TEST_P(StandardRuntimeTest, RecursiveFastBuiltins) { RuntimeOptions opts; opts.enable_fast_builtins = true; opts.max_recursion_depth = -1; const EvaluateResultTestCase& test_case = GetTestCase(); ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), opts)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(test_case.expression)); ASSERT_OK_AND_ASSIGN(std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); // Whether the implementation is recursive shouldn't affect observable // behavior, but it does have performance implications (it will skip // allocating a value stack). EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); google::protobuf::Arena arena; Activation activation; if (test_case.activation_builder != nullptr) { ASSERT_THAT(test_case.activation_builder(activation), IsOk()); } ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); EXPECT_THAT(result, BoolValueIs(test_case.expected_result)) << test_case.expression; } INSTANTIATE_TEST_SUITE_P( Basic, StandardRuntimeTest, testing::ValuesIn(std::vector{ {"int_identifier", "int_var == 42", true, [](Activation& activation) { activation.InsertOrAssignValue("int_var", cel::IntValue(42)); return absl::OkStatus(); }}, {"logic_and_true", "true && 1 < 2", true}, {"logic_and_false", "true && 1 > 2", false}, {"logic_or_true", "false || 1 < 2", true}, {"logic_or_false", "false && 1 > 2", false}, {"ternary_true_cond", "(1 < 2 ? 'yes' : 'no') == 'yes'", true}, {"ternary_false_cond", "(1 > 2 ? 'yes' : 'no') == 'no'", true}, {"list_index", "['a', 'b', 'c', 'd'][1] == 'b'", true}, {"map_index_bool", "{true: 1, false: 2}[false] == 2", true}, {"map_index_string", "{'abc': 123}['abc'] == 123", true}, {"map_index_int", "{1: 2, 2: 4}[2] == 4", true}, {"map_index_uint", "{1u: 1, 2u: 2}[1u] == 1", true}, {"map_index_coerced_double", "{1: 2, 2: 4}[2.0] == 4", true}, })); INSTANTIATE_TEST_SUITE_P( Equality, StandardRuntimeTest, testing::ValuesIn(std::vector{ {"eq_bool_bool_true", "false == false", true}, {"eq_bool_bool_false", "false == true", false}, {"eq_int_int_true", "-1 == -1", true}, {"eq_int_int_false", "-1 == 1", false}, {"eq_uint_uint_true", "2u == 2u", true}, {"eq_uint_uint_false", "2u == 3u", false}, {"eq_double_double_true", "2.4 == 2.4", true}, {"eq_double_double_false", "2.4 == 3.3", false}, {"eq_string_string_true", "'abc' == 'abc'", true}, {"eq_string_string_false", "'abc' == 'def'", false}, {"eq_bytes_bytes_true", "b'abc' == b'abc'", true}, {"eq_bytes_bytes_false", "b'abc' == b'def'", false}, {"eq_duration_duration_true", "duration('15m') == duration('15m')", true}, {"eq_duration_duration_false", "duration('15m') == duration('1h')", false}, {"eq_timestamp_timestamp_true", "timestamp('1970-01-01T00:02:00Z') == " "timestamp('1970-01-01T00:02:00Z')", true}, {"eq_timestamp_timestamp_false", "timestamp('1970-01-01T00:02:00Z') == " "timestamp('2020-01-01T00:02:00Z')", false}, {"eq_null_null_true", "null == null", true}, {"eq_list_list_true", "[1, 2, 3] == [1, 2, 3]", true}, {"eq_list_list_false", "[1, 2, 3] == [1, 2, 3, 4]", false}, {"eq_map_map_true", "{1: 2, 2: 4} == {1: 2, 2: 4}", true}, {"eq_map_map_false", "{1: 2, 2: 4} == {1: 2, 2: 5}", false}, {"neq_bool_bool_true", "false != false", false}, {"neq_bool_bool_false", "false != true", true}, {"neq_int_int_true", "-1 != -1", false}, {"neq_int_int_false", "-1 != 1", true}, {"neq_uint_uint_true", "2u != 2u", false}, {"neq_uint_uint_false", "2u != 3u", true}, {"neq_double_double_true", "2.4 != 2.4", false}, {"neq_double_double_false", "2.4 != 3.3", true}, {"neq_string_string_true", "'abc' != 'abc'", false}, {"neq_string_string_false", "'abc' != 'def'", true}, {"neq_bytes_bytes_true", "b'abc' != b'abc'", false}, {"neq_bytes_bytes_false", "b'abc' != b'def'", true}, {"neq_duration_duration_true", "duration('15m') != duration('15m')", false}, {"neq_duration_duration_false", "duration('15m') != duration('1h')", true}, {"neq_timestamp_timestamp_true", "timestamp('1970-01-01T00:02:00Z') != " "timestamp('1970-01-01T00:02:00Z')", false}, {"neq_timestamp_timestamp_false", "timestamp('1970-01-01T00:02:00Z') != " "timestamp('2020-01-01T00:02:00Z')", true}, {"neq_null_null_true", "null != null", false}, {"neq_list_list_true", "[1, 2, 3] != [1, 2, 3]", false}, {"neq_list_list_false", "[1, 2, 3] != [1, 2, 3, 4]", true}, {"neq_map_map_true", "{1: 2, 2: 4} != {1: 2, 2: 4}", false}, {"neq_map_map_false", "{1: 2, 2: 4} != {1: 2, 2: 5}", true}})); INSTANTIATE_TEST_SUITE_P( ArithmeticFunctions, StandardRuntimeTest, testing::ValuesIn(std::vector{ {"lt_int_int_true", "-1 < 2", true}, {"lt_int_int_false", "2 < -1", false}, {"lt_double_double_true", "-1.1 < 2.2", true}, {"lt_double_double_false", "2.2 < -1.1", false}, {"lt_uint_uint_true", "1u < 2u", true}, {"lt_uint_uint_false", "2u < 1u", false}, {"lt_string_string_true", "'abc' < 'def'", true}, {"lt_string_string_false", "'def' < 'abc'", false}, {"lt_duration_duration_true", "duration('1s') < duration('2s')", true}, {"lt_duration_duration_false", "duration('2s') < duration('1s')", false}, {"lt_timestamp_timestamp_true", "timestamp(1) < timestamp(2)", true}, {"lt_timestamp_timestamp_false", "timestamp(2) < timestamp(1)", false}, {"gt_int_int_false", "-1 > 2", false}, {"gt_int_int_true", "2 > -1", true}, {"gt_double_double_false", "-1.1 > 2.2", false}, {"gt_double_double_true", "2.2 > -1.1", true}, {"gt_uint_uint_false", "1u > 2u", false}, {"gt_uint_uint_true", "2u > 1u", true}, {"gt_string_string_false", "'abc' > 'def'", false}, {"gt_string_string_true", "'def' > 'abc'", true}, {"gt_duration_duration_false", "duration('1s') > duration('2s')", false}, {"gt_duration_duration_true", "duration('2s') > duration('1s')", true}, {"gt_timestamp_timestamp_false", "timestamp(1) > timestamp(2)", false}, {"gt_timestamp_timestamp_true", "timestamp(2) > timestamp(1)", true}, {"le_int_int_true", "-1 <= -1", true}, {"le_int_int_false", "2 <= -1", false}, {"le_double_double_true", "-1.1 <= -1.1", true}, {"le_double_double_false", "2.2 <= -1.1", false}, {"le_uint_uint_true", "1u <= 1u", true}, {"le_uint_uint_false", "2u <= 1u", false}, {"le_string_string_true", "'abc' <= 'abc'", true}, {"le_string_string_false", "'def' <= 'abc'", false}, {"le_duration_duration_true", "duration('1s') <= duration('1s')", true}, {"le_duration_duration_false", "duration('2s') <= duration('1s')", false}, {"le_timestamp_timestamp_true", "timestamp(1) <= timestamp(1)", true}, {"le_timestamp_timestamp_false", "timestamp(2) <= timestamp(1)", false}, {"ge_int_int_false", "-1 >= 2", false}, {"ge_int_int_true", "2 >= 2", true}, {"ge_double_double_false", "-1.1 >= 2.2", false}, {"ge_double_double_true", "2.2 >= 2.2", true}, {"ge_uint_uint_false", "1u >= 2u", false}, {"ge_uint_uint_true", "2u >= 2u", true}, {"ge_string_string_false", "'abc' >= 'def'", false}, {"ge_string_string_true", "'abc' >= 'abc'", true}, {"ge_duration_duration_false", "duration('1s') >= duration('2s')", false}, {"ge_duration_duration_true", "duration('1s') >= duration('1s')", true}, {"ge_timestamp_timestamp_false", "timestamp(1) >= timestamp(2)", false}, {"ge_timestamp_timestamp_true", "timestamp(1) >= timestamp(1)", true}, {"sum_int_int", "1 + 2 == 3", true}, {"sum_uint_uint", "3u + 4u == 7", true}, {"sum_double_double", "1.0 + 2.5 == 3.5", true}, {"sum_duration_duration", "duration('2m') + duration('30s') == duration('150s')", true}, {"sum_time_duration", "timestamp(0) + duration('2m') == " "timestamp('1970-01-01T00:02:00Z')", true}, {"difference_int_int", "1 - 2 == -1", true}, {"difference_uint_uint", "4u - 3u == 1u", true}, {"difference_double_double", "1.0 - 2.5 == -1.5", true}, {"difference_duration_duration", "duration('5m') - duration('45s') == duration('4m15s')", true}, {"difference_time_time", "timestamp(10) - timestamp(0) == duration('10s')", true}, {"difference_time_duration", "timestamp(0) - duration('2m') == " "timestamp('1969-12-31T23:58:00Z')", true}, {"multiplication_int_int", "2 * 3 == 6", true}, {"multiplication_uint_uint", "2u * 3u == 6u", true}, {"multiplication_double_double", "2.5 * 3.0 == 7.5", true}, {"division_int_int", "6 / 3 == 2", true}, {"division_uint_uint", "8u / 4u == 2u", true}, {"division_double_double", "1.0 / 0.0 == double('inf')", true}, {"modulo_int_int", "6 % 4 == 2", true}, {"modulo_uint_uint", "8u % 5u == 3u", true}, })); INSTANTIATE_TEST_SUITE_P( Macros, StandardRuntimeTest, testing::ValuesIn(std::vector{ {"map", "[1, 2, 3, 4].map(x, x * x)[3] == 16", true}, {"filter", "[1, 2, 3, 4].filter(x, x < 4).size() == 3", true}, {"exists", "[1, 2, 3, 4].exists(x, x < 4)", true}, {"all", "[1, 2, 3, 4].all(x, x < 5)", true}})); INSTANTIATE_TEST_SUITE_P( StringFunctions, StandardRuntimeTest, testing::ValuesIn(std::vector{ {"string_contains", "'tacocat'.contains('acoca')", true}, {"string_contains_global", "contains('tacocat', 'dog')", false}, {"string_ends_with", "'abcdefg'.endsWith('efg')", true}, {"string_ends_with_global", "endsWith('abcdefg', 'fgh')", false}, {"string_starts_with", "'abcdefg'.startsWith('abc')", true}, {"string_starts_with_global", "startsWith('abcd', 'bcd')", false}, {"string_size", "'Hello World! 😀'.size() == 14", true}, {"string_size_global", "size('Hello world!') == 12", true}, {"bytes_size", "b'0123'.size() == 4", true}, {"bytes_size_global", "size(b'😀') == 4", true}})); INSTANTIATE_TEST_SUITE_P( RegExFunctions, StandardRuntimeTest, testing::ValuesIn(std::vector{ {"matches_string_re", "'127.0.0.1'.matches(r'127\\.\\d+\\.\\d+\\.\\d+')", true}, {"matches_string_re_global", "matches('192.168.0.1', r'127\\.\\d+\\.\\d+\\.\\d+')", false}})); INSTANTIATE_TEST_SUITE_P( TimeFunctions, StandardRuntimeTest, testing::ValuesIn(std::vector{ {"timestamp_get_full_year", "timestamp('2001-02-03T04:05:06.007Z').getFullYear() == 2001", true}, {"timestamp_get_date", "timestamp('2001-02-03T04:05:06.007Z').getDate() == 3", true}, {"timestamp_get_hours", "timestamp('2001-02-03T04:05:06.007Z').getHours() == 4", true}, {"timestamp_get_minutes", "timestamp('2001-02-03T04:05:06.007Z').getMinutes() == 5", true}, {"timestamp_get_seconds", "timestamp('2001-02-03T04:05:06.007Z').getSeconds() == 6", true}, {"timestamp_get_milliseconds", "timestamp('2001-02-03T04:05:06.007Z').getMilliseconds() == 7", true}, // Zero based indexing {"timestamp_get_month", "timestamp('2001-02-03T04:05:06.007Z').getMonth() == 1", true}, {"timestamp_get_day_of_year", "timestamp('2001-02-03T04:05:06.007Z').getDayOfYear() == 33", true}, {"timestamp_get_day_of_month", "timestamp('2001-02-03T04:05:06.007Z').getDayOfMonth() == 2", true}, {"timestamp_get_day_of_week", "timestamp('2001-02-03T04:05:06.007Z').getDayOfWeek() == 6", true}, {"duration_get_hours", "duration('10h20m30s40ms').getHours() == 10", true}, {"duration_get_minutes", "duration('10h20m30s40ms').getMinutes() == 20 + 600", true}, {"duration_get_seconds", "duration('10h20m30s40ms').getSeconds() == 30 + 20 * 60 + 10 * 60 " "* " "60", true}, {"duration_get_milliseconds", "duration('10h20m30s40ms').getMilliseconds() == 40", true}, })); INSTANTIATE_TEST_SUITE_P( TypeConversionFunctions, StandardRuntimeTest, testing::ValuesIn(std::vector{ {"string_timestamp", "string(timestamp(1)) == '1970-01-01T00:00:01Z'", true}, {"string_duration", "string(duration('10m30s')) == '630s'", true}, {"string_int", "string(-1) == '-1'", true}, {"string_uint", "string(1u) == '1'", true}, {"string_double", "string(double('inf')) == 'inf'", true}, {"string_double_nan", "string(double('nan')) == 'nan'", true}, {"string_bytes", R"(string(b'\xF0\x9F\x98\x80') == '😀')", true}, {"string_string", "string('hello!') == 'hello!'", true}, {"bytes_bytes", "bytes(b'123') == b'123'", true}, {"bytes_string", "bytes('😀') == b'\xF0\x9F\x98\x80'", true}, {"timestamp", "timestamp(1) == timestamp('1970-01-01T00:00:01Z')", true}, {"duration", "duration('10h') == duration('600m')", true}, {"double_string", "double('1.0') == 1.0", true}, {"double_string_precision", "double('0.14285714285714285') == 1.0 / 7.0", true}, {"double_string_nan", "double('nan') != double('nan')", true}, {"double_int", "double(1) == 1.0", true}, {"double_uint", "double(1u) == 1.0", true}, {"double_double", "double(1.0) == 1.0", true}, {"uint_string", "uint('1') == 1u", true}, {"uint_int", "uint(1) == 1u", true}, {"uint_uint", "uint(1u) == 1u", true}, {"uint_double", "uint(1.1) == 1u", true}, {"int_string", "int('-1') == -1", true}, {"int_int", "int(-1) == -1", true}, {"int_uint", "int(1u) == 1", true}, {"int_double", "int(-1.1) == -1", true}, {"int_timestamp", "int(timestamp('1969-12-31T23:30:00Z')) == -1800", true}, })); INSTANTIATE_TEST_SUITE_P( ContainerFunctions, StandardRuntimeTest, testing::ValuesIn(std::vector{ // Containers {"map_size", "{'abc': 1, 'def': 2}.size() == 2", true}, {"map_in", "'abc' in {'abc': 1, 'def': 2}", true}, {"map_in_numeric", "1.0 in {1u: 1, 2u: 2}", true}, {"list_size", "[1, 2, 3, 4].size() == 4", true}, {"list_size_global", "size([1, 2, 3]) == 3", true}, {"list_concat", "[1, 2] + [3, 4] == [1, 2, 3, 4]", true}, {"list_in", "'a' in ['a', 'b', 'c', 'd']", true}, {"list_in_numeric", "3u in [1.1, 2.3, 3.0, 4.4]", true}})); TEST(StandardRuntimeTest, RuntimeIssueSupport) { RuntimeOptions options; options.fail_on_warnings = false; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("unregistered_function(1)")); std::vector issues; ASSERT_OK_AND_ASSIGN( std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr, {&issues})); EXPECT_THAT(issues, ElementsAre(Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kWarning && issue.error_code() == RuntimeIssue::ErrorCode::kNoMatchingOverload; }))); } { ASSERT_OK_AND_ASSIGN( ParsedExpr expr, ParseWithTestMacros( "unregistered_function(1) || unregistered_function(2)")); std::vector issues; ASSERT_OK_AND_ASSIGN( std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr, {&issues})); EXPECT_THAT( issues, ElementsAre( Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kWarning && issue.error_code() == RuntimeIssue::ErrorCode::kNoMatchingOverload; }), Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kWarning && issue.error_code() == RuntimeIssue::ErrorCode::kNoMatchingOverload; }))); } { ASSERT_OK_AND_ASSIGN( ParsedExpr expr, ParseWithTestMacros( "unregistered_function(1) || unregistered_function(2) || true")); std::vector issues; ASSERT_OK_AND_ASSIGN( std::unique_ptr program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr, {&issues})); EXPECT_THAT( issues, ElementsAre( Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kWarning && issue.error_code() == RuntimeIssue::ErrorCode::kNoMatchingOverload; }), Truly([](const RuntimeIssue& issue) { return issue.severity() == RuntimeIssue::Severity::kWarning && issue.error_code() == RuntimeIssue::ErrorCode::kNoMatchingOverload; }))); google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(&arena, activation)); EXPECT_TRUE(result->Is() && result.GetBool().NativeValue()); } } enum class EvalStrategy { kIterative, kRecursive }; class StandardRuntimeEvalStrategyTest : public ::testing::TestWithParam {}; // Check that calls to specialized builtins are validated. TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinBoolOp) { EvalStrategy eval_strategy = GetParam(); RuntimeOptions options; if (eval_strategy == EvalStrategy::kRecursive) { options.max_recursion_depth = -1; } else { options.max_recursion_depth = 0; } ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ParsedExpr expr; expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kOr); auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); arg->mutable_const_expr()->set_bool_value(true); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinTernaryOp) { EvalStrategy eval_strategy = GetParam(); RuntimeOptions options; if (eval_strategy == EvalStrategy::kRecursive) { options.max_recursion_depth = -1; } else { options.max_recursion_depth = 0; } ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ParsedExpr expr; expr.mutable_expr()->mutable_call_expr()->set_function( cel::builtin::kTernary); expr.mutable_expr() ->mutable_call_expr() ->add_args() ->mutable_const_expr() ->set_bool_value(true); expr.mutable_expr() ->mutable_call_expr() ->add_args() ->mutable_const_expr() ->set_bool_value(true); expr.mutable_expr() ->mutable_call_expr() ->add_args() ->mutable_const_expr() ->set_bool_value(true); expr.mutable_expr() ->mutable_call_expr() ->add_args() ->mutable_const_expr() ->set_bool_value(true); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinIndex) { EvalStrategy eval_strategy = GetParam(); RuntimeOptions options; if (eval_strategy == EvalStrategy::kRecursive) { options.max_recursion_depth = -1; } else { options.max_recursion_depth = 0; } ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ParsedExpr expr; expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kIndex); auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); arg->mutable_list_expr() ->add_elements() ->mutable_const_expr() ->set_int64_value(1); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinEq) { EvalStrategy eval_strategy = GetParam(); RuntimeOptions options; if (eval_strategy == EvalStrategy::kRecursive) { options.max_recursion_depth = -1; } else { options.max_recursion_depth = 0; } ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ParsedExpr expr; expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kEqual); auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); arg->mutable_list_expr() ->add_elements() ->mutable_const_expr() ->set_int64_value(1); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_P(StandardRuntimeEvalStrategyTest, InvalidBuiltinIn) { EvalStrategy eval_strategy = GetParam(); RuntimeOptions options; if (eval_strategy == EvalStrategy::kRecursive) { options.max_recursion_depth = -1; } else { options.max_recursion_depth = 0; } ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); ParsedExpr expr; expr.mutable_expr()->mutable_call_expr()->set_function(cel::builtin::kIn); auto* arg = expr.mutable_expr()->mutable_call_expr()->add_args(); arg->mutable_list_expr() ->add_elements() ->mutable_const_expr() ->set_int64_value(1); EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_P(StandardRuntimeEvalStrategyTest, PrecisionPreservingDoubleFormat) { EvalStrategy eval_strategy = GetParam(); RuntimeOptions options; if (eval_strategy == EvalStrategy::kRecursive) { options.max_recursion_depth = -1; } else { options.max_recursion_depth = 0; } options.enable_precision_preserving_double_format = true; ASSERT_OK_AND_ASSIGN(auto builder, CreateStandardRuntimeBuilder( google::protobuf::DescriptorPool::generated_pool(), options)); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); // Note: the string format isn't guaranteed to be shortest since we don't have // to_chars support on all compilers, but it should still be reversible. const absl::string_view kCases[] = {"double(string(1.0/7.0)) == 1.0/7.0", "double(string(0.45)) == 0.45"}; google::protobuf::Arena arena; Activation activation; for (const auto& test_case : kCases) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(test_case)); ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(&arena, activation)); EXPECT_TRUE(result->Is() && result.GetBool().NativeValue()); } } INSTANTIATE_TEST_SUITE_P( StandardRuntimeEvalStrategyTest, StandardRuntimeEvalStrategyTest, testing::Values(EvalStrategy::kIterative, EvalStrategy::kRecursive), [](const auto& info) -> std::string { return info.param == EvalStrategy::kIterative ? "Iterative" : "Recursive"; }); } // namespace } // namespace cel ================================================ FILE: runtime/type_registry.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/type_registry.h" #include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "common/value.h" #include "runtime/internal/legacy_runtime_type_provider.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { TypeRegistry::TypeRegistry( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nullable message_factory) : type_provider_(descriptor_pool), legacy_type_provider_( std::make_shared( descriptor_pool, message_factory)) { RegisterEnum("google.protobuf.NullValue", {{"NULL_VALUE", 0}}); } void TypeRegistry::RegisterEnum(absl::string_view enum_name, std::vector enumerators) { { absl::MutexLock lock(enum_value_table_mutex_); enum_value_table_.reset(); } enum_types_[enum_name] = Enumeration{std::string(enum_name), std::move(enumerators)}; } std::shared_ptr> TypeRegistry::GetEnumValueTable() const { { absl::ReaderMutexLock lock(enum_value_table_mutex_); if (enum_value_table_ != nullptr) { return enum_value_table_; } } absl::MutexLock lock(enum_value_table_mutex_); if (enum_value_table_ != nullptr) { return enum_value_table_; } std::shared_ptr> result = std::make_shared>(); auto& enum_value_map = *result; for (auto iter = enum_types_.begin(); iter != enum_types_.end(); ++iter) { absl::string_view enum_name = iter->first; const auto& enum_type = iter->second; for (const auto& enumerator : enum_type.enumerators) { auto key = absl::StrCat(enum_name, ".", enumerator.name); enum_value_map[key] = cel::IntValue(enumerator.number); } } enum_value_table_ = result; return result; } } // namespace cel ================================================ FILE: runtime/type_registry.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_TYPE_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_TYPE_REGISTRY_H_ #include #include #include #include #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "base/type_provider.h" #include "common/type.h" #include "common/value.h" #include "runtime/internal/legacy_runtime_type_provider.h" #include "runtime/internal/runtime_type_provider.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" namespace cel { class TypeRegistry; namespace runtime_internal { const RuntimeTypeProvider& GetRuntimeTypeProvider( const TypeRegistry& type_registry); const absl_nonnull std::shared_ptr& GetLegacyRuntimeTypeProvider(const TypeRegistry& type_registry); // Returns a memoized table of fully qualified enum values. // // This is populated when first requested. std::shared_ptr> GetEnumValueTable(const TypeRegistry& type_registry); } // namespace runtime_internal // TypeRegistry manages composing TypeProviders used with a Runtime. // // It provides a single effective type provider to be used in a ValueManager. class TypeRegistry { public: // Representation for a custom enum constant. struct Enumerator { std::string name; int64_t number; }; struct Enumeration { std::string name; std::vector enumerators; }; TypeRegistry() : TypeRegistry(google::protobuf::DescriptorPool::generated_pool(), google::protobuf::MessageFactory::generated_factory()) {} TypeRegistry(const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nullable message_factory); // Neither moveable nor copyable. TypeRegistry(const TypeRegistry& other) = delete; TypeRegistry& operator=(TypeRegistry& other) = delete; TypeRegistry(TypeRegistry&& other) = delete; TypeRegistry& operator=(TypeRegistry&& other) = delete; // Registers a type such that it can be accessed by name, i.e. `type(foo) == // my_type`. Where `my_type` is the type being registered. absl::Status RegisterType(const OpaqueType& type) { return type_provider_.RegisterType(type); } // Register a custom enum type. // // This adds the enum to the set consulted at plan time to identify constant // enum values. void RegisterEnum(absl::string_view enum_name, std::vector enumerators); const absl::flat_hash_map& resolveable_enums() const { return enum_types_; } // Returns the effective type provider. const TypeProvider& GetComposedTypeProvider() const { return type_provider_; } private: friend const runtime_internal::RuntimeTypeProvider& runtime_internal::GetRuntimeTypeProvider(const TypeRegistry& type_registry); friend const absl_nonnull std::shared_ptr& runtime_internal::GetLegacyRuntimeTypeProvider( const TypeRegistry& type_registry); friend std::shared_ptr> runtime_internal::GetEnumValueTable(const TypeRegistry& type_registry); std::shared_ptr> GetEnumValueTable() const; runtime_internal::RuntimeTypeProvider type_provider_; absl_nonnull std::shared_ptr legacy_type_provider_; absl::flat_hash_map enum_types_; // memoized fully qualified enumerator names. // // populated when requested. // // In almost all cases, this is built once and never updated, but we can't // guarantee that with the current CelExpressionBuilder API. // // The cases when invalidation may occur are likely already race conditions, // but we provide basic thread safety to avoid issues with sanitizers. mutable std::shared_ptr> enum_value_table_ ABSL_GUARDED_BY(enum_value_table_mutex_); mutable absl::Mutex enum_value_table_mutex_; }; namespace runtime_internal { inline const RuntimeTypeProvider& GetRuntimeTypeProvider( const TypeRegistry& type_registry) { return type_registry.type_provider_; } inline const absl_nonnull std::shared_ptr& GetLegacyRuntimeTypeProvider(const TypeRegistry& type_registry) { return type_registry.legacy_type_provider_; } inline std::shared_ptr> GetEnumValueTable(const TypeRegistry& type_registry) { return type_registry.GetEnumValueTable(); } } // namespace runtime_internal } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_TYPE_REGISTRY_H_ ================================================ FILE: testing/testrunner/BUILD ================================================ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package( default_testonly = True, default_visibility = ["//visibility:public"], ) licenses(["notice"]) cc_library( name = "cel_test_context", hdrs = ["cel_test_context.h"], deps = [ ":cel_expression_source", "//common:value", "//compiler", "//eval/public:cel_expression", "//runtime", "//runtime:activation", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "runner_lib", srcs = ["runner_lib.cc"], hdrs = ["runner_lib.h"], deps = [ ":cel_expression_source", ":cel_test_context", ":coverage_index", ":coverage_reporting", "//checker:validation_result", "//common:ast", "//common:ast_proto", "//common:value", "//common/internal:value_conversion", "//eval/public:activation", "//eval/public:cel_expression", "//eval/public:cel_value", "//eval/public:transform_utility", "//internal:status_macros", "//internal:testing_no_main", "//runtime", "//runtime:activation", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "cel_test_factories", hdrs = ["cel_test_factories.h"], deps = [ ":cel_test_context", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", ], ) cc_test( name = "runner_lib_test", srcs = ["runner_lib_test.cc"], args = [ "--test_cel_file_path=$(location //testing/testrunner/resources:test.cel)", ], data = [ "//testing/testrunner/resources:test.cel", ], deps = [ ":cel_expression_source", ":cel_test_context", ":coverage_index", ":runner_lib", "//checker:type_checker_builder", "//checker:validation_result", "//common:ast_proto", "//common:decl", "//common:type", "//common:value", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", "//runtime:runtime_builder", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "coverage_reporting", srcs = ["coverage_reporting.cc"], hdrs = ["coverage_reporting.h"], deps = [ ":coverage_index", "//internal:testing_no_main", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], ) cc_library( name = "runner", srcs = ["runner_bin.cc"], deps = [ ":cel_expression_source", ":cel_test_context", ":cel_test_factories", ":coverage_index", ":coverage_reporting", ":runner_lib", "//eval/public:cel_expression", "//internal:status_macros", "//internal:testing_no_main", "//runtime", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], alwayslink = True, ) cc_library( name = "cel_expression_source", hdrs = ["cel_expression_source.h"], deps = ["@com_google_cel_spec//proto/cel/expr:checked_cc_proto"], ) cc_library( name = "coverage_index", srcs = ["coverage_index.cc"], hdrs = ["coverage_index.h"], deps = [ "//common:ast", "//common:value", "//eval/compiler:cel_expression_builder_flat_impl", "//eval/compiler:instrumentation", "//eval/public:cel_expression", "//internal:casts", "//runtime", "//runtime/internal:runtime_impl", "//tools:cel_unparser", "//tools:navigable_ast", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "coverage_index_test", srcs = ["coverage_index_test.cc"], deps = [ ":coverage_index", "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:ast_proto", "//common:decl", "//common:type", "//common:value", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", "//runtime:runtime_builder", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: testing/testrunner/cel_cc_test.bzl ================================================ # Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Rules for triggering the cc impl of the CEL test runner.""" load("@bazel_skylib//lib:paths.bzl", "paths") load("@rules_cc//cc:cc_test.bzl", "cc_test") expr_src_type = struct( RAW = "raw", FILE = "file", CHECKED = "checked", ) def cel_cc_test( name, test_suite = "", cel_expr = "", is_raw_expr = False, filegroup = "", deps = [], enable_coverage = False, test_data_path = "", data = [], **kwargs): """trigger the cc impl of the CEL test runner. This rule will generate a cc_test rule. This rule will be used to trigger the cc impl of the cel_test rule. Args: name: str name for the generated artifact test_suite: str label of a file containing a test suite. The file should have a .textproto extension. cel_expr: The CEL expression source. The meaning of this argument depends on `is_raw_expr`. is_raw_expr: bool whether the cel_expr is a raw expression string. If False, cel_expr is treated as a file path. The file type (.cel or .textproto) is inferred from the extension. filegroup: str label of a filegroup containing the test suite, the config and the checked expression. deps: list of dependencies for the cc_test rule. data: list of data dependencies for the cc_test rule. enable_coverage: bool whether to enable coverage collection. test_data_path: absolute path of the directory containing the test files. This is needed only if the test files are not located in the same directory as the BUILD file. **kwargs: additional arguments to pass to the cc_test rule. """ data, test_data_path = _update_data_with_test_files( data, filegroup, test_data_path, test_suite, cel_expr, is_raw_expr, ) args = kwargs.pop("args", []) test_data_path = test_data_path.lstrip("/") if test_suite != "": test_suite = test_data_path + "/" + test_suite args.append("--test_suite_path=" + test_suite) args.append("--collect_coverage=" + str(enable_coverage)) if cel_expr != "": expr_source_type = "" expr_source = "" if is_raw_expr: expr_source_type = expr_src_type.RAW expr_source = "\"" + cel_expr + "\"" else: _, ext = paths.split_extension(cel_expr) # The C++ test runner currently only supports parsing expressions from .cel files. # Support for other CEL source types (e.g., .celpolicy, .yaml) is not yet implemented. if ext == ".cel": expr_source_type = expr_src_type.FILE expr_source = test_data_path + "/" + cel_expr else: expr_source_type = expr_src_type.CHECKED expr_source = "$(location " + cel_expr + ")" args.append("--expr_source_type=" + expr_source_type) args.append("--expr_source=" + expr_source) cc_test( name = name, data = data, args = args, deps = ["//testing/testrunner:runner"] + deps, **kwargs ) def _update_data_with_test_files(data, filegroup, test_data_path, test_suite, cel_expr, is_raw_expr): """Updates the data with the test files.""" if filegroup != "": data = data + [filegroup] elif test_data_path != "" and test_data_path != native.package_name(): if test_suite != "": data = data + [test_data_path + ":" + test_suite] if cel_expr != "" and not is_raw_expr: _, ext = paths.split_extension(cel_expr) if ext == ".cel": data = data + [test_data_path + ":" + cel_expr] else: data = data + [cel_expr] else: test_data_path = native.package_name() if test_suite != "": data = data + [test_suite] if cel_expr != "" and not is_raw_expr: data = data + [cel_expr] return data, test_data_path ================================================ FILE: testing/testrunner/cel_expression_source.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_EXPRESSION_SOURCE_H_ #define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_EXPRESSION_SOURCE_H_ #include #include #include #include "cel/expr/checked.pb.h" namespace cel::test { // A wrapper class that holds one of three possible sources for a CEL // expression using a std::variant for type safety. class CelExpressionSource { public: // Distinct wrapper types are used for string-based sources to disambiguate // them within the std::variant. struct RawExpression { std::string value; }; struct CelFile { std::string path; }; // The variant holds one of the three possible source types. using SourceVariant = std::variant; // Creates a CelExpressionSource from a compiled // cel::expr::CheckedExpr. static CelExpressionSource FromCheckedExpr( cel::expr::CheckedExpr checked_expr) { return CelExpressionSource(std::move(checked_expr)); } // Creates a CelExpressionSource from a raw CEL expression string. static CelExpressionSource FromRawExpression(std::string raw_expression) { return CelExpressionSource(RawExpression{std::move(raw_expression)}); } // Creates a CelExpressionSource from a file path pointing to a .cel file. static CelExpressionSource FromCelFile(std::string cel_file_path) { return CelExpressionSource(CelFile{std::move(cel_file_path)}); } // Make copyable and movable. CelExpressionSource(const CelExpressionSource&) = default; CelExpressionSource& operator=(const CelExpressionSource&) = default; CelExpressionSource(CelExpressionSource&&) = default; CelExpressionSource& operator=(CelExpressionSource&&) = default; // Returns the underlying variant. The caller is expected to use std::visit // to interact with the active value in a type-safe manner. const SourceVariant& source() const { return source_; } private: // A single private constructor enforces creation via the static factories. explicit CelExpressionSource(SourceVariant source) : source_(std::move(source)) {} // A single std::variant member efficiently stores one of the possible states. SourceVariant source_; }; } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_EXPRESSION_SOURCE_H_ ================================================ FILE: testing/testrunner/cel_test_context.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TOOLS_TESTRUNNER_CEL_TEST_CONTEXT_H_ #define THIRD_PARTY_CEL_CPP_TOOLS_TESTRUNNER_CEL_TEST_CONTEXT_H_ #include #include #include #include #include "cel/expr/checked.pb.h" #include "cel/expr/value.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "common/value.h" #include "compiler/compiler.h" #include "eval/public/cel_expression.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "testing/testrunner/cel_expression_source.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/arena.h" namespace cel::test { // The context class for a CEL test, holding configurations needed to evaluate // compiled CEL expressions. class CelTestContext { public: using CelActivationFactoryFn = std::function( const cel::expr::conformance::test::TestCase& test_case, google::protobuf::Arena* arena)>; using AssertFn = std::function; // Creates a CelTestContext using a `CelExpressionBuilder`. // // The `CelExpressionBuilder` helps in setting up the environment for // building the CEL expression. // // Example usage: // // CEL_REGISTER_TEST_CONTEXT_FACTORY( // []() -> absl::StatusOr> { // // SAFE: This setup code now runs when the lambda is invoked at // runtime, // // long after all static initializations are complete. // auto cel_expression_builder = // google::api::expr::runtime::CreateCelExpressionBuilder(); // CelTestContextOptions options; // return CelTestContext::CreateFromCelExpressionBuilder( // std::move(cel_expression_builder), std::move(options)); // }); static std::unique_ptr CreateFromCelExpressionBuilder( std::unique_ptr cel_expression_builder) { return absl::WrapUnique( new CelTestContext(std::move(cel_expression_builder))); } // Creates a CelTestContext using a `cel::Runtime`. // // The `cel::Runtime` is used to evaluate the CEL expression by managing // the state needed to generate Program. static std::unique_ptr CreateFromRuntime( std::unique_ptr runtime) { return absl::WrapUnique(new CelTestContext(std::move(runtime))); } const cel::Runtime* absl_nullable runtime() const { return runtime_.get(); } const google::api::expr::runtime::CelExpressionBuilder* absl_nullable cel_expression_builder() const { return cel_expression_builder_.get(); } const cel::Compiler* absl_nullable compiler() const { return compiler_.get(); } const CelExpressionSource* absl_nullable expression_source() const { return expression_source_.get(); } const absl::flat_hash_map& custom_bindings() const { return custom_bindings_; } bool enable_coverage() const { return enable_coverage_; } // Allows the runner to inject the expression source // parsed from command-line flags. void SetExpressionSource(CelExpressionSource source) { expression_source_ = std::make_unique(std::move(source)); } // Allows the runner to inject an optional CEL compiler. void SetCompiler(std::unique_ptr compiler) { compiler_ = std::move(compiler); } // Allows the runner to inject custom bindings. void SetCustomBindings( absl::flat_hash_map custom_bindings) { custom_bindings_ = std::move(custom_bindings); } // Allows the runner to inject a custom activation factory. If not set, an // empty activation will be used. Custom bindings and test case inputs will // be added to the activation returned by the factory. void SetActivationFactory(CelActivationFactoryFn activation_factory) { activation_factory_ = std::move(activation_factory); } // Allows the runner to enable coverage collection. void SetEnableCoverage(bool enable) { enable_coverage_ = enable; } const CelActivationFactoryFn& activation_factory() const { return activation_factory_; } // Allows the runner to inject a custom assertion function. If not set, the // default assertion logic in TestRunner will be used. void SetAssertFn(AssertFn assert_fn) { assert_fn_ = std::move(assert_fn); } const AssertFn& assert_fn() const { return assert_fn_; } private: // Delete copy and move constructors. CelTestContext(const CelTestContext&) = delete; CelTestContext& operator=(const CelTestContext&) = delete; CelTestContext(CelTestContext&&) = delete; CelTestContext& operator=(CelTestContext&&) = delete; // Make the constructors private to enforce the use of the factory methods. explicit CelTestContext( std::unique_ptr cel_expression_builder) : cel_expression_builder_(std::move(cel_expression_builder)) {} explicit CelTestContext(std::unique_ptr runtime) : runtime_(std::move(runtime)) {} // An optional CEL compiler. This is required for test cases where // input or output values are themselves CEL expressions that need to be // resolved at runtime or cel expression source is raw string or cel file. std::unique_ptr compiler_ = nullptr; // A map of variable names to values that provides default bindings for the // evaluation. // // These bindings can be considered context-wide defaults. If a variable name // exists in both these custom bindings and in a specific TestCase's input, // the value from the TestCase will take precedence and override this one. // This logic is handled by the test runner when it constructs the final // activation. absl::flat_hash_map custom_bindings_; // The source for the CEL expression to be evaluated in the test. std::unique_ptr expression_source_; // This helps in setting up the environment for building the CEL // expression. Users should either provide a runtime, or the // CelExpressionBuilder. std::unique_ptr cel_expression_builder_; // The runtime is used to evaluate the CEL expression by managing the state // needed to generate Program. Users should either provide a runtime, or the // CelExpressionBuilder. std::unique_ptr runtime_; CelActivationFactoryFn activation_factory_; AssertFn assert_fn_; // Whether to enable coverage collection. bool enable_coverage_ = false; }; } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TOOLS_TESTRUNNER_CEL_TEST_CONTEXT_H_ ================================================ FILE: testing/testrunner/cel_test_factories.h ================================================ // Copyright 2025 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_TEST_FACTORIES_H_ #define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_TEST_FACTORIES_H_ #include #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "testing/testrunner/cel_test_context.h" #include "cel/expr/conformance/test/suite.pb.h" namespace cel::test { namespace internal { using CelTestContextFactoryFn = std::function>()>; using CelTestSuiteFactoryFn = std::function; // Returns the factory function for creating a CelTestContext. inline CelTestContextFactoryFn& GetCelTestContextFactory() { static absl::NoDestructor factory; return *factory; } // Sets the factory function for creating a CelTestContext. Only one factory // function can be set. Usage details can be found in cel_test_context.h. inline bool SetCelTestContextFactory(CelTestContextFactoryFn factory) { ABSL_DCHECK(GetCelTestContextFactory() == nullptr) << "CelTestContextFactory is already set."; GetCelTestContextFactory() = std::move(factory); return true; } // Returns the factory function for creating a CelTestSuite. inline CelTestSuiteFactoryFn& GetCelTestSuiteFactory() { static absl::NoDestructor factory; return *factory; } // Sets the factory function for creating a CelTestSuite. Only one factory // function can be set. inline bool SetCelTestSuiteFactory(CelTestSuiteFactoryFn factory) { ABSL_DCHECK(GetCelTestSuiteFactory() == nullptr) << "CelTestSuiteFactory is already set."; GetCelTestSuiteFactory() = std::move(factory); return true; } } // namespace internal // Register cel test context factories from a function or lambda. // // The return value of `factory_fn` should be a // `absl::StatusOr>>`. #define CEL_REGISTER_TEST_CONTEXT_FACTORY(factory_fn) \ namespace { \ const bool kTestContextFactoryRegistrationResult_##__LINE__ = \ ::cel::test::internal::SetCelTestContextFactory(factory_fn); \ } // Register cel test suite factory from a function or lambda. This is used to // provide a custom test suite to the test runner which is useful for cases // where the test suite is dynamically generated or where the test suite needs // to be generated from a user provided source. // // The return value of `factory_fn` should be a // `::cel::expr::conformance::test::TestSuite`. #define CEL_REGISTER_TEST_SUITE_FACTORY(factory_fn) \ namespace { \ const bool kTestSuiteFactoryRegistrationResult_##__LINE__ = \ ::cel::test::internal::SetCelTestSuiteFactory(factory_fn); \ } } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_CEL_TEST_FACTORIES_H_ ================================================ FILE: testing/testrunner/coverage_index.cc ================================================ // Copyright 2025 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "testing/testrunner/coverage_index.h" #include #include #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "common/ast.h" #include "common/value.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/instrumentation.h" #include "eval/public/cel_expression.h" #include "internal/casts.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime.h" #include "tools/cel_unparser.h" #include "tools/navigable_ast.h" namespace cel::test { namespace { using ::cel::expr::CheckedExpr; using ::cel::expr::Type; using ::google::api::expr::runtime::CelExpressionBuilder; using ::google::api::expr::runtime::Instrumentation; using ::google::api::expr::runtime::InstrumentationFactory; std::string EscapeSpecialCharacters(absl::string_view expr_text) { return absl::StrReplaceAll(expr_text, {{"\\\"", "\""}, {"\"", "\\\""}, {"\n", "\\n"}, {"||", " \\| \\| "}, {"<", "\\<"}, {">", "\\>"}, {"{", "\\{"}, {"}", "\\}"}}); } std::string KindToString(const NavigableProtoAstNode& node) { if (node.parent_relation() != ChildKind::kUnspecified && node.parent()->expr()->has_comprehension_expr()) { const cel::expr::Expr::Comprehension& comp = node.parent()->expr()->comprehension_expr(); if (node.expr()->id() == comp.iter_range().id()) return "IterRange"; if (node.expr()->id() == comp.accu_init().id()) return "AccuInit"; if (node.expr()->id() == comp.loop_condition().id()) return "LoopCondition"; if (node.expr()->id() == comp.loop_step().id()) return "LoopStep"; if (node.expr()->id() == comp.result().id()) return "Result"; } return absl::StrCat(NodeKindName(node.node_kind()), " Node"); } const Type* absl_nullable FindCheckerType(const CheckedExpr& expr, int64_t expr_id) { if (auto it = expr.type_map().find(expr_id); it != expr.type_map().end()) { return &it->second; } return nullptr; } bool InferredBooleanNode(const CheckedExpr& checked_expr, const NavigableProtoAstNode& node) { int64_t node_id = node.expr()->id(); const auto* checker_type = FindCheckerType(checked_expr, node_id); if (checker_type != nullptr) { return checker_type->has_primitive() && checker_type->primitive() == Type::BOOL; } return false; } void TraverseAndCalculateCoverage( const CheckedExpr& checked_expr, const NavigableProtoAstNode& node, const absl::flat_hash_map& stats_map, bool log_unencountered, std::string preceeding_tabs, CoverageIndex::CoverageReport& report, std::string& dot_graph) { int64_t node_id = node.expr()->id(); const CoverageIndex::NodeCoverageStats& stats = stats_map.at(node_id); report.nodes++; absl::StatusOr unparsed = google::api::expr::Unparse(*node.expr()); std::string expr_text = unparsed.ok() ? *unparsed : "unparse_failed"; bool is_interesting_bool_node = stats.is_boolean_node && !node.expr()->has_const_expr() && (!node.expr()->has_call_expr() || node.expr()->call_expr().function() != "cel.@block"); absl::string_view node_coverage_style = kUncoveredNodeStyle; if (stats.covered) { if (is_interesting_bool_node) { if (stats.has_true_branch && stats.has_false_branch) { node_coverage_style = kCompletelyCoveredNodeStyle; } else { node_coverage_style = kPartiallyCoveredNodeStyle; } } else { node_coverage_style = kCompletelyCoveredNodeStyle; } } std::string escaped_expr_text = EscapeSpecialCharacters(expr_text); dot_graph += absl::StrFormat( "%d [shape=record, %s, label=\"{<1> exprID: %d | <2> %s} | <3> %s\"];\n", node_id, node_coverage_style, node_id, KindToString(node), escaped_expr_text); bool node_covered = stats.covered; if (node_covered) { report.covered_nodes++; } else if (log_unencountered) { if (is_interesting_bool_node) { report.unencountered_nodes.push_back( absl::StrCat("Expression ID ", node_id, " ('", expr_text, "')")); } log_unencountered = false; } if (is_interesting_bool_node) { report.branches += 2; if (stats.has_true_branch) { report.covered_boolean_outcomes++; } else if (log_unencountered) { report.unencountered_branches.push_back( absl::StrCat("\n", preceeding_tabs, "Expression ID ", node_id, " ('", expr_text, "'): Never evaluated to 'true'")); preceeding_tabs += "\t\t"; } if (stats.has_false_branch) { report.covered_boolean_outcomes++; } else if (log_unencountered) { report.unencountered_branches.push_back( absl::StrCat("\n", preceeding_tabs, "Expression ID ", node_id, " ('", expr_text, "'): Never evaluated to 'false'")); preceeding_tabs += "\t\t"; } } for (const auto* child : node.children()) { dot_graph += absl::StrFormat("%d -> %d;\n", node_id, child->expr()->id()); TraverseAndCalculateCoverage(checked_expr, *child, stats_map, log_unencountered, preceeding_tabs, report, dot_graph); } } int32_t GetLineNumber(const cel::expr::SourceInfo& source_info, int32_t offset) { auto line_it = std::upper_bound(source_info.line_offsets().begin(), source_info.line_offsets().end(), offset); return std::distance(source_info.line_offsets().begin(), line_it) + 1; } } // namespace void CoverageIndex::RecordCoverage(int64_t node_id, const cel::Value& value) { NodeCoverageStats& stats = node_coverage_stats_[node_id]; stats.covered = true; if (node_coverage_stats_[node_id].is_boolean_node && value.IsBool()) { if (value.AsBool()->NativeValue()) { stats.has_true_branch = true; } else { stats.has_false_branch = true; } } } void CoverageIndex::Init(const cel::expr::CheckedExpr& checked_expr) { checked_expr_ = checked_expr; navigable_ast_ = NavigableProtoAst::Build(checked_expr_.expr()); for (const auto& node : navigable_ast_.Root().DescendantsPreorder()) { NodeCoverageStats stats; stats.is_boolean_node = InferredBooleanNode(checked_expr_, node); node_coverage_stats_[node.expr()->id()] = stats; } } CoverageIndex::CoverageReport CoverageIndex::GetCoverageReport() const { CoverageReport report; if (node_coverage_stats_.empty()) { return report; } std::string dot_graph = std::string(kDigraphHeader); TraverseAndCalculateCoverage(checked_expr_, navigable_ast_.Root(), node_coverage_stats_, true, "", report, dot_graph); dot_graph += "}\n"; report.dot_graph = dot_graph; report.cel_expression = google::api::expr::Unparse(checked_expr_).value_or(""); return report; } void CoverageIndex::WriteLCOV(absl::string_view path) { std::ofstream file(std::string(path).c_str()); if (!file.is_open()) { return; } // Maps instrumented line numbers to whether they are covered. std::map lines; const auto& positions = checked_expr_.source_info().positions(); for (const auto& [node_id, stats] : node_coverage_stats_) { auto it = positions.find(node_id); if (it == positions.end()) continue; int line_num = GetLineNumber(checked_expr_.source_info(), it->second); bool& covered = lines[line_num]; covered = covered || stats.covered; } file << "SF:" << checked_expr_.source_info().location() << "\n"; for (auto& [line_num, covered] : lines) { file << "DA:" << line_num << "," << (covered ? 1 : 0) << "\n"; } file << "end_of_record\n"; } InstrumentationFactory InstrumentationFactoryForCoverage( CoverageIndex& coverage_index) { return [&](const cel::Ast& ast) -> Instrumentation { return [&](int64_t node_id, const cel::Value& value) -> absl::Status { coverage_index.RecordCoverage(node_id, value); return absl::OkStatus(); }; }; } absl::Status EnableCoverageInRuntime(cel::Runtime& runtime, CoverageIndex& coverage_index) { auto& runtime_impl = cel::internal::down_cast(runtime); runtime_impl.expr_builder().AddProgramOptimizer( google::api::expr::runtime::CreateInstrumentationExtension( InstrumentationFactoryForCoverage(coverage_index))); return absl::OkStatus(); } absl::Status EnableCoverageInCelExpressionBuilder( CelExpressionBuilder& cel_expression_builder, CoverageIndex& coverage_index) { auto& cel_expression_builder_impl = cel::internal::down_cast< google::api::expr::runtime::CelExpressionBuilderFlatImpl&>( cel_expression_builder); cel_expression_builder_impl.flat_expr_builder().AddProgramOptimizer( google::api::expr::runtime::CreateInstrumentationExtension( InstrumentationFactoryForCoverage(coverage_index))); return absl::OkStatus(); } } // namespace cel::test ================================================ FILE: testing/testrunner/coverage_index.h ================================================ // Copyright 2025 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_INDEX_H_ #define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_INDEX_H_ #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "common/value.h" #include "eval/public/cel_expression.h" #include "runtime/runtime.h" #include "tools/navigable_ast.h" namespace cel::test { inline constexpr absl::string_view kDigraphHeader = "digraph {\n"; inline constexpr absl::string_view kUncoveredNodeStyle = R"(color="indianred2", style=filled)"; inline constexpr absl::string_view kPartiallyCoveredNodeStyle = R"(color="lightyellow", style=filled)"; inline constexpr absl::string_view kCompletelyCoveredNodeStyle = R"(color="lightgreen", style=filled)"; // `CoverageIndex` is a utility for tracking expression coverage based on the // Abstract Syntax Tree (AST) of a `cel::expr::CheckedExpr`. // // To use `CoverageIndex`, it must first be initialized with a // `cel::expr::CheckedExpr` using the `Init` method. This allows the // index to build up a representation of all the nodes and potential boolean // branches within the expression. // // The `CoverageIndex` is then integrated with the CEL evaluation process. // This is done by enabling coverage either in a `cel::Runtime` or a // `google::api::expr::runtime::CelExpressionBuilder` using the provided helper // functions (`EnableCoverageInRuntime` or // `EnableCoverageInCelExpressionBuilder`). When integrated, the CEL evaluation // engine will call `RecordCoverage` for each visited expression node, allowing // `CoverageIndex` to track which parts of the expression were executed and, // for boolean-producing nodes, which branches were taken (true/false). // // After evaluation, a `CoverageReport` can be generated, summarizing the // executed nodes and branches, and highlighting any unencountered parts of // the expression. class CoverageIndex { public: struct NodeCoverageStats { bool is_boolean_node = false; bool covered = false; bool has_true_branch = false; bool has_false_branch = false; }; struct CoverageReport { std::string cel_expression; int64_t nodes = 0; int64_t covered_nodes = 0; int64_t branches = 0; int64_t covered_boolean_outcomes = 0; std::vector unencountered_nodes; std::vector unencountered_branches; std::string dot_graph; }; // Initializes the coverage index with the given checked expression. // // The coverage index will be initialized with an entry for each node in the // AST. void Init(const cel::expr::CheckedExpr& checked_expr); // Records coverage for the given node. // // The coverage index will be updated with the coverage information for the // given node. void RecordCoverage(int64_t node_id, const cel::Value& value); // Returns a coverage report for the given checked expression. CoverageReport GetCoverageReport() const; // Writes the coverage in LCOV format to the given path. void WriteLCOV(absl::string_view path); private: absl::flat_hash_map node_coverage_stats_; NavigableProtoAst navigable_ast_; cel::expr::CheckedExpr checked_expr_; }; // Enables coverage tracking within the provided `cel::Runtime`. // Note: This function ties the `runtime` instance to a single expression. // Do not reuse this `runtime` instance with multiple expressions when coverage // is enabled, as the `coverage_index` will accumulate results across different // expressions, leading to incorrect coverage reports. absl::Status EnableCoverageInCelExpressionBuilder( google::api::expr::runtime::CelExpressionBuilder& cel_expression_builder, CoverageIndex& coverage_index); // Enables coverage tracking within the provided `CelExpressionBuilder`. // Note: This function ties the `cel_expression_builder` instance to a single // expression. Do not reuse this `cel_expression_builder` instance with // multiple expressions when coverage is enabled, as the `coverage_index` will // accumulate results across different expressions, leading to incorrect // coverage reports. absl::Status EnableCoverageInRuntime(cel::Runtime& runtime, CoverageIndex& coverage_index); } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_INDEX_H_ ================================================ FILE: testing/testrunner/coverage_index_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "testing/testrunner/coverage_index.h" #include #include #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/ast_proto.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" namespace cel::test { namespace { using ::absl_testing::IsOk; using ::cel::expr::CheckedExpr; absl::StatusOr> CreateTestRuntime() { CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder standard_runtime_builder, cel::CreateStandardRuntimeBuilder( cel::internal::GetTestingDescriptorPool(), {})); return std::move(standard_runtime_builder).Build(); } TEST(CoverageIndexTest, RecordCoverageWithErrorDoesNotCrash) { ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("x", cel::IntType())), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(compiler_builder)->Build()); ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, compiler->Compile("1/x > 1")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), IsOk()); CoverageIndex coverage_index; coverage_index.Init(checked_expr); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), coverage_index), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, cel::CreateAstFromCheckedExpr(checked_expr)); ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); cel::Activation activation; activation.InsertOrAssignValue("x", cel::IntValue(0)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(cel::Value result, program->Evaluate(&arena, activation)); EXPECT_TRUE(result.IsError()); } TEST(CoverageIndexTest, WriteLCOV) { ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("x", cel::BoolType())), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(compiler_builder)->Build()); const absl::string_view kSrc = R"(x ? true : false )"; ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, compiler->Compile(kSrc)); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), IsOk()); checked_expr.mutable_source_info()->set_location("test.cel"); CoverageIndex coverage_index; coverage_index.Init(checked_expr); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), coverage_index), IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, cel::CreateAstFromCheckedExpr(checked_expr)); ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); cel::Activation activation; activation.InsertOrAssignValue("x", cel::BoolValue(true)); google::protobuf::Arena arena; ASSERT_OK_AND_ASSIGN(cel::Value result, program->Evaluate(&arena, activation)); EXPECT_TRUE(result.GetBool().NativeValue()); std::string temp_file = absl::StrCat(testing::TempDir(), "/coverage.lcov"); coverage_index.WriteLCOV(temp_file); std::ifstream f(temp_file); std::stringstream buffer; buffer << f.rdbuf(); std::string content = buffer.str(); // Verify content. // We expect "test.cel" to be the source file. EXPECT_THAT(content, testing::HasSubstr("SF:test.cel")); // Line 1 (x ?) should be covered. EXPECT_THAT(content, testing::HasSubstr("DA:1,1")); // Line 2 (true) should be covered. EXPECT_THAT(content, testing::HasSubstr("DA:2,1")); // Line 3 (false) should be uncovered. EXPECT_THAT(content, testing::HasSubstr("DA:3,0")); // Line 4 (empty) should not be instrumented. EXPECT_THAT(content, testing::Not(testing::HasSubstr("DA:4,"))); EXPECT_THAT(content, testing::HasSubstr("end_of_record")); } } // namespace } // namespace cel::test ================================================ FILE: testing/testrunner/coverage_reporting.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "testing/testrunner/coverage_reporting.h" #include #include #include #include #include #include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "internal/testing.h" #include "testing/testrunner/coverage_index.h" namespace cel::test { void CoverageReportingEnvironment::TearDown() { CoverageIndex::CoverageReport coverage_report = coverage_index_.GetCoverageReport(); testing::Test::RecordProperty("CEL Expression", coverage_report.cel_expression); std::cout << "CEL Expression: " << coverage_report.cel_expression; if (coverage_report.nodes == 0) { testing::Test::RecordProperty("CEL Coverage", "No coverage stats found"); std::cout << "CEL Coverage: " << "No coverage stats found"; return; } // Log Node Coverage results double node_coverage = static_cast(coverage_report.covered_nodes) / static_cast(coverage_report.nodes) * 100.0; std::string node_coverage_string = absl::StrFormat("%.2f%% (%d out of %d nodes covered)", node_coverage, coverage_report.covered_nodes, coverage_report.nodes); testing::Test::RecordProperty("AST Node Coverage", node_coverage_string); std::cout << "AST Node Coverage: " << node_coverage_string; if (!coverage_report.unencountered_nodes.empty()) { testing::Test::RecordProperty( "Interesting Unencountered Nodes", absl::StrJoin(coverage_report.unencountered_nodes, "\n")); std::cout << "Interesting Unencountered Nodes: " << absl::StrJoin(coverage_report.unencountered_nodes, "\n"); } // Log Branch Coverage results double branch_coverage = 0.0; if (coverage_report.branches > 0) { branch_coverage = static_cast(coverage_report.covered_boolean_outcomes) / static_cast(coverage_report.branches) * 100.0; } std::string branch_coverage_string = absl::StrFormat( "%.2f%% (%d out of %d branch outcomes covered)", branch_coverage, coverage_report.covered_boolean_outcomes, coverage_report.branches); testing::Test::RecordProperty("AST Branch Coverage", branch_coverage_string); std::cout << "AST Branch Coverage: " << branch_coverage_string; if (!coverage_report.unencountered_branches.empty()) { testing::Test::RecordProperty( "Interesting Unencountered Branch Paths", absl::StrJoin(coverage_report.unencountered_branches, "\n")); std::cout << "Interesting Unencountered Branch Paths: " << absl::StrJoin(coverage_report.unencountered_branches, "\n"); } if (!coverage_report.dot_graph.empty()) { WriteDotGraphToArtifact(coverage_report.dot_graph); } } void CoverageReportingEnvironment::WriteDotGraphToArtifact( absl::string_view dot_graph) { // Save DOT graph to file in TEST_UNDECLARED_OUTPUTS_DIR or default dir const char* outputs_dir_env = std::getenv("TEST_UNDECLARED_OUTPUTS_DIR"); // For non-Bazel/Blaze users, we write to a subdirectory under the current // working directory. // NOMUTANTS --cel_artifacts is for non-Bazel/Blaze users only so not // needed to test in our case. std::string outputs_dir = (outputs_dir_env == nullptr) ? "cel_artifacts" : outputs_dir_env; std::string coverage_dir = absl::StrCat(outputs_dir, "/cel_test_coverage"); // Creates the directory to store CEL test coverage artifacts. // The second argument, `0755`, sets the directory's permissions in octal // format, which is a standard for file system operations. It grants: // - Owner: read, write, and execute permissions (7 = 4+2+1). // - Group: read and execute permissions (5 = 4+1). // - Others: read and execute permissions (5 = 4+1). // This gives the owner full control while allowing other users to access // the generated artifacts. int mkdir_result = mkdir(coverage_dir.c_str(), 0755); // If mkdir fails, it sets the global 'errno' variable to an error code // indicating the reason. We check this code to specifically ignore the // EEXIST error, which just means the directory already exists (this is not // a real failure we need to warn about). if (mkdir_result == 0 || errno == EEXIST) { std::string graph_path = absl::StrCat(coverage_dir, "/coverage_graph.txt"); std::ofstream out(graph_path); if (out.is_open()) { out << dot_graph; out.close(); } else { ABSL_LOG(WARNING) << "Failed to open file for writing: " << graph_path; } } else { ABSL_LOG(WARNING) << "Failed to create directory: " << coverage_dir << " (reason: " << strerror(errno) << ")"; } } } // namespace cel::test ================================================ FILE: testing/testrunner/coverage_reporting.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_REPORTING_H_ #define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_REPORTING_H_ #include "absl/strings/string_view.h" #include "internal/testing.h" #include "testing/testrunner/coverage_index.h" namespace cel::test { // A Google Test Environment that reports CEL coverage results in its TearDown // phase. // // This class encapsulates the logic for calculating coverage statistics and // logging them as test properties. class CoverageReportingEnvironment : public testing::Environment { public: explicit CoverageReportingEnvironment(CoverageIndex& coverage_index) : coverage_index_(coverage_index) {}; // Called by the Google Test framework after all tests have run. void TearDown() override; private: // Helper function to write the DOT graph to a test artifact file. void WriteDotGraphToArtifact(absl::string_view dot_graph); CoverageIndex& coverage_index_; }; } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_COVERAGE_REPORTING_H_ ================================================ FILE: testing/testrunner/resources/BUILD ================================================ package(default_visibility = ["//visibility:public"]) exports_files( [ "test.cel", ], ) filegroup( name = "resources", srcs = glob([ "*.textproto", ]), ) ================================================ FILE: testing/testrunner/resources/simple_tests.textproto ================================================ # proto-file: google3/third_party/cel/spec/proto/cel/expr/conformance/test/suite.proto # proto-message: cel.expr.conformance.test.TestSuite name: "simple_tests" description: "Simple tests to validate the test runner." sections: { name: "simple_map_operations" description: "Tests for map operations." tests: { name: "literal_and_sum" description: "Test that a map can be created and values can be accessed." input: { key: "x" value { value { int64_value: 1 } } } input { key: "y" value { value { int64_value: 2 } } } output { result_value { bool_value: true } } } tests: { name: "literal_and_sum_2_5" description: "Test that a map can be created and values can be accessed." input: { key: "x" value { value { int64_value: 2 } } } input { key: "y" value { value { int64_value: 5 } } } output { result_value { bool_value: false } } } } ================================================ FILE: testing/testrunner/resources/test.cel ================================================ x-y ================================================ FILE: testing/testrunner/resources/test_environment.textproto ================================================ # proto-file: third_party/cel/go/tools/compilecli/compile_input.proto # proto-message: Environment declarations: { name: "x" ident: { type: { primitive: INT64 } } } declarations: { name: "y" ident: { type: { primitive: INT64 } } } ================================================ FILE: testing/testrunner/runner_bin.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This binary is a test runner for CEL tests. It is used to run CEL tests // written in the CEL test suite format. #include #include #include #include #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "absl/flags/flag.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "eval/public/cel_expression.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/runtime.h" #include "testing/testrunner/cel_expression_source.h" #include "testing/testrunner/cel_test_context.h" #include "testing/testrunner/cel_test_factories.h" #include "testing/testrunner/coverage_index.h" #include "testing/testrunner/coverage_reporting.h" #include "testing/testrunner/runner_lib.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/text_format.h" ABSL_FLAG(std::string, test_suite_path, "", "The path to the file containing the test suite to run."); ABSL_FLAG(std::string, expr_source_type, "", "The kind of expression source: 'raw', 'file', or 'checked'."); ABSL_FLAG(std::string, expr_source, "", "The value of the CEL expression source. For 'raw', it's the " "expression string. For 'file' and 'checked', it's the file path."); ABSL_FLAG(bool, collect_coverage, false, "Whether to collect code coverage."); namespace { using ::cel::expr::conformance::test::TestCase; using ::cel::expr::conformance::test::TestSuite; using ::cel::test::CelExpressionSource; using ::cel::test::CelTestContext; using ::cel::test::CoverageIndex; using ::cel::test::TestRunner; using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::CelExpressionBuilder; class CelTest : public testing::Test { public: explicit CelTest(std::shared_ptr test_runner, const TestCase& test_case) : test_runner_(std::move(test_runner)), test_case_(test_case) {} void TestBody() override { test_runner_->RunTest(test_case_); } private: std::shared_ptr test_runner_; TestCase test_case_; }; absl::Status RegisterTests(const TestSuite& test_suite, const std::shared_ptr& test_runner) { for (const auto& section : test_suite.sections()) { for (const TestCase& test_case : section.tests()) { testing::RegisterTest( test_suite.name().c_str(), absl::StrCat(section.name(), "/", test_case.name()).c_str(), nullptr, nullptr, __FILE__, __LINE__, [&test_runner, test_case]() -> CelTest* { return new CelTest(test_runner, test_case); }); } } return absl::OkStatus(); } absl::StatusOr ReadFileToString(absl::string_view file_path) { std::ifstream file_stream{std::string(file_path)}; if (!file_stream.is_open()) { return absl::NotFoundError( absl::StrCat("Unable to open file: ", file_path)); } std::stringstream buffer; buffer << file_stream.rdbuf(); return buffer.str(); } template absl::StatusOr ReadTextProtoFromFile(absl::string_view file_path) { CEL_ASSIGN_OR_RETURN(std::string contents, ReadFileToString(file_path)); T message; if (!google::protobuf::TextFormat::ParseFromString(contents, &message)) { return absl::InternalError(absl::StrCat( "Failed to parse text-format proto from file: ", file_path)); } return message; } absl::StatusOr ReadBinaryProtoFromFile( absl::string_view file_path) { CheckedExpr message; std::ifstream file_stream{std::string(file_path), std::ios::binary}; if (!file_stream.is_open()) { return absl::NotFoundError( absl::StrCat("Unable to open file: ", file_path)); } if (!message.ParseFromIstream(&file_stream)) { return absl::InternalError( absl::StrCat("Failed to parse binary proto from file: ", file_path)); } return message; } TestSuite ReadTestSuiteFromPath(absl::string_view test_suite_path) { absl::StatusOr test_suite_or = ReadTextProtoFromFile(test_suite_path); if (!test_suite_or.ok()) { ABSL_LOG(FATAL) << "Failed to load test suite from " << test_suite_path << ": " << test_suite_or.status(); } return *std::move(test_suite_or); } absl::StatusOr ReadCheckedExprFromFile( absl::string_view file_path) { if (absl::EndsWith(file_path, ".textproto")) { return ReadTextProtoFromFile(file_path); } if (absl::EndsWith(file_path, ".binarypb")) { return ReadBinaryProtoFromFile(file_path); } return absl::InvalidArgumentError(absl::StrCat( "Unknown file extension for checked expression. ", "Please use .textproto, .textpb, .pb, or .binarypb: ", file_path)); } TestSuite GetTestSuite() { std::string test_suite_path = absl::GetFlag(FLAGS_test_suite_path); if (!test_suite_path.empty()) { return ReadTestSuiteFromPath(test_suite_path); } // If no test suite path is provided, use the factory function to get the // test suite after checking if the factory function is empty or not. std::function test_suite_factory = cel::test::internal::GetCelTestSuiteFactory(); if (test_suite_factory == nullptr) { ABSL_LOG(FATAL) << "No CEL test suite provided. Please provide a test suite using " "either the bzl macro or the CEL_REGISTER_TEST_SUITE_FACTORY " "preprocessor macro."; } return test_suite_factory(); } void UpdateWithExpressionFromCommandLineFlags( CelTestContext& cel_test_context) { if (absl::GetFlag(FLAGS_expr_source).empty()) { return; } constexpr absl::string_view kRawExpressionKind = "raw"; constexpr absl::string_view kFileExpressionKind = "file"; constexpr absl::string_view kCheckedExpressionKind = "checked"; std::string kind = absl::GetFlag(FLAGS_expr_source_type); std::string value = absl::GetFlag(FLAGS_expr_source); std::optional expression_source_from_flags; if (kind == kRawExpressionKind) { expression_source_from_flags = CelExpressionSource::FromRawExpression(value); } else if (kind == kFileExpressionKind) { expression_source_from_flags = CelExpressionSource::FromCelFile(value); } else if (kind == kCheckedExpressionKind) { absl::StatusOr checked_expr = ReadCheckedExprFromFile(value); if (!checked_expr.ok()) { ABSL_LOG(FATAL) << "Failed to read checked expression from file: " << checked_expr.status(); } expression_source_from_flags = CelExpressionSource::FromCheckedExpr(std::move(*checked_expr)); } else { ABSL_LOG(FATAL) << "Unknown expression kind: " << kind; } // Check for conflicting expression sources. if (cel_test_context.expression_source() != nullptr) { ABSL_LOG(FATAL) << "Expression source can only be set once and is currently set via " "the factory."; } if (expression_source_from_flags.has_value()) { cel_test_context.SetExpressionSource( std::move(*expression_source_from_flags)); } } } // namespace int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); // Create a test context using the factory function returned by the global // factory function provider which was initialized by the user. absl::StatusOr> cel_test_context_or = cel::test::internal::GetCelTestContextFactory()(); if (!cel_test_context_or.ok()) { ABSL_LOG(FATAL) << "Failed to create CEL test context from factory: " << cel_test_context_or.status(); } std::unique_ptr cel_test_context = std::move(cel_test_context_or.value()); // We manually enable coverage here instead of just setting the // `enable_coverage` flag on the context. This is intentional and necessary // for this binary's reporting model. // // This binary needs a single coverage report for all tests run. // We create `coverage_index` here, local to the `main` function, so its // lifetime spans the entire test run. // // We must pass this specific instance to the // `CoverageReportingEnvironment`, which Google Test calls after all // dynamically registered tests are finished. // // If we just set the `enable_coverage` flag, the `TestRunner`'s // constructor (as used in our `cc_test` files) would create its own // internal `CoverageIndex`. That internal index would be destroyed // with the `TestRunner` and would not populate the `coverage_index` // instance needed by our global reporter. // // This manual approach ensures all tests populate the same `coverage_index` // (the one local to `main`), which is then ready for the final report. cel::test::CoverageIndex coverage_index; if (absl::GetFlag(FLAGS_collect_coverage)) { if (cel_test_context->runtime() != nullptr) { ABSL_CHECK_OK(cel::test::EnableCoverageInRuntime( const_cast(*cel_test_context->runtime()), coverage_index)); } else if (cel_test_context->cel_expression_builder() != nullptr) { ABSL_CHECK_OK(cel::test::EnableCoverageInCelExpressionBuilder( const_cast( *cel_test_context->cel_expression_builder()), coverage_index)); } } // Update the context with an expression from flags, if provided. // This will FATAL if an expression is set by both the factory and flags. UpdateWithExpressionFromCommandLineFlags(*cel_test_context); auto test_runner = std::make_shared(std::move(cel_test_context)); ABSL_CHECK_OK(RegisterTests(GetTestSuite(), test_runner)); // Make sure the checked expression exists during the entire test run since // the ast references it during coverage collection at teardown. absl::StatusOr checked_expr = test_runner->GetCheckedExpr(); if (!checked_expr.ok()) { ABSL_LOG(FATAL) << "Failed to get checked expression: " << checked_expr.status(); } if (absl::GetFlag(FLAGS_collect_coverage)) { coverage_index.Init(*checked_expr); testing::AddGlobalTestEnvironment( new cel::test::CoverageReportingEnvironment(coverage_index)); } return RUN_ALL_TESTS(); } ================================================ FILE: testing/testrunner/runner_lib.cc ================================================ // Copyright 2025 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "testing/testrunner/runner_lib.h" #include #include #include #include #include #include #include "cel/expr/eval.pb.h" #include "absl/functional/overload.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/ast_proto.h" #include "common/internal/value_conversion.h" #include "common/value.h" #include "eval/public/activation.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/public/transform_utility.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "testing/testrunner/cel_expression_source.h" #include "testing/testrunner/cel_test_context.h" #include "testing/testrunner/coverage_index.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/util/field_comparator.h" #include "google/protobuf/util/message_differencer.h" namespace cel::test { namespace { using ::cel::expr::conformance::test::InputValue; using ::cel::expr::conformance::test::TestCase; using ::cel::expr::conformance::test::TestOutput; using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::CelExpression; using ::google::api::expr::runtime::ValueToCelValue; using ::google::api::expr::runtime::Activation; using LegacyCelValue = ::google::api::expr::runtime::CelValue; using ValueProto = ::cel::expr::Value; absl::StatusOr ReadFileToString(absl::string_view file_path) { std::ifstream file_stream{std::string(file_path)}; if (!file_stream.is_open()) { return absl::NotFoundError( absl::StrCat("Unable to open file: ", file_path)); } std::stringstream buffer; buffer << file_stream.rdbuf(); return buffer.str(); } absl::StatusOr Compile(absl::string_view expression, const CelTestContext& context) { const auto* compiler = context.compiler(); if (compiler == nullptr) { return absl::InvalidArgumentError( "A compiler must be provided to compile a raw expression or .cel " "file."); } CEL_ASSIGN_OR_RETURN(ValidationResult validation_result, compiler->Compile(expression)); if (!validation_result.IsValid()) { return absl::InternalError(validation_result.FormatError()); } CheckedExpr checked_expr; CEL_RETURN_IF_ERROR( AstToCheckedExpr(*validation_result.GetAst(), &checked_expr)); return checked_expr; } absl::StatusOr> Plan( const CheckedExpr& checked_expr, const cel::Runtime* runtime) { std::unique_ptr ast; CEL_ASSIGN_OR_RETURN(ast, cel::CreateAstFromCheckedExpr(checked_expr)); if (ast == nullptr) { return absl::InternalError("No expression provided for testing."); } return runtime->CreateProgram(std::move(ast)); } const google::protobuf::DescriptorPool* GetDescriptorPool(const CelTestContext& context) { return context.cel_expression_builder() != nullptr ? google::protobuf::DescriptorPool::generated_pool() : context.runtime()->GetDescriptorPool(); } google::protobuf::MessageFactory* GetMessageFactory(const CelTestContext& context) { return context.cel_expression_builder() != nullptr ? google::protobuf::MessageFactory::generated_factory() : context.runtime()->GetMessageFactory(); } absl::StatusOr EvalWithModernBindings( const CheckedExpr& checked_expr, const CelTestContext& context, const cel::Activation& activation, google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN(std::unique_ptr program, Plan(checked_expr, context.runtime())); return program->Evaluate(arena, activation); } absl::StatusOr EvalWithLegacyBindings( const CheckedExpr& checked_expr, const CelTestContext& context, const Activation& activation, google::protobuf::Arena* arena) { const auto* builder = context.cel_expression_builder(); CEL_ASSIGN_OR_RETURN(std::unique_ptr sub_expression, builder->CreateExpression(&checked_expr)); CEL_ASSIGN_OR_RETURN(LegacyCelValue legacy_result, sub_expression->Evaluate(activation, arena)); ValueProto result_proto; CEL_RETURN_IF_ERROR(CelValueToValue(legacy_result, &result_proto)); return FromExprValue(result_proto, GetDescriptorPool(context), GetMessageFactory(context), arena); } absl::StatusOr ResolveValue(const InputValue& input_value, const CelTestContext& context, google::protobuf::Arena* arena) { return FromExprValue(input_value.value(), GetDescriptorPool(context), GetMessageFactory(context), arena); } absl::StatusOr ResolveExpr(absl::string_view expr, const CelTestContext& context, google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN(CheckedExpr checked_expr, Compile(expr, context)); if (context.runtime() != nullptr) { cel::Activation empty_activation; return EvalWithModernBindings(checked_expr, context, empty_activation, arena); } else { Activation empty_activation; return EvalWithLegacyBindings(checked_expr, context, empty_activation, arena); } } absl::StatusOr ResolveInputValue(const InputValue& input_value, const CelTestContext& context, google::protobuf::Arena* arena) { switch (input_value.kind_case()) { case InputValue::kValue: { return ResolveValue(input_value, context, arena); } case InputValue::kExpr: { return ResolveExpr(input_value.expr(), context, arena); } default: return absl::InvalidArgumentError("Unknown InputValue kind."); } } absl::Status AddCustomBindingsToModernActivation(const CelTestContext& context, cel::Activation& activation, google::protobuf::Arena* arena) { for (const auto& binding : context.custom_bindings()) { CEL_ASSIGN_OR_RETURN(cel::Value value, FromExprValue(/*value_proto=*/binding.second, GetDescriptorPool(context), GetMessageFactory(context), arena)); activation.InsertOrAssignValue(/*name=*/binding.first, value); } return absl::OkStatus(); } absl::Status AddTestCaseBindingsToModernActivation( const TestCase& test_case, const CelTestContext& context, cel::Activation& activation, google::protobuf::Arena* arena) { for (const auto& binding : test_case.input()) { CEL_ASSIGN_OR_RETURN( cel::Value value, ResolveInputValue(/*input_value=*/binding.second, context, arena)); activation.InsertOrAssignValue(/*name=*/binding.first, std::move(value)); } return absl::OkStatus(); } absl::StatusOr GetActivation(const CelTestContext& context, const TestCase& test_case, google::protobuf::Arena* arena) { if (context.activation_factory() != nullptr) { return context.activation_factory()(test_case, arena); } return cel::Activation(); } absl::StatusOr CreateModernActivationFromBindings( const TestCase& test_case, const CelTestContext& context, google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN(cel::Activation activation, GetActivation(context, test_case, arena)); CEL_RETURN_IF_ERROR( AddCustomBindingsToModernActivation(context, activation, arena)); CEL_RETURN_IF_ERROR(AddTestCaseBindingsToModernActivation(test_case, context, activation, arena)); return activation; } absl::Status AddCustomBindingsToLegacyActivation(const CelTestContext& context, Activation& activation, google::protobuf::Arena* arena) { for (const auto& binding : context.custom_bindings()) { CEL_ASSIGN_OR_RETURN( LegacyCelValue value, ValueToCelValue(/*value_proto=*/binding.second, arena)); activation.InsertValue(/*name=*/binding.first, value); } return absl::OkStatus(); } absl::Status AddTestCaseBindingsToLegacyActivation( const TestCase& test_case, const CelTestContext& context, Activation& activation, google::protobuf::Arena* arena) { auto* message_factory = GetMessageFactory(context); auto* descriptor_pool = GetDescriptorPool(context); for (const auto& binding : test_case.input()) { CEL_ASSIGN_OR_RETURN( cel::Value resolved_cel_value, ResolveInputValue(/*input_value=*/binding.second, context, arena)); CEL_ASSIGN_OR_RETURN(ValueProto value_proto, ToExprValue(resolved_cel_value, descriptor_pool, message_factory, arena)); CEL_ASSIGN_OR_RETURN(LegacyCelValue value, ValueToCelValue(value_proto, arena)); activation.InsertValue(/*name=*/binding.first, value); } return absl::OkStatus(); } absl::StatusOr CreateLegacyActivationFromBindings( const TestCase& test_case, const CelTestContext& context, google::protobuf::Arena* arena) { Activation activation; CEL_RETURN_IF_ERROR( AddCustomBindingsToLegacyActivation(context, activation, arena)); CEL_RETURN_IF_ERROR(AddTestCaseBindingsToLegacyActivation(test_case, context, activation, arena)); return activation; } bool IsEqual(const ValueProto& expected, const ValueProto& actual) { static auto* kFieldComparator = []() { auto* field_comparator = new google::protobuf::util::DefaultFieldComparator(); field_comparator->set_treat_nan_as_equal(true); return field_comparator; }(); static auto* kDifferencer = []() { auto* differencer = new google::protobuf::util::MessageDifferencer(); differencer->set_message_field_comparison( google::protobuf::util::MessageDifferencer::EQUIVALENT); differencer->set_field_comparator(kFieldComparator); const auto* descriptor = cel::expr::MapValue::descriptor(); const auto* entries_field = descriptor->FindFieldByName("entries"); const auto* key_field = entries_field->message_type()->FindFieldByName("key"); differencer->TreatAsMap(entries_field, key_field); return differencer; }(); return kDifferencer->Compare(expected, actual); } MATCHER_P(MatchesValue, expected, "") { return IsEqual(arg, expected); } } // namespace void TestRunner::AssertValue(const cel::Value& computed, const TestOutput& output, google::protobuf::Arena* arena) { if (computed.IsError()) { ADD_FAILURE() << "Expected value but got error: " << computed.DebugString(); return; } ValueProto expected_value_proto; const auto* descriptor_pool = GetDescriptorPool(*test_context_); auto* message_factory = GetMessageFactory(*test_context_); if (output.has_result_value()) { expected_value_proto = output.result_value(); } else if (output.has_result_expr()) { InputValue input_value; input_value.set_expr(output.result_expr()); ASSERT_OK_AND_ASSIGN(cel::Value resolved_cel_value, ResolveInputValue(input_value, *test_context_, arena)); ASSERT_OK_AND_ASSIGN(expected_value_proto, ToExprValue(resolved_cel_value, descriptor_pool, message_factory, arena)); } ValueProto computed_expr_value; ASSERT_OK_AND_ASSIGN( computed_expr_value, ToExprValue(computed, descriptor_pool, message_factory, arena)); EXPECT_THAT(computed_expr_value, MatchesValue(expected_value_proto)); } void TestRunner::AssertError(const cel::Value& computed, const TestOutput& output) { if (!computed.IsError()) { ADD_FAILURE() << "Expected error but got value: " << computed.DebugString(); return; } absl::Status computed_status = computed.AsError()->ToStatus(); // We selected the first error in the set for comparison because there is only // one runtime error that is reported even if there are multiple errors in the // critical path. ASSERT_TRUE(output.eval_error().errors_size() == 1) << "Expected exactly one error but got: " << output.eval_error().errors_size(); ASSERT_EQ(computed_status.message(), output.eval_error().errors(0).message()); } void TestRunner::Assert(const cel::Value& computed, const TestCase& test_case, google::protobuf::Arena* arena) { if (test_context_->assert_fn()) { test_context_->assert_fn()(computed, test_case, arena); return; } TestOutput output = test_case.output(); if (output.has_result_value() || output.has_result_expr()) { AssertValue(computed, output, arena); } else if (output.has_eval_error()) { AssertError(computed, output); } else if (output.has_unknown()) { ADD_FAILURE() << "Unknown assertions not implemented yet."; } else { ADD_FAILURE() << "Unexpected output kind."; } } absl::StatusOr TestRunner::EvalWithRuntime( const CheckedExpr& checked_expr, const TestCase& test_case, google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN( cel::Activation activation, CreateModernActivationFromBindings(test_case, *test_context_, arena)); return EvalWithModernBindings(checked_expr, *test_context_, activation, arena); } absl::StatusOr TestRunner::EvalWithCelExpressionBuilder( const CheckedExpr& checked_expr, const TestCase& test_case, google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN( Activation activation, CreateLegacyActivationFromBindings(test_case, *test_context_, arena)); return EvalWithLegacyBindings(checked_expr, *test_context_, activation, arena); } absl::StatusOr TestRunner::GetCheckedExpr() const { const CelExpressionSource* source_ptr = test_context_->expression_source(); if (source_ptr == nullptr) { return absl::InvalidArgumentError("No expression source provided."); } return std::visit( absl::Overload([](const cel::expr::CheckedExpr& v) -> absl::StatusOr { return v; }, [this](const CelExpressionSource::RawExpression& v) -> absl::StatusOr { return Compile(v.value, *test_context_); }, [this](const CelExpressionSource::CelFile& v) -> absl::StatusOr { CEL_ASSIGN_OR_RETURN(std::string contents, ReadFileToString(v.path)); return Compile(contents, *test_context_); }), source_ptr->source()); } absl::Status TestRunner::EnableCoverage() { if (test_context_ != nullptr && test_context_->enable_coverage()) { coverage_index_ = std::make_unique(); if (test_context_->runtime() != nullptr) { auto* runtime = const_cast(test_context_->runtime()); CEL_RETURN_IF_ERROR(EnableCoverageInRuntime(*runtime, *coverage_index_)); } else if (test_context_->cel_expression_builder() != nullptr) { auto* builder = const_cast( test_context_->cel_expression_builder()); CEL_RETURN_IF_ERROR( EnableCoverageInCelExpressionBuilder(*builder, *coverage_index_)); } } return absl::OkStatus(); } void TestRunner::RunTest(const TestCase& test_case) { // The arena has to be declared in RunTest because cel::Value returned by // EvalWithRuntime or EvalWithCelExpressionBuilder might contain pointers to // the arena. The arena has to be alive during the assertion. google::protobuf::Arena arena; ASSERT_THAT(EnableCoverage(), absl_testing::IsOk()); ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr, GetCheckedExpr()); if (coverage_index_) { coverage_index_->Init(checked_expr); } if (test_context_->runtime() != nullptr) { ASSERT_OK_AND_ASSIGN(cel::Value result, EvalWithRuntime(checked_expr, test_case, &arena)); ASSERT_NO_FATAL_FAILURE(Assert(result, test_case, &arena)); } else if (test_context_->cel_expression_builder() != nullptr) { ASSERT_OK_AND_ASSIGN( cel::Value result, EvalWithCelExpressionBuilder(checked_expr, test_case, &arena)); ASSERT_NO_FATAL_FAILURE(Assert(result, test_case, &arena)); } } } // namespace cel::test ================================================ FILE: testing/testrunner/runner_lib.h ================================================ // Copyright 2025 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RUNNER_LIBRARY_H_ #define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RUNNER_LIBRARY_H_ #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "common/value.h" #include "testing/testrunner/cel_test_context.h" #include "testing/testrunner/coverage_index.h" #include "testing/testrunner/coverage_reporting.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/arena.h" namespace cel::test { // The test runner class for running CEL tests. class TestRunner { public: explicit TestRunner(std::unique_ptr test_context) : test_context_(std::move(test_context)) {} // Automatically reports coverage results. ~TestRunner() { if (coverage_index_) { CoverageReportingEnvironment reporter(*coverage_index_); reporter.TearDown(); } } // Evaluates the checked expression in the test case, performs the // assertions against the expected result. void RunTest(const cel::expr::conformance::test::TestCase& test_case); // Returns the checked expression for the test case. absl::StatusOr GetCheckedExpr() const; private: absl::StatusOr EvalWithRuntime( const cel::expr::CheckedExpr& checked_expr, const cel::expr::conformance::test::TestCase& test_case, google::protobuf::Arena* arena); absl::StatusOr EvalWithCelExpressionBuilder( const cel::expr::CheckedExpr& checked_expr, const cel::expr::conformance::test::TestCase& test_case, google::protobuf::Arena* arena); void Assert(const cel::Value& computed, const cel::expr::conformance::test::TestCase& test_case, google::protobuf::Arena* arena); void AssertValue(const cel::Value& computed, const cel::expr::conformance::test::TestOutput& output, google::protobuf::Arena* arena); void AssertError(const cel::Value& computed, const cel::expr::conformance::test::TestOutput& output); absl::Status EnableCoverage(); std::unique_ptr test_context_; std::unique_ptr coverage_index_; }; } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RUNNER_LIBRARY_H_ ================================================ FILE: testing/testrunner/runner_lib_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "testing/testrunner/runner_lib.h" #include #include #include #include #include "gtest/gtest-spi.h" #include "absl/container/flat_hash_map.h" #include "absl/flags/flag.h" #include "absl/log/absl_check.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast_proto.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/standard_runtime_builder_factory.h" #include "testing/testrunner/cel_expression_source.h" #include "testing/testrunner/cel_test_context.h" #include "testing/testrunner/coverage_index.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" ABSL_FLAG(std::string, test_cel_file_path, "", "Path to the .cel file for testing"); namespace cel::test { namespace { using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::expr::conformance::test::TestCase; using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::CelExpressionBuilder; using ValueProto = ::cel::expr::Value; using ::testing::EndsWith; using ::testing::HasSubstr; using ::testing::Not; using ::testing::StartsWith; template T ParseTextProtoOrDie(absl::string_view text_proto) { T result; ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); return result; } int CountSubstrings(absl::string_view text, absl::string_view substr) { int count = 0; size_t pos = 0; while ((pos = text.find(substr, pos)) != absl::string_view::npos) { ++count; ++pos; } return count; } absl::StatusOr> CreateBasicCompiler() { CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); CEL_RETURN_IF_ERROR( checker_builder.AddVariable(cel::MakeVariableDecl("x", cel::IntType()))); CEL_RETURN_IF_ERROR( checker_builder.AddVariable(cel::MakeVariableDecl("y", cel::IntType()))); return std::move(builder)->Build(); } absl::StatusOr> CreateTestRuntime() { CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder standard_runtime_builder, cel::CreateStandardRuntimeBuilder( cel::internal::GetTestingDescriptorPool(), {})); return std::move(standard_runtime_builder).Build(); } absl::StatusOr> CreateTestCelExpressionBuilder() { auto builder = google::api::expr::runtime::CreateCelExpressionBuilder(); CEL_RETURN_IF_ERROR(google::api::expr::runtime::RegisterBuiltinFunctions( builder->GetRegistry())); return builder; } // Creates a static, singleton instance of the basic compiler to be shared // across tests, avoiding repeated setup costs. const cel::Compiler& DefaultCompiler() { static const cel::Compiler* instance = []() { absl::StatusOr> s = CreateBasicCompiler(); ABSL_QCHECK_OK(s.status()); return s->release(); }(); return *instance; } enum class RuntimeApi { kRuntime, kBuilder }; // Parameterized test fixture for tests that are run against both the Runtime // and the CelExpressionBuilder backends. class TestRunnerParamTest : public ::testing::TestWithParam { protected: // Helper to create the appropriate CelTestContext based on the test // parameter. absl::StatusOr> CreateTestContext() { if (GetParam() == RuntimeApi::kRuntime) { CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, CreateTestRuntime()); return CelTestContext::CreateFromRuntime(std::move(runtime)); } CEL_ASSIGN_OR_RETURN(std::unique_ptr builder, CreateTestCelExpressionBuilder()); return CelTestContext::CreateFromCelExpressionBuilder(std::move(builder)); } }; TEST_P(TestRunnerParamTest, BasicTestReportsSuccess) { ASSERT_OK_AND_ASSIGN( cel::ValidationResult validation_result, DefaultCompiler().Compile("{'sum': x + y, 'literal': 3}")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 1 } } } input { key: "y" value { value { int64_value: 2 } } } output { result_value { map_value { entries { key { string_value: "literal" } value { int64_value: 3 } } entries { key { string_value: "sum" } value { int64_value: 3 } } } } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST_P(TestRunnerParamTest, BasicTestReportsFailure) { ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("x + y == 3")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 1 } } } input { key: "y" value { value { int64_value: 2 } } } output { result_value { bool_value: false } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "bool_value: true"); // expected true got false } TEST_P(TestRunnerParamTest, DynamicInputAndOutputReportsSuccess) { ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("x + y")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { expr: "1 + 1" } } input { key: "y" value { expr: "10 - 7" } } output { result_expr: "7 - 2" } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, CreateBasicCompiler()); context->SetCompiler(std::move(compiler)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST_P(TestRunnerParamTest, DynamicInputAndOutputReportsFailure) { ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("x + y")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { expr: "1 + 1" } } input { key: "y" value { expr: "10 - 7" } } output { result_expr: "10" } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, CreateBasicCompiler()); context->SetCompiler(std::move(compiler)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 5"); // expected 5 got 10 } TEST_P(TestRunnerParamTest, RawExpressionWithCompilerReportsSuccess) { TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 10 } } } input { key: "y" value { value { int64_value: 3 } } } output { result_value { int64_value: 7 } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, CreateBasicCompiler()); context->SetCompiler(std::move(compiler)); context->SetExpressionSource(CelExpressionSource::FromRawExpression("x - y")); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST_P(TestRunnerParamTest, RawExpressionWithCompilerReportsFailure) { TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 10 } } } input { key: "y" value { value { int64_value: 3 } } } output { result_value { int64_value: 100 } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, CreateBasicCompiler()); context->SetCompiler(std::move(compiler)); context->SetExpressionSource(CelExpressionSource::FromRawExpression("x - y")); TestRunner test_runner(std::move(context)); EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 7"); // expected 7 got 100 } TEST_P(TestRunnerParamTest, CelFileWithCompilerReportsSuccess) { const std::string cel_file_path = absl::GetFlag(FLAGS_test_cel_file_path); ASSERT_FALSE(cel_file_path.empty()) << "Flag --test_cel_file_path must be set"; TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 10 } } } input { key: "y" value { value { int64_value: 3 } } } output { result_value { int64_value: 7 } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, CreateBasicCompiler()); context->SetCompiler(std::move(compiler)); context->SetExpressionSource(CelExpressionSource::FromCelFile(cel_file_path)); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST_P(TestRunnerParamTest, CelFileWithCompilerReportsFailure) { const std::string cel_file_path = absl::GetFlag(FLAGS_test_cel_file_path); ASSERT_FALSE(cel_file_path.empty()) << "Flag --test_cel_file_path must be set"; TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 10 } } } input { key: "y" value { value { int64_value: 3 } } } output { result_value { int64_value: 123 } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, CreateBasicCompiler()); context->SetCompiler(std::move(compiler)); context->SetExpressionSource(CelExpressionSource::FromCelFile(cel_file_path)); TestRunner test_runner(std::move(context)); EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 7"); // expected 7 got 123 } TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsSucceeds) { ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("x + y")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 10 } } } output { result_value { int64_value: 15 } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); absl::flat_hash_map bindings; bindings["y"] = ParseTextProtoOrDie(R"pb(int64_value: 5)pb"); context->SetCustomBindings(std::move(bindings)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsReportsFailure) { ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("x + y")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 10 } } } output { result_value { int64_value: 999 } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr context, CreateTestContext()); absl::flat_hash_map bindings; bindings["y"] = ParseTextProtoOrDie(R"pb(int64_value: 5)pb"); context->SetCustomBindings(std::move(bindings)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 15"); // expected 15 got 999. } INSTANTIATE_TEST_SUITE_P(TestRunnerTests, TestRunnerParamTest, ::testing::Values(RuntimeApi::kRuntime, RuntimeApi::kBuilder)); TEST(TestRunnerStandaloneTest, DynamicInputWithoutCompilerFails) { const std::string expected_error = "INVALID_ARGUMENT: A compiler must be provided to compile a raw " "expression or .cel file."; EXPECT_FATAL_FAILURE( { // Create a compiler. ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, CreateBasicCompiler()); ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, compiler->Compile("x + y")); CheckedExpr checked_expr; ASSERT_THAT( cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { expr: "1 + 1" } } input { key: "y" value { value { int64_value: 2 } } } output { result_value { int64_value: 3 } } )pb"); // Create the expression builder. ASSERT_OK_AND_ASSIGN(auto builder, CreateTestCelExpressionBuilder()); // Create the TestRunner without the compiler. std::unique_ptr context = CelTestContext::CreateFromCelExpressionBuilder( /*cel_expression_builder=*/std::move(builder)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); test_runner.RunTest(test_case); }, expected_error); } TEST(TestRunnerStandaloneTest, RuntimeUsesRuntimePoolToResolveCustomProtoLiteral) { // Create a custom CompilerBuilder. ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); ASSERT_THAT(builder->AddLibrary(cel::StandardCompilerLibrary()), absl_testing::IsOk()); cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); ASSERT_THAT(checker_builder.AddVariable(cel::MakeVariableDecl( "custom_var", cel::MessageType(TestAllTypes::descriptor()))), absl_testing::IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(builder)->Build()); // Compile the expression. ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, compiler->Compile("custom_var.single_int32 == 123")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); // Create a runtime configured with the testing descriptor pool. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); // Define the test case. The important part is the "custom_var" input, // which forces 'ResolveValue' to run on a custom type. This succeeds because // the testing descriptor pool (used by CreateTestRuntime()) is configured // to contain the TestAllTypes descriptor. TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "custom_var" value { value { object_value { [type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes] { single_int32: 123 } } } } } output { result_value { bool_value: true } } )pb"); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST(TestRunnerStandaloneTest, RunTestFailsWhenNoExpressionSourceIsProvided) { const std::string expected_error = "INVALID_ARGUMENT: No expression source provided."; EXPECT_FATAL_FAILURE( { // Create a runtime. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 10 } } } input { key: "y" value { value { int64_value: 3 } } } output { result_value { int64_value: 123 } } )pb"); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, CreateBasicCompiler()); // Create a TestRunner but without an expression source. std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetCompiler(std::move(compiler)); TestRunner test_runner(std::move(context)); test_runner.RunTest(test_case); }, expected_error); } TEST(TestRunnerStandaloneTest, BasicTestWithErrorAssertion) { // Compile the expression. ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("x + y")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); // Create a runtime. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 1 } } } output { eval_error { errors { message: "No value with name \"y\" found in Activation" } } } )pb"); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST(TestRunnerStandaloneTest, BasicTestFailsWhenExpectingErrorButGotValue) { // Compile the expression. ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("1 + 1")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); // Create a runtime. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); TestCase test_case = ParseTextProtoOrDie(R"pb( output { eval_error { errors { message: "No value with name \"y\" found in Activation" } } } )pb"); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "Expected error but got value"); } TEST(TestRunnerStandaloneTest, BasicTestWithActivationFactorySucceeds) { ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("x + y")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetActivationFactory( [](const TestCase& test_case, google::protobuf::Arena* arena) -> absl::StatusOr { cel::Activation activation; activation.InsertOrAssignValue("x", cel::IntValue(10)); activation.InsertOrAssignValue("y", cel::IntValue(5)); return activation; }); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestCase test_case = ParseTextProtoOrDie(R"pb( output { result_value { int64_value: 15 } } )pb"); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); // Input bindings should override values set by the activation factory. test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 4 } } } output { result_value { int64_value: 9 } } )pb"); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST(TestRunnerStandaloneTest, CustomAssertFnIsUsed) { // Compile the expression. ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, DefaultCompiler().Compile("1 + 1")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); // Create a runtime. ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); // Set the output to a value that would fail the default assertion. TestCase test_case = ParseTextProtoOrDie(R"pb( output { result_value { int64_value: 102 } } )pb"); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetAssertFn([&](const cel::Value& computed, const TestCase& test_case, google::protobuf::Arena* arena) { ASSERT_TRUE(computed.Is()); EXPECT_EQ(computed.As().value(), 2); }); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); TestRunner test_runner(std::move(context)); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } TEST(CoverageTest, RuntimeCoverage) { ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), absl_testing::IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("x", cel::IntType())), absl_testing::IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("y", cel::IntType())), absl_testing::IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(compiler_builder)->Build()); ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, compiler->Compile("x > 1 && y > 1")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 2 } } } input { key: "y" value { value { int64_value: 0 } } } output { result_value { bool_value: false } } )pb"); CoverageIndex coverage_index; ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), coverage_index), absl_testing::IsOk()); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(checked_expr)); TestRunner test_runner(std::move(context)); coverage_index.Init(checked_expr); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); EXPECT_GT(report.nodes, 0); EXPECT_GT(report.covered_nodes, 0); EXPECT_EQ(report.branches, 6); EXPECT_EQ(report.covered_boolean_outcomes, 3); EXPECT_THAT( report.unencountered_branches, ::testing::ElementsAre( HasSubstr("\nExpression ID 7 ('x > 1 && y > 1'): Never " "evaluated to 'true'"), HasSubstr( "\n\t\tExpression ID 2 ('x > 1'): Never evaluated to 'false'"), HasSubstr( "\n\t\tExpression ID 5 ('y > 1'): Never evaluated to 'true'"))); } TEST(CoverageTest, BuilderCoverage) { ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), absl_testing::IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("x", cel::IntType())), absl_testing::IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("y", cel::IntType())), absl_testing::IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(compiler_builder)->Build()); ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, compiler->Compile("x > 1 && y > 1")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 0 } } } input { key: "y" value { value { int64_value: 2 } } } output { result_value { bool_value: false } } )pb"); CoverageIndex coverage_index; ASSERT_OK_AND_ASSIGN(std::unique_ptr builder, CreateTestCelExpressionBuilder()); ASSERT_THAT(EnableCoverageInCelExpressionBuilder(*builder, coverage_index), absl_testing::IsOk()); std::unique_ptr context = CelTestContext::CreateFromCelExpressionBuilder(std::move(builder)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(checked_expr)); TestRunner test_runner(std::move(context)); coverage_index.Init(checked_expr); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); EXPECT_GT(report.nodes, 0); EXPECT_GT(report.covered_nodes, 0); EXPECT_EQ(report.branches, 6); EXPECT_EQ(report.covered_boolean_outcomes, 2); EXPECT_THAT(report.unencountered_nodes, ::testing::UnorderedElementsAre(HasSubstr("y > 1"))); EXPECT_THAT( report.unencountered_branches, ::testing::UnorderedElementsAre(HasSubstr("Never evaluated to 'true'"), HasSubstr("Never evaluated to 'true'"))); } TEST(CoverageTest, DotGraphIsGeneratedForRuntime) { ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), absl_testing::IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("x", cel::IntType())), absl_testing::IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("y", cel::IntType())), absl_testing::IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(compiler_builder)->Build()); ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, compiler->Compile("x > 1 && y > 1")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 2 } } } input { key: "y" value { value { int64_value: 0 } } } output { result_value { bool_value: false } } )pb"); CoverageIndex coverage_index; ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), coverage_index), absl_testing::IsOk()); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(checked_expr)); TestRunner test_runner(std::move(context)); coverage_index.Init(checked_expr); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); absl::string_view dot_graph = report.dot_graph; // Check for graph structure EXPECT_THAT(dot_graph, StartsWith(kDigraphHeader)); EXPECT_THAT(dot_graph, EndsWith("}\n")); EXPECT_THAT(dot_graph, HasSubstr("->")); EXPECT_THAT(dot_graph, HasSubstr("shape=record")); // Check for the existence of complete labels for key nodes, using the actual // expression IDs from the build log. EXPECT_THAT(dot_graph, HasSubstr("label=\"{<1> exprID: 7 | <2> Call Node} | " "<3> x \\> 1 && y \\> 1\"")); EXPECT_THAT( dot_graph, HasSubstr("label=\"{<1> exprID: 2 | <2> Call Node} | <3> x \\> 1\"")); EXPECT_THAT( dot_graph, HasSubstr("label=\"{<1> exprID: 5 | <2> Call Node} | <3> y \\> 1\"")); // Check for coverage styles EXPECT_THAT(dot_graph, HasSubstr(kCompletelyCoveredNodeStyle)); EXPECT_THAT(dot_graph, HasSubstr(kPartiallyCoveredNodeStyle)); EXPECT_THAT(dot_graph, Not(HasSubstr(kUncoveredNodeStyle))); } TEST(CoverageTest, DotGraphIsGeneratedForComprehension) { ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), absl_testing::IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(compiler_builder)->Build()); ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, compiler->Compile("[1, 2, 3].all(i, i > 0)")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); // Test case expects 'true' since all elements are > 0. TestCase test_case = ParseTextProtoOrDie(R"pb( output { result_value { bool_value: true } } )pb"); CoverageIndex coverage_index; ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), coverage_index), absl_testing::IsOk()); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(checked_expr)); TestRunner test_runner(std::move(context)); coverage_index.Init(checked_expr); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); absl::string_view dot_graph = report.dot_graph; // Assert that the specific kinds for comprehension nodes are present in the // generated graph. EXPECT_THAT(dot_graph, HasSubstr("IterRange")); EXPECT_THAT(dot_graph, HasSubstr("AccuInit")); EXPECT_THAT(dot_graph, HasSubstr("LoopCondition")); EXPECT_THAT(dot_graph, HasSubstr("LoopStep")); EXPECT_THAT(dot_graph, HasSubstr("Result")); // The expression is fully evaluated, so no nodes should be uncovered. EXPECT_THAT(dot_graph, Not(HasSubstr(kUncoveredNodeStyle))); } TEST(CoverageTest, PartiallyCoveredBooleanNodeIsStyledCorrectly) { // This test is designed to kill a mutant that incorrectly styles partially // covered boolean nodes as completely covered. It uses a short-circuiting // expression to ensure that some boolean nodes are only evaluated one way // (e.g., only to 'true'), making them partially covered. ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); ASSERT_THAT(compiler_builder->AddLibrary(cel::StandardCompilerLibrary()), absl_testing::IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("x", cel::IntType())), absl_testing::IsOk()); ASSERT_THAT(compiler_builder->GetCheckerBuilder().AddVariable( cel::MakeVariableDecl("y", cel::IntType())), absl_testing::IsOk()); ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, std::move(compiler_builder)->Build()); ASSERT_OK_AND_ASSIGN( cel::ValidationResult validation_result, compiler->Compile("{'sum': x + y, 'literal': 3}.sum == 3 || x == y")); CheckedExpr checked_expr; ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), absl_testing::IsOk()); TestCase test_case = ParseTextProtoOrDie(R"pb( input { key: "x" value { value { int64_value: 1 } } } input { key: "y" value { value { int64_value: 2 } } } output { result_value { bool_value: true } } )pb"); CoverageIndex coverage_index; ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, CreateTestRuntime()); ASSERT_THAT(EnableCoverageInRuntime(*const_cast(runtime.get()), coverage_index), absl_testing::IsOk()); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetExpressionSource( CelExpressionSource::FromCheckedExpr(checked_expr)); TestRunner test_runner(std::move(context)); coverage_index.Init(checked_expr); EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); CoverageIndex::CoverageReport report = coverage_index.GetCoverageReport(); // With x=1, y=2, the left side of '||' is true, so the right side ('x == y') // is short-circuited and never evaluated. // - The '||' node and the '==' node are partially covered (only 'true'). // - The 'x == y' branch (and its children) are uncovered. // - All other evaluated nodes are fully covered. EXPECT_EQ(CountSubstrings(report.dot_graph, kPartiallyCoveredNodeStyle), 2); EXPECT_EQ(CountSubstrings(report.dot_graph, kUncoveredNodeStyle), 3); EXPECT_EQ(CountSubstrings(report.dot_graph, kCompletelyCoveredNodeStyle), 9); } } // namespace } // namespace cel::test ================================================ FILE: testing/testrunner/user_tests/BUILD ================================================ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("//testing/testrunner:cel_cc_test.bzl", "cel_cc_test") package(default_visibility = ["//visibility:public"]) cc_library( name = "simple_user_test", testonly = True, srcs = ["simple.cc"], deps = [ "//checker:type_checker_builder", "//checker:validation_result", "//common:ast_proto", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:status_macros", "//internal:testing_descriptor_pool", "//runtime", "//runtime:runtime_builder", "//runtime:standard_runtime_builder_factory", "//testing/testrunner:cel_expression_source", "//testing/testrunner:cel_test_context", "//testing/testrunner:cel_test_factories", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_protobuf//:protobuf", ], alwayslink = True, ) cc_library( name = "raw_expression_user_test", testonly = True, srcs = ["raw_expression_test.cc"], deps = [ "//checker:type_checker_builder", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:status_macros", "//internal:testing_descriptor_pool", "//runtime", "//runtime:runtime_builder", "//runtime:standard_runtime_builder_factory", "//testing/testrunner:cel_expression_source", "//testing/testrunner:cel_test_context", "//testing/testrunner:cel_test_factories", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], alwayslink = True, ) cc_library( name = "raw_expr_and_cel_file_test", testonly = True, srcs = ["raw_expr_and_cel_file_test.cc"], deps = [ "//checker:type_checker_builder", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:status_macros", "//internal:testing_descriptor_pool", "//runtime", "//runtime:runtime_builder", "//runtime:standard_runtime_builder_factory", "//testing/testrunner:cel_test_context", "//testing/testrunner:cel_test_factories", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], alwayslink = True, ) cc_library( name = "checked_expr_user_test", testonly = True, srcs = ["checked_expr_test.cc"], deps = [ "//internal:status_macros", "//internal:testing_descriptor_pool", "//runtime", "//runtime:runtime_builder", "//runtime:standard_runtime_builder_factory", "//testing/testrunner:cel_test_context", "//testing/testrunner:cel_test_factories", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], alwayslink = True, ) cel_cc_test( name = "simple_test", enable_coverage = True, filegroup = "//testing/testrunner/resources", test_data_path = "//testing/testrunner/resources", test_suite = "simple_tests.textproto", deps = [ ":simple_user_test", ], ) cel_cc_test( name = "simple_test_with_custom_test_suite", enable_coverage = True, filegroup = "//testing/testrunner/resources", test_data_path = "//testing/testrunner/resources", deps = [ ":simple_user_test", ], ) cel_cc_test( name = "raw_expression_test_with_custom_test_suite", enable_coverage = True, deps = [ ":raw_expression_user_test", ], ) cel_cc_test( name = "subtraction_raw_expr_test", cel_expr = "x - y", is_raw_expr = True, deps = [ ":raw_expr_and_cel_file_test", ], ) cel_cc_test( name = "subtraction_cel_file_test", cel_expr = "test.cel", test_data_path = "//testing/testrunner/resources", deps = [ ":raw_expr_and_cel_file_test", ], ) ================================================ FILE: testing/testrunner/user_tests/checked_expr_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "internal/status_macros.h" #include "internal/testing_descriptor_pool.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/standard_runtime_builder_factory.h" #include "testing/testrunner/cel_test_context.h" #include "testing/testrunner/cel_test_factories.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/text_format.h" namespace cel::testing { using ::cel::test::CelTestContext; template T ParseTextProtoOrDie(absl::string_view text_proto) { T result; ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); return result; } CEL_REGISTER_TEST_SUITE_FACTORY([]() { return ParseTextProtoOrDie(R"pb( name: "cli_expression_tests" description: "Tests designed for expressions passed via CLI flags." sections: { name: "subtraction_test" description: "Tests subtraction of two variables." tests: { name: "variable_subtraction" description: "Test that subtraction of two variables works." input: { key: "x" value { value { int64_value: 10 } } } input { key: "y" value { value { int64_value: 5 } } } output { result_value { int64_value: 5 } } } } )pb"); }); CEL_REGISTER_TEST_CONTEXT_FACTORY( []() -> absl::StatusOr> { ABSL_LOG(INFO) << "Creating runtime-only test context for CheckedExpr"; // Create a runtime. CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, cel::CreateStandardRuntimeBuilder( cel::internal::GetTestingDescriptorPool(), {})); CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, std::move(runtime_builder).Build()); // Create the context with the runtime, but no compiler. // The test runner will inject the CheckedExpr source later. return CelTestContext::CreateFromRuntime(std::move(runtime)); }); } // namespace cel::testing ================================================ FILE: testing/testrunner/user_tests/raw_expr_and_cel_file_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/status_macros.h" #include "internal/testing_descriptor_pool.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/standard_runtime_builder_factory.h" #include "testing/testrunner/cel_test_context.h" #include "testing/testrunner/cel_test_factories.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/text_format.h" namespace cel::testing { using ::cel::test::CelTestContext; template T ParseTextProtoOrDie(absl::string_view text_proto) { T result; ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); return result; } CEL_REGISTER_TEST_SUITE_FACTORY([]() { return ParseTextProtoOrDie(R"pb( name: "cli_expression_tests" description: "Tests designed for expressions passed via CLI flags." sections: { name: "subtraction_test" description: "Tests subtraction of two variables." tests: { name: "variable_subtraction" description: "Test that subtraction of two variables works." input: { key: "x" value { value { int64_value: 10 } } } input { key: "y" value { value { int64_value: 5 } } } output { result_value { int64_value: 5 } } } } )pb"); }); CEL_REGISTER_TEST_CONTEXT_FACTORY( []() -> absl::StatusOr> { ABSL_LOG(INFO) << "Creating test context for raw_expr and cel_file"; // Create a compiler. CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); CEL_RETURN_IF_ERROR(checker_builder.AddVariable( cel::MakeVariableDecl("x", cel::IntType()))); CEL_RETURN_IF_ERROR(checker_builder.AddVariable( cel::MakeVariableDecl("y", cel::IntType()))); CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, std::move(builder)->Build()); // Create a runtime. CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, cel::CreateStandardRuntimeBuilder( cel::internal::GetTestingDescriptorPool(), {})); CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, std::move(runtime_builder).Build()); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetCompiler(std::move(compiler)); return context; }); } // namespace cel::testing ================================================ FILE: testing/testrunner/user_tests/raw_expression_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/status_macros.h" #include "internal/testing_descriptor_pool.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/standard_runtime_builder_factory.h" #include "testing/testrunner/cel_expression_source.h" #include "testing/testrunner/cel_test_context.h" #include "testing/testrunner/cel_test_factories.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/text_format.h" namespace cel::testing { using ::cel::test::CelTestContext; template T ParseTextProtoOrDie(absl::string_view text_proto) { T result; ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); return result; } CEL_REGISTER_TEST_SUITE_FACTORY([]() { return ParseTextProtoOrDie(R"pb( name: "raw_expression_tests" description: "Tests for validating support for raw CEL expressions in test inputs and outputs." sections: { name: "raw_expression_io" description: "A section for tests with raw CEL expressions in inputs and outputs." tests: { name: "eval_input_and_output" description: "Test that a raw CEL expression can be provided as both an input and an expected output." input: { key: "x" value { expr: "1 + 1" } } input: { key: "y" value { value { int64_value: 8 } } } output { result_expr: "5 * 2" } } } )pb"); }); CEL_REGISTER_TEST_CONTEXT_FACTORY( []() -> absl::StatusOr> { // Create a compiler. CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); CEL_RETURN_IF_ERROR(checker_builder.AddVariable( cel::MakeVariableDecl("x", cel::IntType()))); CEL_RETURN_IF_ERROR(checker_builder.AddVariable( cel::MakeVariableDecl("y", cel::IntType()))); CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, builder->Build()); // Create a runtime. CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, cel::CreateStandardRuntimeBuilder( cel::internal::GetTestingDescriptorPool(), {})); CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, std::move(runtime_builder).Build()); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetCompiler(std::move(compiler)); context->SetExpressionSource( test::CelExpressionSource::FromRawExpression("x + y")); return context; }); } // namespace cel::testing ================================================ FILE: testing/testrunner/user_tests/simple.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "cel/expr/checked.pb.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast_proto.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/status_macros.h" #include "internal/testing_descriptor_pool.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/standard_runtime_builder_factory.h" #include "testing/testrunner/cel_expression_source.h" #include "testing/testrunner/cel_test_context.h" #include "testing/testrunner/cel_test_factories.h" #include "google/protobuf/text_format.h" namespace cel::testing { using ::cel::test::CelTestContext; using ::cel::expr::CheckedExpr; template T ParseTextProtoOrDie(absl::string_view text_proto) { T result; ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(text_proto, &result)); return result; } CEL_REGISTER_TEST_SUITE_FACTORY([]() { return ParseTextProtoOrDie(R"pb( name: "custom_test_suite_tests" description: "Simple tests to validate the test runner." sections: { name: "simple_map_operations" description: "Tests for map operations." tests: { name: "literal_and_sum" description: "Test that a map can be created and values can be accessed." input: { key: "x" value { value { int64_value: 1 } } } input { key: "y" value { value { int64_value: 2 } } } output { result_value { bool_value: true } } } } )pb"); }); CEL_REGISTER_TEST_CONTEXT_FACTORY( []() -> absl::StatusOr> { ABSL_LOG(INFO) << "Creating test context"; // Create a compiler. CEL_ASSIGN_OR_RETURN( std::unique_ptr builder, cel::NewCompilerBuilder(cel::internal::GetTestingDescriptorPool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); cel::TypeCheckerBuilder& checker_builder = builder->GetCheckerBuilder(); CEL_RETURN_IF_ERROR(checker_builder.AddVariable( cel::MakeVariableDecl("x", cel::IntType()))); CEL_RETURN_IF_ERROR(checker_builder.AddVariable( cel::MakeVariableDecl("y", cel::IntType()))); CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, builder->Build()); // Compile the expression. CEL_ASSIGN_OR_RETURN( cel::ValidationResult validation_result, compiler->Compile("{'sum': x + y, 'literal': 3}.sum == 3 || x == y")); CheckedExpr checked_expr; CEL_RETURN_IF_ERROR( cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr)); // Create a runtime. CEL_ASSIGN_OR_RETURN(cel::RuntimeBuilder runtime_builder, cel::CreateStandardRuntimeBuilder( cel::internal::GetTestingDescriptorPool(), {})); CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, std::move(runtime_builder).Build()); std::unique_ptr context = CelTestContext::CreateFromRuntime(std::move(runtime)); context->SetExpressionSource( test::CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); return context; }); } // namespace cel::testing ================================================ FILE: testutil/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "expr_printer", srcs = ["expr_printer.cc"], hdrs = ["expr_printer.h"], deps = [ "//common:ast", "//common:ast_proto", "//common:constant", "//common:expr", "//internal:strings", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "expr_printer_test", srcs = ["expr_printer_test.cc"], deps = [ ":expr_printer", "//common:expr", "//internal:testing", "//parser", "//parser:options", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/strings", ], ) cc_library( name = "util", testonly = True, hdrs = [ "util.h", ], deps = ["//internal:proto_matchers"], ) cc_library( name = "baseline_tests", testonly = True, srcs = ["baseline_tests.cc"], hdrs = ["baseline_tests.h"], deps = [ ":expr_printer", "//common:ast", "//common:expr", "//extensions/protobuf:ast_converters", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", ], ) cc_test( name = "baseline_tests_test", srcs = ["baseline_tests_test.cc"], deps = [ ":baseline_tests", "//common:ast", "//internal:testing", "@com_google_protobuf//:protobuf", ], ) proto_library( name = "test_json_names_proto", srcs = ["test_json_names.proto"], ) ================================================ FILE: testutil/baseline_tests.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "testutil/baseline_tests.h" #include #include #include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "common/ast.h" #include "common/expr.h" #include "extensions/protobuf/ast_converters.h" #include "testutil/expr_printer.h" namespace cel::test { namespace { std::string FormatReference(const cel::Reference& r) { if (r.overload_id().empty()) { return r.name(); } return absl::StrJoin(r.overload_id(), "|"); } class TypeAdorner : public ExpressionAdorner { public: explicit TypeAdorner(const Ast& ast) : ast_(ast) {} std::string Adorn(const Expr& e) const override { std::string s; auto t = ast_.type_map().find(e.id()); if (t != ast_.type_map().end()) { absl::StrAppend(&s, "~", FormatTypeSpec(t->second)); } if (const auto r = ast_.reference_map().find(e.id()); r != ast_.reference_map().end()) { absl::StrAppend(&s, "^", FormatReference(r->second)); } return s; } std::string AdornStructField(const StructExprField& e) const override { return ""; } std::string AdornMapEntry(const MapExprEntry& e) const override { return ""; } private: const Ast& ast_; }; } // namespace std::string FormatBaselineAst(const Ast& ast) { TypeAdorner adorner(ast); ExprPrinter printer(adorner); return printer.Print(ast.root_expr()); } std::string FormatBaselineCheckedExpr( const cel::expr::CheckedExpr& checked) { auto ast = cel::extensions::CreateAstFromCheckedExpr(checked); if (!ast.ok()) { return ast.status().ToString(); } return FormatBaselineAst(**ast); } } // namespace cel::test ================================================ FILE: testutil/baseline_tests.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Utilities for baseline tests. Baseline files are textual reports in a common // format that can be used to compare the output of each of the libraries. // // The protobuf ast format is a bit tricky to compare directly (e.g. // renumberings do not change the meaning of the expression), so we use a custom // format that compares well with simple string comparisons. // // Example: // ``` // Source: Foo(a.b) // declare a { // variable map(string,dyn) // } // declare Foo { // function foo_string(string) -> string // function foo_int(int) -> int // } // =========> // Foo( // a~map(string,dyn)^a.b~dyn // )~dyn^foo_string|foo_int // // // ``` #ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_BASELINE_TESTS_H_ #define THIRD_PARTY_CEL_CPP_TESTUTIL_BASELINE_TESTS_H_ #include #include "cel/expr/checked.pb.h" #include "common/ast.h" namespace cel::test { // Returns a string representation of the AST that matches the baseline format // used in tests across the CEL libraries. std::string FormatBaselineAst(const Ast& ast); // Returns a string representation of the protobuf AST that matches the baseline // format used in tests across the CEL libraries. std::string FormatBaselineCheckedExpr( const cel::expr::CheckedExpr& checked); } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TESTUTIL_BASELINE_TEST_H_ ================================================ FILE: testutil/baseline_tests_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or astied. // See the License for the specific language governing permissions and // limitations under the License. #include "testutil/baseline_tests.h" #include #include #include "common/ast.h" #include "internal/testing.h" #include "google/protobuf/text_format.h" namespace cel::test { namespace { using ::cel::expr::CheckedExpr; TEST(FormatBaselineAst, Basic) { Ast ast; ast.mutable_root_expr().mutable_ident_expr().set_name("foo"); ast.mutable_root_expr().set_id(1); ast.mutable_type_map()[1] = TypeSpec(PrimitiveType::kInt64); ast.mutable_reference_map()[1].set_name("foo"); EXPECT_EQ(FormatBaselineAst(ast), "foo~int^foo"); } TEST(FormatBaselineAst, NoType) { Ast ast; ast.mutable_root_expr().mutable_ident_expr().set_name("foo"); ast.mutable_root_expr().set_id(1); ast.mutable_reference_map()[1].set_name("foo"); EXPECT_EQ(FormatBaselineAst(ast), "foo^foo"); } TEST(FormatBaselineAst, NoReference) { Ast ast; ast.mutable_root_expr().mutable_ident_expr().set_name("foo"); ast.mutable_root_expr().set_id(1); ast.mutable_type_map()[1] = TypeSpec(PrimitiveType::kInt64); EXPECT_EQ(FormatBaselineAst(ast), "foo~int"); } TEST(FormatBaselineAst, MutlipleReferences) { Ast ast; ast.mutable_root_expr().mutable_call_expr().set_function("_+_"); ast.mutable_root_expr().set_id(1); ast.mutable_type_map()[1] = TypeSpec(DynTypeSpec()); ast.mutable_reference_map()[1].mutable_overload_id().push_back( "add_timestamp_duration"); ast.mutable_reference_map()[1].mutable_overload_id().push_back( "add_duration_duration"); { auto& arg1 = ast.mutable_root_expr().mutable_call_expr().add_args(); arg1.mutable_ident_expr().set_name("a"); arg1.set_id(2); ast.mutable_type_map()[2] = TypeSpec(DynTypeSpec()); ast.mutable_reference_map()[2].set_name("a"); } { auto& arg2 = ast.mutable_root_expr().mutable_call_expr().add_args(); arg2.mutable_ident_expr().set_name("b"); arg2.set_id(3); ast.mutable_type_map()[3] = TypeSpec(WellKnownTypeSpec::kDuration); ast.mutable_reference_map()[3].set_name("b"); } EXPECT_EQ(FormatBaselineAst(ast), "_+_(\n" " a~dyn^a,\n" " b~google.protobuf.Duration^b\n" ")~dyn^add_timestamp_duration|add_duration_duration"); } TEST(FormatBaselineCheckedExpr, MutlipleReferences) { CheckedExpr checked; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr { id: 1 call_expr { function: "_+_" args { id: 2 ident_expr { name: "a" } } args { id: 3 ident_expr { name: "b" } } } } type_map { key: 1 value { dyn {} } } type_map { key: 2 value { dyn {} } } type_map { key: 3 value { well_known: DURATION } } reference_map { key: 1 value { overload_id: "add_timestamp_duration" overload_id: "add_duration_duration" } } reference_map { key: 2 value { name: "a" } } reference_map { key: 3 value { name: "b" } } )pb", &checked)); EXPECT_EQ(FormatBaselineCheckedExpr(checked), "_+_(\n" " a~dyn^a,\n" " b~google.protobuf.Duration^b\n" ")~dyn^add_timestamp_duration|add_duration_duration"); } struct TestCase { TypeSpec type; std::string expected_string; }; class FormatBaselineTypeSpecTest : public testing::TestWithParam {}; TEST_P(FormatBaselineTypeSpecTest, Runner) { Ast ast; ast.mutable_root_expr().set_id(1); ast.mutable_root_expr().mutable_ident_expr().set_name("x"); ast.mutable_type_map()[1] = GetParam().type; EXPECT_EQ(FormatBaselineAst(ast), GetParam().expected_string); } INSTANTIATE_TEST_SUITE_P( Types, FormatBaselineTypeSpecTest, ::testing::Values( TestCase{TypeSpec(PrimitiveType::kBool), "x~bool"}, TestCase{TypeSpec(PrimitiveType::kInt64), "x~int"}, TestCase{TypeSpec(PrimitiveType::kUint64), "x~uint"}, TestCase{TypeSpec(PrimitiveType::kDouble), "x~double"}, TestCase{TypeSpec(PrimitiveType::kString), "x~string"}, TestCase{TypeSpec(PrimitiveType::kBytes), "x~bytes"}, TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), "x~wrapper(bool)"}, TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), "x~wrapper(int)"}, TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), "x~wrapper(uint)"}, TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), "x~wrapper(double)"}, TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), "x~wrapper(string)"}, TestCase{TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), "x~wrapper(bytes)"}, TestCase{TypeSpec(WellKnownTypeSpec::kAny), "x~google.protobuf.Any"}, TestCase{TypeSpec(WellKnownTypeSpec::kDuration), "x~google.protobuf.Duration"}, TestCase{TypeSpec(WellKnownTypeSpec::kTimestamp), "x~google.protobuf.Timestamp"}, TestCase{TypeSpec(DynTypeSpec()), "x~dyn"}, TestCase{TypeSpec(NullTypeSpec()), "x~null"}, TestCase{TypeSpec(UnsetTypeSpec()), "x~*error*"}, TestCase{TypeSpec(MessageTypeSpec("com.example.Type")), "x~com.example.Type"}, TestCase{TypeSpec(AbstractType("optional_type", {TypeSpec(PrimitiveType::kInt64)})), "x~optional_type(int)"}, TestCase{TypeSpec(std::make_unique()), "x~type"}, TestCase{TypeSpec(std::make_unique(PrimitiveType::kInt64)), "x~type(int)"}, TestCase{TypeSpec(ParamTypeSpec("T")), "x~T"}, TestCase{TypeSpec(MapTypeSpec( std::make_unique(PrimitiveType::kString), std::make_unique(PrimitiveType::kString))), "x~map(string, string)"}, TestCase{TypeSpec(ListTypeSpec( std::make_unique(PrimitiveType::kString))), "x~list(string)"})); } // namespace } // namespace cel::test ================================================ FILE: testutil/expr_printer.cc ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "testutil/expr_printer.h" #include #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_log.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "common/ast.h" #include "common/ast_proto.h" #include "common/constant.h" #include "common/expr.h" #include "internal/strings.h" namespace cel::test { namespace { class EmptyAdornerImpl : public ExpressionAdorner { public: std::string Adorn(const Expr& e) const override { return ""; } std::string AdornStructField(const StructExprField& e) const override { return ""; } std::string AdornMapEntry(const MapExprEntry& e) const override { return ""; } }; class StringBuilder { public: explicit StringBuilder(const ExpressionAdorner& adorner) : adorner_(adorner), line_start_(true), indent_(0) {} std::string Print(const Expr& expr) { AppendExpr(expr); return s_; } private: void AppendExpr(const Expr& e) { switch (e.kind_case()) { case ExprKindCase::kConstant: Append(FormatLiteral(e.const_expr())); break; case ExprKindCase::kIdentExpr: Append(e.ident_expr().name()); break; case ExprKindCase::kSelectExpr: AppendSelect(e.select_expr()); break; case ExprKindCase::kCallExpr: AppendCall(e.call_expr()); break; case ExprKindCase::kListExpr: AppendList(e.list_expr()); break; case ExprKindCase::kMapExpr: AppendMap(e.map_expr()); break; case ExprKindCase::kStructExpr: AppendStruct(e.struct_expr()); break; case ExprKindCase::kComprehensionExpr: AppendComprehension(e.comprehension_expr()); break; default: break; } Append(adorner_.Adorn(e)); } void AppendSelect(const SelectExpr& sel) { AppendExpr(sel.operand()); Append("."); Append(sel.field()); if (sel.test_only()) { Append("~test-only~"); } } void AppendCall(const CallExpr& call) { if (call.has_target()) { AppendExpr(call.target()); s_ += "."; } Append(call.function()); if (call.args().empty()) { Append("()"); return; } Append("("); Indent(); AppendLine(); for (int i = 0; i < call.args().size(); ++i) { const auto& arg = call.args()[i]; if (i > 0) { Append(","); AppendLine(); } AppendExpr(arg); } AppendLine(); Unindent(); Append(")"); } void AppendList(const ListExpr& list) { if (list.elements().empty()) { Append("[]"); return; } Append("["); AppendLine(); Indent(); for (int i = 0; i < list.elements().size(); ++i) { const auto& elem = list.elements()[i]; if (i > 0) { Append(","); AppendLine(); } if (elem.optional()) { Append("?"); } AppendExpr(elem.expr()); } AppendLine(); Unindent(); Append("]"); } void AppendStruct(const StructExpr& obj) { Append(obj.name()); if (obj.fields().empty()) { Append("{}"); return; } Append("{"); AppendLine(); Indent(); for (int i = 0; i < obj.fields().size(); ++i) { const auto& entry = obj.fields()[i]; if (i > 0) { Append(","); AppendLine(); } if (entry.optional()) { Append("?"); } Append(entry.name()); Append(":"); AppendExpr(entry.value()); Append(adorner_.AdornStructField(entry)); } AppendLine(); Unindent(); Append("}"); } void AppendMap(const MapExpr& obj) { if (obj.entries().empty()) { Append("{}"); return; } Append("{"); AppendLine(); Indent(); for (int i = 0; i < obj.entries().size(); ++i) { const auto& entry = obj.entries()[i]; if (i > 0) { Append(","); AppendLine(); } if (entry.optional()) { Append("?"); } AppendExpr(entry.key()); Append(":"); AppendExpr(entry.value()); Append(adorner_.AdornMapEntry(entry)); } AppendLine(); Unindent(); Append("}"); } void AppendComprehension(const ComprehensionExpr& comprehension) { Append("__comprehension__("); Indent(); AppendLine(); Append("// Variable"); AppendLine(); Append(comprehension.iter_var()); Append(","); AppendLine(); Append("// Target"); AppendLine(); AppendExpr(comprehension.iter_range()); Append(","); AppendLine(); Append("// Accumulator"); AppendLine(); Append(comprehension.accu_var()); Append(","); AppendLine(); Append("// Init"); AppendLine(); AppendExpr(comprehension.accu_init()); Append(","); AppendLine(); Append("// LoopCondition"); AppendLine(); AppendExpr(comprehension.loop_condition()); Append(","); AppendLine(); Append("// LoopStep"); AppendLine(); AppendExpr(comprehension.loop_step()); Append(","); AppendLine(); Append("// Result"); AppendLine(); AppendExpr(comprehension.result()); Append(")"); Unindent(); } void Append(const std::string& s) { if (line_start_) { line_start_ = false; for (int i = 0; i < indent_; ++i) { s_ += " "; } } s_ += s; } void AppendLine() { s_ += "\n"; line_start_ = true; } void Indent() { ++indent_; } void Unindent() { if (indent_ >= 0) { --indent_; } else { ABSL_LOG(ERROR) << "ExprPrinter indent underflow"; } } std::string FormatLiteral(const Constant& c) { switch (c.kind_case()) { case ConstantKindCase::kBool: return absl::StrFormat("%s", c.bool_value() ? "true" : "false"); case ConstantKindCase::kBytes: return cel::internal::FormatDoubleQuotedBytesLiteral(c.bytes_value()); case ConstantKindCase::kDouble: { std::string s = absl::StrFormat("%f", c.double_value()); // remove trailing zeros, i.e., convert 1.600000 to just 1.6 without // forcing a specific precision. There seems to be no flag to get this // directly from absl::StrFormat. auto idx = std::find_if_not(s.rbegin(), s.rend(), [](const char c) { return c == '0'; }); s.erase(idx.base(), s.end()); if (absl::EndsWith(s, ".")) { s += '0'; } return s; } case ConstantKindCase::kInt: return absl::StrFormat("%d", c.int_value()); case ConstantKindCase::kString: return cel::internal::FormatDoubleQuotedStringLiteral(c.string_value()); case ConstantKindCase::kUint: return absl::StrFormat("%uu", c.uint_value()); case ConstantKindCase::kNull: return "null"; default: return "<>"; } } std::string s_; const ExpressionAdorner& adorner_; bool line_start_; int indent_; }; } // namespace const ExpressionAdorner& EmptyAdorner() { static absl::NoDestructor kInstance; return *kInstance; } std::string ExprPrinter::PrintProto(const cel::expr::Expr& expr) const { StringBuilder w(adorner_); absl::StatusOr> ast = CreateAstFromParsedExpr(expr); if (!ast.ok()) { return std::string(ast.status().message()); } return w.Print(ast.value()->root_expr()); } std::string ExprPrinter::Print(const Expr& expr) const { StringBuilder w(adorner_); return w.Print(expr); } } // namespace cel::test ================================================ FILE: testutil/expr_printer.h ================================================ // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_EXPR_PRINTER_H_ #define THIRD_PARTY_CEL_CPP_TESTUTIL_EXPR_PRINTER_H_ #include #include "cel/expr/syntax.pb.h" #include "common/expr.h" namespace cel::test { // Interface for adding additional information to an expression during // printing. class ExpressionAdorner { public: virtual ~ExpressionAdorner() = default; virtual std::string Adorn(const Expr& e) const = 0; virtual std::string AdornStructField(const StructExprField& e) const = 0; virtual std::string AdornMapEntry(const MapExprEntry& e) const = 0; }; // Default implementation of the ExpressionAdorner which does nothing. const ExpressionAdorner& EmptyAdorner(); // Helper class for printing an expression AST to a human readable, but detailed // and consistently formatted string. // // Note: this implementation is recursive and is not suitable for printing // arbitrarily large expressions. class ExprPrinter { public: ExprPrinter() : adorner_(EmptyAdorner()) {} explicit ExprPrinter(const ExpressionAdorner& adorner) : adorner_(adorner) {} std::string PrintProto(const cel::expr::Expr& expr) const; std::string Print(const Expr& expr) const; private: const ExpressionAdorner& adorner_; }; } // namespace cel::test #endif // THIRD_PARTY_CEL_CPP_TESTUTIL_EXPR_PRINTER_H_ ================================================ FILE: testutil/expr_printer_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "testutil/expr_printer.h" #include #include "absl/base/no_destructor.h" #include "absl/strings/str_cat.h" #include "common/expr.h" #include "internal/testing.h" #include "parser/options.h" #include "parser/parser.h" namespace cel::test { namespace { using ::google::api::expr::parser::Parse; class TestAdorner : public ExpressionAdorner { public: static const TestAdorner& Get() { static absl::NoDestructor kInstance; return *kInstance; } std::string Adorn(const Expr& e) const override { return absl::StrCat("#", e.id()); } std::string AdornStructField(const StructExprField& e) const override { return absl::StrCat("#", e.id()); } std::string AdornMapEntry(const MapExprEntry& e) const override { return absl::StrCat("#", e.id()); } }; TEST(ExprPrinterTest, Identifier) { Expr expr; expr.mutable_ident_expr().set_name("foo"); expr.set_id(1); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), ("foo#1")); } TEST(ExprPrinterTest, ConstantString) { Expr expr; expr.mutable_const_expr().set_string_value("foo"); expr.set_id(1); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"("foo"#1)")); } TEST(ExprPrinterTest, ConstantBytes) { Expr expr; expr.mutable_const_expr().set_bytes_value("foo"); expr.set_id(1); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"(b"foo"#1)")); } TEST(ExprPrinterTest, ConstantInt) { Expr expr; expr.mutable_const_expr().set_int_value(1); expr.set_id(1); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"(1#1)")); } TEST(ExprPrinterTest, ConstantUint) { Expr expr; expr.mutable_const_expr().set_uint_value(1); expr.set_id(1); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"(1u#1)")); } TEST(ExprPrinterTest, ConstantDouble) { Expr expr; expr.mutable_const_expr().set_double_value(1.1); expr.set_id(1); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"(1.1#1)")); } TEST(ExprPrinterTest, ConstantBool) { Expr expr; expr.mutable_const_expr().set_bool_value(true); expr.set_id(1); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"(true#1)")); } TEST(ExprPrinterTest, Call) { Expr expr; expr.mutable_call_expr().set_function("foo"); expr.set_id(1); { Expr& arg1 = expr.mutable_call_expr().add_args(); arg1.mutable_const_expr().set_int_value(1); arg1.set_id(2); } { Expr& arg2 = expr.mutable_call_expr().add_args(); arg2.mutable_const_expr().set_int_value(2); arg2.set_id(3); } ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"(foo( 1#2, 2#3 )#1)")); } TEST(ExprPrinterTest, ReceiverCall) { Expr expr; expr.mutable_call_expr().set_function("foo"); expr.set_id(1); { Expr& target = expr.mutable_call_expr().mutable_target(); target.mutable_const_expr().set_string_value("bar"); target.set_id(2); } { Expr& arg2 = expr.mutable_call_expr().add_args(); arg2.mutable_const_expr().set_int_value(2); arg2.set_id(3); } ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"("bar"#2.foo( 2#3 )#1)")); } TEST(ExprPrinterTest, List) { Expr expr; expr.set_id(1); { ListExprElement& arg1 = expr.mutable_list_expr().add_elements(); arg1.set_optional(true); arg1.mutable_expr().set_id(2); arg1.mutable_expr().mutable_const_expr().set_int_value(1); } { ListExprElement& arg2 = expr.mutable_list_expr().add_elements(); arg2.set_optional(false); arg2.mutable_expr().set_id(3); arg2.mutable_expr().mutable_const_expr().set_int_value(2); } ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"([ ?1#2, 2#3 ]#1)")); } TEST(ExprPrinterTest, Map) { Expr expr; expr.set_id(1); { MapExprEntry& entry = expr.mutable_map_expr().add_entries(); entry.set_id(2); entry.set_optional(true); entry.mutable_key().set_id(3); entry.mutable_key().mutable_const_expr().set_string_value("k1"); entry.mutable_value().set_id(4); entry.mutable_value().mutable_const_expr().set_string_value("v1"); } { MapExprEntry& entry = expr.mutable_map_expr().add_entries(); entry.set_id(5); entry.set_optional(false); entry.mutable_key().set_id(6); entry.mutable_key().mutable_const_expr().set_string_value("k2"); entry.mutable_value().set_id(7); entry.mutable_value().mutable_const_expr().set_string_value("v2"); } ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"({ ?"k1"#3:"v1"#4#2, "k2"#6:"v2"#7#5 }#1)")); } TEST(ExprPrinterTest, Struct) { Expr expr; expr.set_id(1); auto& struct_expr = expr.mutable_struct_expr(); struct_expr.set_name("Foo"); { StructExprField& field1 = struct_expr.add_fields(); field1.set_optional(true); field1.set_id(2); field1.set_name("field1"); field1.mutable_value().set_id(3); field1.mutable_value().mutable_const_expr().set_int_value(1); } { StructExprField& field2 = struct_expr.add_fields(); field2.set_optional(false); field2.set_id(4); field2.set_name("field2"); field2.mutable_value().set_id(5); field2.mutable_value().mutable_const_expr().set_int_value(1); } ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), (R"(Foo{ ?field1:1#3#2, field2:1#5#4 }#1)")); } TEST(ExprPrinterTest, Comprehension) { Expr expr; expr.set_id(1); expr.mutable_comprehension_expr().set_iter_var("x"); expr.mutable_comprehension_expr().set_accu_var("@result"); auto& range = expr.mutable_comprehension_expr().mutable_iter_range(); range.set_id(2); range.mutable_ident_expr().set_name("range"); auto& accu_init = expr.mutable_comprehension_expr().mutable_accu_init(); accu_init.set_id(3); accu_init.mutable_ident_expr().set_name("accu_init"); auto& loop_condition = expr.mutable_comprehension_expr().mutable_loop_condition(); loop_condition.set_id(4); loop_condition.mutable_ident_expr().set_name("loop_condition"); auto& loop_step = expr.mutable_comprehension_expr().mutable_loop_step(); loop_step.set_id(5); loop_step.mutable_ident_expr().set_name("loop_step"); auto& result = expr.mutable_comprehension_expr().mutable_result(); result.set_id(6); result.mutable_ident_expr().set_name("result"); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.Print(expr), R"(__comprehension__( // Variable x, // Target range#2, // Accumulator @result, // Init accu_init#3, // LoopCondition loop_condition#4, // LoopStep loop_step#5, // Result result#6)#1)"); } TEST(ExprPrinterTest, Proto) { ParserOptions options; options.enable_optional_syntax = true; options.enable_hidden_accumulator_var = true; ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse(R"cel( "foo".startsWith("bar") || [1, ?2, 3].exists(x, x in {?"b": "foo"}) || Foo{ byte_value: b'bytes', bool_value: false, uint_value: 1u, double_value: 1.1, }.bar )cel", "", options)); ExprPrinter printer(TestAdorner::Get()); EXPECT_EQ(printer.PrintProto(parsed_expr.expr()), R"ast(_||_( _||_( "foo"#1.startsWith( "bar"#3 )#2, __comprehension__( // Variable x, // Target [ 1#5, ?2#6, 3#7 ]#4, // Accumulator @result, // Init false#16, // LoopCondition @not_strictly_false( !_( @result#17 )#18 )#19, // LoopStep _||_( @result#20, @in( x#10, { ?"b"#14:"foo"#15#13 }#12 )#11 )#21, // Result @result#22)#23 )#24, Foo{ byte_value:b"bytes"#27#26, bool_value:false#29#28, uint_value:1u#31#30, double_value:1.1#33#32 }#25.bar#34 )#35)ast"); } } // namespace } // namespace cel::test ================================================ FILE: testutil/test_json_names.proto ================================================ edition = "2024"; package cel.cpp.testutil; option features.enforce_naming_style = STYLE_LEGACY; // This proto tests json_name options message TestJsonNames { int32 int32_snake_case_json_name = 1 [json_name = "int32_snake_case_json_name"]; int64 int64_camel_case_json_name = 2 [json_name = "int64CamelCaseJsonName"]; uint32 uint32_default_json_name = 3; uint64 uint64_custom_json_name = 4 [json_name = "uint64-custom-json-name"]; // Collides with normal field name. string string_json_name_shadows = 5 [json_name = "single_string"]; string single_string = 6; // protoc should fail on cases like these // double double_json_shadow_default = 7 [json_name = "doubleJsonDefault"] // double double_json_default = 8; // double double_json_swapped_a = 7 [json_name = "double_json_swapped_b"]; // double double_json_swapped_b = 8 [json_name = "double_json_swapped_a"]; extensions 100 to 199; } extend TestJsonNames { int32 int32_snake_case_ext = 100; int64 int64CamelCaseExt = 101; } ================================================ FILE: testutil/util.h ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_UTIL_H_ #define THIRD_PARTY_CEL_CPP_TESTUTIL_UTIL_H_ #include "internal/proto_matchers.h" namespace google::api::expr::testutil { // alias for old namespace // prefer using cel::internal::test::EqualsProto. using ::cel::internal::test::EqualsProto; } // namespace google::api::expr::testutil #endif // THIRD_PARTY_CEL_CPP_TESTUTIL_UTIL_H_ ================================================ FILE: tools/BUILD ================================================ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) licenses(["notice"]) cc_library( name = "cel_field_extractor", srcs = ["cel_field_extractor.cc"], hdrs = ["cel_field_extractor.h"], deps = [ ":navigable_ast", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "cel_field_extractor_test", srcs = ["cel_field_extractor_test.cc"], deps = [ ":cel_field_extractor", "//internal:testing", "//parser", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "cel_unparser", srcs = [ "cel_unparser.cc", ], hdrs = [ "cel_unparser.h", ], deps = [ "//common:operators", "//internal:status_macros", "//internal:strings", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:timestamp_cc_proto", "@com_googlesource_code_re2//:re2", ], ) cc_test( name = "cel_unparser_test", srcs = ["cel_unparser_test.cc"], deps = [ ":cel_unparser", "//internal:proto_matchers", "//internal:testing", "//parser", "//parser:options", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "flatbuffers_backed_impl", srcs = [ "flatbuffers_backed_impl.cc", ], hdrs = [ "flatbuffers_backed_impl.h", ], deps = [ "//eval/public:cel_value", "@com_github_google_flatbuffers//:flatbuffers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) cc_test( name = "flatbuffers_backed_impl_test", size = "small", srcs = [ "flatbuffers_backed_impl_test.cc", ], data = [ "//tools/testdata:flatbuffers_reflection_out", ], deps = [ ":flatbuffers_backed_impl", "//internal:status_macros", "//internal:testing", "@com_github_google_flatbuffers//:flatbuffers", ], ) cc_library( name = "navigable_ast", srcs = ["navigable_ast.cc"], hdrs = ["navigable_ast.h"], deps = [ "//common/ast:navigable_ast_internal", "//eval/public:ast_traverse", "//eval/public:ast_visitor", "//eval/public:ast_visitor_base", "//eval/public:source_position", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_test( name = "navigable_ast_test", srcs = ["navigable_ast_test.cc"], deps = [ ":navigable_ast", "//base:builtins", "//internal:testing", "//parser", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", ], ) cc_library( name = "branch_coverage", srcs = ["branch_coverage.cc"], hdrs = ["branch_coverage.h"], deps = [ ":navigable_ast", "//common:value", "//eval/internal:interop", "//eval/public:cel_value", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:variant", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "branch_coverage_test", srcs = ["branch_coverage_test.cc"], data = [ "//tools/testdata:coverage_testdata", ], deps = [ ":branch_coverage", ":navigable_ast", "//base:builtins", "//common:value", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", "//eval/public:cel_value", "//internal:proto_file_util", "//internal:testing", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) cc_library( name = "descriptor_pool_builder", srcs = ["descriptor_pool_builder.cc"], hdrs = ["descriptor_pool_builder.h"], deps = [ "//common:minimal_descriptor_database", "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) cc_test( name = "descriptor_pool_builder_test", srcs = ["descriptor_pool_builder_test.cc"], deps = [ ":descriptor_pool_builder", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], ) ================================================ FILE: tools/branch_coverage.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/branch_coverage.h" #include #include #include "cel/expr/checked.pb.h" #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/overload.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "absl/types/variant.h" #include "common/value.h" #include "eval/internal/interop.h" #include "eval/public/cel_value.h" #include "tools/navigable_ast.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::cel::expr::CheckedExpr; using ::cel::expr::Type; using ::google::api::expr::runtime::CelValue; const absl::Status& UnsupportedConversionError() { static absl::NoDestructor kErr( absl::StatusCode::kInternal, "Conversion to legacy type unsupported."); return *kErr; } // Constant literal. // // These should be handled separately from variable parts of the AST to not // inflate / deflate coverage wrt variable inputs. struct ConstantNode {}; // A boolean node. // // Branching in CEL is mostly determined by boolean subexpression results, so // specify intercepted values. struct BoolNode { int result_true; int result_false; int result_error; }; // Catch all for other nodes. struct OtherNode { int result_error; }; // Representation for coverage of an AST node. struct CoverageNode { int evaluate_count; absl::variant kind; }; const Type* absl_nullable FindCheckerType(const CheckedExpr& expr, int64_t expr_id) { if (auto it = expr.type_map().find(expr_id); it != expr.type_map().end()) { return &it->second; } return nullptr; } class BranchCoverageImpl : public BranchCoverage { public: explicit BranchCoverageImpl(const CheckedExpr& expr) : expr_(expr) {} // Implement public interface. void Record(int64_t expr_id, const Value& value) override { auto value_or = interop_internal::ToLegacyValue(&arena_, value); if (!value_or.ok()) { // TODO(uncreated-issue/65): Use pointer identity for UnsupportedConversionError // as a sentinel value. The legacy CEL value just wraps the error pointer. // This can be removed after the value migration is complete. RecordImpl(expr_id, CelValue::CreateError(&UnsupportedConversionError())); } else { return RecordImpl(expr_id, *value_or); } } void RecordLegacyValue(int64_t expr_id, const CelValue& value) override { return RecordImpl(expr_id, value); } BranchCoverage::NodeCoverageStats StatsForNode( int64_t expr_id) const override; const NavigableProtoAst& ast() const override; const CheckedExpr& expr() const override; // Initializes the coverage implementation. This should be called by the // factory function (synchronously). // // Other mutation operations must be synchronized since we don't have control // of when the instrumented expressions get called. void Init(); private: friend class BranchCoverage; void RecordImpl(int64_t expr_id, const CelValue& value); // Infer it the node is boolean typed. Check the type map if available. // Otherwise infer typing based on built-in functions. bool InferredBoolType(const NavigableProtoAstNode& node) const; CheckedExpr expr_; NavigableProtoAst ast_; mutable absl::Mutex coverage_nodes_mu_; absl::flat_hash_map coverage_nodes_ ABSL_GUARDED_BY(coverage_nodes_mu_); absl::flat_hash_set unexpected_expr_ids_ ABSL_GUARDED_BY(coverage_nodes_mu_); google::protobuf::Arena arena_; }; BranchCoverage::NodeCoverageStats BranchCoverageImpl::StatsForNode( int64_t expr_id) const { BranchCoverage::NodeCoverageStats stats{ /*is_boolean=*/false, /*evaluation_count=*/0, /*error_count=*/0, /*boolean_true_count=*/0, /*boolean_false_count=*/0, }; absl::MutexLock lock(coverage_nodes_mu_); auto it = coverage_nodes_.find(expr_id); if (it != coverage_nodes_.end()) { const CoverageNode& coverage_node = it->second; stats.evaluation_count = coverage_node.evaluate_count; absl::visit(absl::Overload([&](const ConstantNode& cov) {}, [&](const OtherNode& cov) { stats.error_count = cov.result_error; }, [&](const BoolNode& cov) { stats.is_boolean = true; stats.boolean_true_count = cov.result_true; stats.boolean_false_count = cov.result_false; stats.error_count = cov.result_error; }), coverage_node.kind); return stats; } return stats; } const NavigableProtoAst& BranchCoverageImpl::ast() const { return ast_; } const CheckedExpr& BranchCoverageImpl::expr() const { return expr_; } bool BranchCoverageImpl::InferredBoolType( const NavigableProtoAstNode& node) const { int64_t expr_id = node.expr()->id(); const auto* checker_type = FindCheckerType(expr_, expr_id); if (checker_type != nullptr) { return checker_type->has_primitive() && checker_type->primitive() == Type::BOOL; } return false; } void BranchCoverageImpl::Init() ABSL_NO_THREAD_SAFETY_ANALYSIS { ast_ = NavigableProtoAst::Build(expr_.expr()); for (const NavigableProtoAstNode& node : ast_.Root().DescendantsPreorder()) { int64_t expr_id = node.expr()->id(); CoverageNode& coverage_node = coverage_nodes_[expr_id]; coverage_node.evaluate_count = 0; if (node.node_kind() == NodeKind::kConstant) { coverage_node.kind = ConstantNode{}; } else if (InferredBoolType(node)) { coverage_node.kind = BoolNode{0, 0, 0}; } else { coverage_node.kind = OtherNode{0}; } } } void BranchCoverageImpl::RecordImpl(int64_t expr_id, const CelValue& value) { absl::MutexLock lock(coverage_nodes_mu_); auto it = coverage_nodes_.find(expr_id); if (it == coverage_nodes_.end()) { unexpected_expr_ids_.insert(expr_id); it = coverage_nodes_.insert({expr_id, CoverageNode{0, {}}}).first; if (value.IsBool()) { it->second.kind = BoolNode{0, 0, 0}; } } CoverageNode& coverage_node = it->second; coverage_node.evaluate_count++; bool is_error = value.IsError() && // Filter conversion errors for evaluator internal types. // TODO(uncreated-issue/65): RecordImpl operates on legacy values so // special case conversion errors. This error is really just a // sentinel value and doesn't need to round-trip between // legacy and legacy types. value.ErrorOrDie() != &UnsupportedConversionError(); absl::visit(absl::Overload([&](ConstantNode& node) {}, [&](OtherNode& cov) { if (is_error) { cov.result_error++; } }, [&](BoolNode& cov) { if (value.IsBool()) { bool held_value = value.BoolOrDie(); if (held_value) { cov.result_true++; } else { cov.result_false++; } } else if (is_error) { cov.result_error++; } }), coverage_node.kind); } } // namespace std::unique_ptr CreateBranchCoverage(const CheckedExpr& expr) { auto result = std::make_unique(expr); result->Init(); return result; } } // namespace cel ================================================ FILE: tools/branch_coverage.h ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TOOLS_BRANCH_COVERAGE_H_ #define THIRD_PARTY_CEL_CPP_TOOLS_BRANCH_COVERAGE_H_ #include #include #include "cel/expr/checked.pb.h" #include "absl/base/attributes.h" #include "common/value.h" #include "eval/public/cel_value.h" #include "tools/navigable_ast.h" namespace cel { // Interface for BranchCoverage collection utility. // // This provides a factory for instrumentation that collects coverage // information over multiple executions of a CEL expression. This does not // provide any mechanism for de-duplicating multiple CheckedExpr instances // that represent the same expression within or across processes. // // The default implementation is thread safe. // // TODO(uncreated-issue/65): add support for interesting aggregate stats. class BranchCoverage { public: struct NodeCoverageStats { bool is_boolean; int evaluation_count; int boolean_true_count; int boolean_false_count; int error_count; }; virtual ~BranchCoverage() = default; virtual void Record(int64_t expr_id, const Value& value) = 0; virtual void RecordLegacyValue( int64_t expr_id, const google::api::expr::runtime::CelValue& value) = 0; virtual NodeCoverageStats StatsForNode(int64_t expr_id) const = 0; virtual const NavigableProtoAst& ast() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; virtual const cel::expr::CheckedExpr& expr() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0; }; std::unique_ptr CreateBranchCoverage( const cel::expr::CheckedExpr& expr); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_TOOLS_BRANCH_COVERAGE_H_ ================================================ FILE: tools/branch_coverage_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/branch_coverage.h" #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/substitute.h" #include "base/builtins.h" #include "common/value.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "internal/proto_file_util.h" #include "internal/testing.h" #include "tools/navigable_ast.h" #include "google/protobuf/arena.h" namespace cel { namespace { using ::cel::internal::test::ReadTextProtoFromFile; using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::RegisterBuiltinFunctions; // int1 < int2 && // (43 > 42) && // !(bool1 || bool2) && // 4 / int_divisor >= 1 && // (ternary_c ? ternary_t : ternary_f) constexpr char kCoverageExamplePath[] = "tools/testdata/coverage_example.textproto"; const CheckedExpr& TestExpression() { static absl::NoDestructor expression([]() { CheckedExpr value; ABSL_CHECK_OK(ReadTextProtoFromFile(kCoverageExamplePath, value)); return value; }()); return *expression; } std::string FormatNodeStats(const BranchCoverage::NodeCoverageStats& stats) { return absl::Substitute( "is_bool: $0; evaluated: $1; bool_true: $2; bool_false: $3; error: $4", stats.is_boolean, stats.evaluation_count, stats.boolean_true_count, stats.boolean_false_count, stats.error_count); } google::api::expr::runtime::CelEvaluationListener EvaluationListenerForCoverage( BranchCoverage* coverage) { return [coverage](int64_t id, const CelValue& value, google::protobuf::Arena* arena) { coverage->RecordLegacyValue(id, value); return absl::OkStatus(); }; } MATCHER_P(MatchesNodeStats, expected, "") { const BranchCoverage::NodeCoverageStats& actual = arg; *result_listener << "\n"; *result_listener << "Expected: " << FormatNodeStats(expected); *result_listener << "\n"; *result_listener << "Got: " << FormatNodeStats(actual); return actual.is_boolean == expected.is_boolean && actual.evaluation_count == expected.evaluation_count && actual.boolean_true_count == expected.boolean_true_count && actual.boolean_false_count == expected.boolean_false_count && actual.error_count == expected.error_count; } MATCHER(NodeStatsIsBool, "") { const BranchCoverage::NodeCoverageStats& actual = arg; *result_listener << "\n"; *result_listener << "Expected: " << FormatNodeStats({true, 0, 0, 0, 0}); *result_listener << "\n"; *result_listener << "Got: " << FormatNodeStats(actual); return actual.is_boolean == true; } TEST(BranchCoverage, DefaultsForUntrackedId) { auto coverage = CreateBranchCoverage(TestExpression()); using Stats = BranchCoverage::NodeCoverageStats; EXPECT_THAT(coverage->StatsForNode(99), MatchesNodeStats(Stats{/*is_boolean=*/false, /*evaluation_count=*/0, /*boolean_true_count=*/0, /*boolean_false_count=*/0, /*error_count=*/0})); } TEST(BranchCoverage, Record) { auto coverage = CreateBranchCoverage(TestExpression()); int64_t root_id = coverage->expr().expr().id(); coverage->Record(root_id, cel::BoolValue(false)); using Stats = BranchCoverage::NodeCoverageStats; EXPECT_THAT(coverage->StatsForNode(root_id), MatchesNodeStats(Stats{/*is_boolean=*/true, /*evaluation_count=*/1, /*boolean_true_count=*/0, /*boolean_false_count=*/1, /*error_count=*/0})); } TEST(BranchCoverage, RecordUnexpectedId) { auto coverage = CreateBranchCoverage(TestExpression()); int64_t unexpected_id = 99; coverage->Record(unexpected_id, cel::BoolValue(false)); using Stats = BranchCoverage::NodeCoverageStats; EXPECT_THAT(coverage->StatsForNode(unexpected_id), MatchesNodeStats(Stats{/*is_boolean=*/true, /*evaluation_count=*/1, /*boolean_true_count=*/0, /*boolean_false_count=*/1, /*error_count=*/0})); } TEST(BranchCoverage, IncrementsCounters) { auto coverage = CreateBranchCoverage(TestExpression()); EXPECT_TRUE(static_cast(coverage->ast())); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // int1 < int2 && // (43 > 42) && // !(bool1 || bool2) && // 4 / int_divisor >= 1 && // (ternary_c ? ternary_t : ternary_f) ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression(&TestExpression())); google::protobuf::Arena arena; Activation activation; activation.InsertValue("bool1", CelValue::CreateBool(false)); activation.InsertValue("bool2", CelValue::CreateBool(false)); activation.InsertValue("int1", CelValue::CreateInt64(42)); activation.InsertValue("int2", CelValue::CreateInt64(43)); activation.InsertValue("int_divisor", CelValue::CreateInt64(4)); activation.InsertValue("ternary_c", CelValue::CreateBool(true)); activation.InsertValue("ternary_t", CelValue::CreateBool(true)); activation.InsertValue("ternary_f", CelValue::CreateBool(false)); ASSERT_OK_AND_ASSIGN( auto result, program->Trace(activation, &arena, EvaluationListenerForCoverage(coverage.get()))); EXPECT_TRUE(result.IsBool() && result.BoolOrDie() == true); using Stats = BranchCoverage::NodeCoverageStats; const NavigableProtoAst& ast = coverage->ast(); auto root_node_stats = coverage->StatsForNode(ast.Root().expr()->id()); EXPECT_THAT(root_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, /*evaluation_count=*/1, /*boolean_true_count=*/1, /*boolean_false_count=*/0, /*error_count=*/0})); const NavigableProtoAstNode* ternary; for (const auto& node : ast.Root().DescendantsPreorder()) { if (node.node_kind() == NodeKind::kCall && node.expr()->call_expr().function() == cel::builtin::kTernary) { ternary = &node; break; } } ASSERT_NE(ternary, nullptr); auto ternary_node_stats = coverage->StatsForNode(ternary->expr()->id()); // Ternary gets optimized to conditional jumps, so it isn't instrumented // directly in stack machine impl. EXPECT_THAT(ternary_node_stats, NodeStatsIsBool()); const auto* false_node = ternary->children().at(2); auto false_node_stats = coverage->StatsForNode(false_node->expr()->id()); EXPECT_THAT(false_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, /*evaluation_count=*/0, /*boolean_true_count=*/0, /*boolean_false_count=*/0, /*error_count=*/0})); const NavigableProtoAstNode* not_arg_expr; for (const auto& node : ast.Root().DescendantsPreorder()) { if (node.node_kind() == NodeKind::kCall && node.expr()->call_expr().function() == cel::builtin::kNot) { not_arg_expr = node.children().at(0); break; } } ASSERT_NE(not_arg_expr, nullptr); auto not_expr_node_stats = coverage->StatsForNode(not_arg_expr->expr()->id()); EXPECT_THAT(not_expr_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, /*evaluation_count=*/1, /*boolean_true_count=*/0, /*boolean_false_count=*/1, /*error_count=*/0})); const NavigableProtoAstNode* div_expr; for (const auto& node : ast.Root().DescendantsPreorder()) { if (node.node_kind() == NodeKind::kCall && node.expr()->call_expr().function() == cel::builtin::kDivide) { div_expr = &node; break; } } ASSERT_NE(div_expr, nullptr); auto div_expr_stats = coverage->StatsForNode(div_expr->expr()->id()); EXPECT_THAT(div_expr_stats, MatchesNodeStats(Stats{/*is_boolean=*/false, /*evaluation_count=*/1, /*boolean_true_count=*/0, /*boolean_false_count=*/0, /*error_count=*/0})); } TEST(BranchCoverage, AccumulatesAcrossRuns) { auto coverage = CreateBranchCoverage(TestExpression()); EXPECT_TRUE(static_cast(coverage->ast())); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // int1 < int2 && // (43 > 42) && // !(bool1 || bool2) && // 4 / int_divisor >= 1 && // (ternary_c ? ternary_t : ternary_f) ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression(&TestExpression())); google::protobuf::Arena arena; Activation activation; activation.InsertValue("bool1", CelValue::CreateBool(false)); activation.InsertValue("bool2", CelValue::CreateBool(false)); activation.InsertValue("int1", CelValue::CreateInt64(42)); activation.InsertValue("int2", CelValue::CreateInt64(43)); activation.InsertValue("int_divisor", CelValue::CreateInt64(4)); activation.InsertValue("ternary_c", CelValue::CreateBool(true)); activation.InsertValue("ternary_t", CelValue::CreateBool(true)); activation.InsertValue("ternary_f", CelValue::CreateBool(false)); ASSERT_OK_AND_ASSIGN( auto result, program->Trace(activation, &arena, EvaluationListenerForCoverage(coverage.get()))); EXPECT_TRUE(result.IsBool() && result.BoolOrDie() == true); activation.RemoveValueEntry("ternary_c"); activation.RemoveValueEntry("ternary_f"); activation.InsertValue("ternary_c", CelValue::CreateBool(false)); activation.InsertValue("ternary_f", CelValue::CreateBool(false)); ASSERT_OK_AND_ASSIGN( result, program->Trace(activation, &arena, EvaluationListenerForCoverage(coverage.get()))); EXPECT_TRUE(result.IsBool() && result.BoolOrDie() == false) << result.DebugString(); using Stats = BranchCoverage::NodeCoverageStats; const NavigableProtoAst& ast = coverage->ast(); auto root_node_stats = coverage->StatsForNode(ast.Root().expr()->id()); EXPECT_THAT(root_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, /*evaluation_count=*/2, /*boolean_true_count=*/1, /*boolean_false_count=*/1, /*error_count=*/0})); const NavigableProtoAstNode* ternary; for (const auto& node : ast.Root().DescendantsPreorder()) { if (node.node_kind() == NodeKind::kCall && node.expr()->call_expr().function() == cel::builtin::kTernary) { ternary = &node; break; } } ASSERT_NE(ternary, nullptr); auto ternary_node_stats = coverage->StatsForNode(ternary->expr()->id()); // Ternary gets optimized into conditional jumps for stack machine plan. EXPECT_THAT(ternary_node_stats, NodeStatsIsBool()); const auto* false_node = ternary->children().at(2); auto false_node_stats = coverage->StatsForNode(false_node->expr()->id()); EXPECT_THAT(false_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, /*evaluation_count=*/1, /*boolean_true_count=*/0, /*boolean_false_count=*/1, /*error_count=*/0})); } TEST(BranchCoverage, CountsErrors) { auto coverage = CreateBranchCoverage(TestExpression()); EXPECT_TRUE(static_cast(coverage->ast())); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // int1 < int2 && // (43 > 42) && // !(bool1 || bool2) && // 4 / int_divisor >= 1 && // (ternary_c ? ternary_t : ternary_f) ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression(&TestExpression())); google::protobuf::Arena arena; Activation activation; activation.InsertValue("bool1", CelValue::CreateBool(false)); activation.InsertValue("bool2", CelValue::CreateBool(false)); activation.InsertValue("int1", CelValue::CreateInt64(42)); activation.InsertValue("int2", CelValue::CreateInt64(43)); activation.InsertValue("int_divisor", CelValue::CreateInt64(0)); activation.InsertValue("ternary_c", CelValue::CreateBool(true)); activation.InsertValue("ternary_t", CelValue::CreateBool(false)); activation.InsertValue("ternary_f", CelValue::CreateBool(false)); ASSERT_OK_AND_ASSIGN( auto result, program->Trace(activation, &arena, EvaluationListenerForCoverage(coverage.get()))); EXPECT_TRUE(result.IsBool() && result.BoolOrDie() == false); using Stats = BranchCoverage::NodeCoverageStats; const NavigableProtoAst& ast = coverage->ast(); auto root_node_stats = coverage->StatsForNode(ast.Root().expr()->id()); EXPECT_THAT(root_node_stats, MatchesNodeStats(Stats{/*is_boolean=*/true, /*evaluation_count=*/1, /*boolean_true_count=*/0, /*boolean_false_count=*/1, /*error_count=*/0})); const NavigableProtoAstNode* ternary; for (const auto& node : ast.Root().DescendantsPreorder()) { if (node.node_kind() == NodeKind::kCall && node.expr()->call_expr().function() == cel::builtin::kTernary) { ternary = &node; break; } } const NavigableProtoAstNode* div_expr; for (const auto& node : ast.Root().DescendantsPreorder()) { if (node.node_kind() == NodeKind::kCall && node.expr()->call_expr().function() == cel::builtin::kDivide) { div_expr = &node; break; } } ASSERT_NE(div_expr, nullptr); auto div_expr_stats = coverage->StatsForNode(div_expr->expr()->id()); EXPECT_THAT(div_expr_stats, MatchesNodeStats(Stats{/*is_boolean=*/false, /*evaluation_count=*/1, /*boolean_true_count=*/0, /*boolean_false_count=*/0, /*error_count=*/1})); } } // namespace } // namespace cel ================================================ FILE: tools/cel_field_extractor.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/cel_field_extractor.h" #include #include #include #include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tools/navigable_ast.h" namespace cel { namespace { bool IsComprehensionDefinedField(const cel::NavigableProtoAstNode& node) { const cel::NavigableProtoAstNode* current_node = &node; while (current_node->parent() != nullptr) { current_node = current_node->parent(); if (current_node->node_kind() != cel::NodeKind::kComprehension) { continue; } std::string ident_name = node.expr()->ident_expr().name(); bool iter_var_match = ident_name == current_node->expr()->comprehension_expr().iter_var(); bool iter_var2_match = ident_name == current_node->expr()->comprehension_expr().iter_var2(); bool accu_var_match = ident_name == current_node->expr()->comprehension_expr().accu_var(); if (iter_var_match || iter_var2_match || accu_var_match) { return true; } } return false; } } // namespace absl::flat_hash_set ExtractFieldPaths( const cel::expr::Expr& expr) { NavigableProtoAst ast = NavigableProtoAst::Build(expr); absl::flat_hash_set field_paths; std::vector fields_in_scope; // Preorder traversal works because the select nodes (in a well-formed // expression) always have only one operand, so its operand is visited // next in the loop iteration (which results in the path being extended, // completed, or discarded if uninteresting). for (const cel::NavigableProtoAstNode& node : ast.Root().DescendantsPreorder()) { if (node.node_kind() == cel::NodeKind::kSelect) { fields_in_scope.push_back(node.expr()->select_expr().field()); continue; } if (node.node_kind() == cel::NodeKind::kIdent && !IsComprehensionDefinedField(node)) { fields_in_scope.push_back(node.expr()->ident_expr().name()); std::reverse(fields_in_scope.begin(), fields_in_scope.end()); field_paths.insert(absl::StrJoin(fields_in_scope, ".")); } fields_in_scope.clear(); } return field_paths; } } // namespace cel ================================================ FILE: tools/cel_field_extractor.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TOOLS_CEL_FIELD_EXTRACTOR_H #define THIRD_PARTY_CEL_CPP_TOOLS_CEL_FIELD_EXTRACTOR_H #include #include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_set.h" namespace cel { // ExtractExpressionFieldPaths attempts to extract the set of unique field // selection paths from top level identifiers (e.g. "request.user.id"). // // One possible use case for this class is to determine which fields of a // serialized message are referenced by a CEL query, enabling partial // deserialization for performance optimization. // // Implementation notes: // The extraction logic focuses on identifying chains of `Select` operations // that terminate with a primary identifier node (`IdentExpr`). For example, // in the expression `message.field.subfield == 10`, the path // "message.field.subfield" would be extracted. // // Identifiers defined locally within CEL comprehension expressions (e.g., // comprehension variables aliases defined by `iter_var`, `iter_var2`, // `accu_var` in the AST) are NOT included. Example: // `list.exists(elem, elem.field == 'value')` would return {"list"} only. // // Container indexing with the _[_] is not considered, but map indexing with // the select operator is considered. For example: // `message.map_field.key || message.map_field['foo']` results in // {'message.map_field.key', 'message.map_field'} // // This implementation does not consider type check metadata, so there is no // understanding of whether the primary identifiers and field accesses // necessarily map to proto messages or proto field accesses. The field // also does not have any understanding of the type of the leaf of the // select path. // // Example: // Given the CEL expression: // `(request.user.id == 'test' && request.user.attributes.exists(attr, // attr.key // == 'role')) || size(request.items) > 0` // // The extracted field paths would be: // - "request.user.id" // - "request.user.attributes" (because `attr` is a comprehension variable) // - "request.items" absl::flat_hash_set ExtractFieldPaths( const cel::expr::Expr& expr); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_TOOLS_CEL_FIELD_EXTRACTOR_H ================================================ FILE: tools/cel_field_extractor_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/cel_field_extractor.h" #include #include "cel/expr/syntax.pb.h" #include "absl/container/flat_hash_set.h" #include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "internal/testing.h" #include "parser/parser.h" namespace cel { namespace { using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; absl::flat_hash_set GetExtractedFields( const std::string& cel_query) { absl::StatusOr parsed_expr_or_status = Parse(cel_query); ABSL_CHECK_OK(parsed_expr_or_status); return ExtractFieldPaths(parsed_expr_or_status.value().expr()); } TEST(TestExtractFieldPaths, CelExprWithOneField) { EXPECT_THAT(GetExtractedFields("field_name"), UnorderedElementsAre("field_name")); } TEST(TestExtractFieldPaths, CelExprWithNoWithLiteral) { EXPECT_THAT(GetExtractedFields("'field_name'"), IsEmpty()); } TEST(TestExtractFieldPaths, CelExprWithFunctionCallOnSingleField) { EXPECT_THAT(GetExtractedFields("!boolean_field"), UnorderedElementsAre("boolean_field")); } TEST(TestExtractFieldPaths, CelExprWithSizeFuncCallOnSingleField) { EXPECT_THAT(GetExtractedFields("size(repeated_field)"), UnorderedElementsAre("repeated_field")); } TEST(TestExtractFieldPaths, CelExprWithNestedField) { EXPECT_THAT(GetExtractedFields("message_field.nested_field.nested_field2"), UnorderedElementsAre("message_field.nested_field.nested_field2")); } TEST(TestExtractFieldPaths, CelExprWithNestedFieldAndIndexAccess) { EXPECT_THAT(GetExtractedFields( "repeated_message_field.nested_field[0].nested_field2"), UnorderedElementsAre("repeated_message_field.nested_field")); } TEST(TestExtractFieldPaths, CelExprWithMultipleFunctionCalls) { EXPECT_THAT(GetExtractedFields( "(size(repeated_field) > 0 && !boolean_field == true) || " "request.valid == true && request.count == 0"), UnorderedElementsAre("boolean_field", "repeated_field", "request.valid", "request.count")); } TEST(TestExtractFieldPaths, CelExprWithNestedComprehension) { EXPECT_THAT( GetExtractedFields("repeated_field_1.exists(e, e.key == 'one') && " "req.repeated_field_2.exists(x, " "x.y.z == 'val' &&" "x.array.exists(y, y == 'val' && req.bool_field == " "true && x.bool_field == false))"), UnorderedElementsAre("req.repeated_field_2", "req.bool_field", "repeated_field_1")); } TEST(TestExtractFieldPaths, CelExprWithMultipleComprehension) { EXPECT_THAT( GetExtractedFields( "repeated_field_1.exists(e, e.key == 'one' && y.field_1 == 'val') && " "repeated_field_2.exists(y, y.key == 'one' && e.field_2 == 'val')"), UnorderedElementsAre("repeated_field_1", "repeated_field_2", "e.field_2", "y.field_1")); } TEST(TestExtractFieldPaths, CelExprWithListLiteral) { EXPECT_THAT(GetExtractedFields("['a', b, 3].exists(x, x == 1)"), UnorderedElementsAre("b")); } TEST(TestExtractFieldPaths, CelExprWithFunctionCallsAndRepeatedFields) { EXPECT_THAT( GetExtractedFields("data == 'data_1' && field_1 == 'val_1' &&" "(matches(req.field_2, 'val_1') == true) &&" "repeated_field[0].priority >= 200"), UnorderedElementsAre("data", "field_1", "req.field_2", "repeated_field")); } TEST(TestExtractFieldPaths, CelExprWithFunctionOnRepeatedField) { EXPECT_THAT( GetExtractedFields("(contains_data == false && " "data.field_1=='value_1') || " "size(data.nodes) > 0 && " "data.nodes[0].field_2=='value_2'"), UnorderedElementsAre("contains_data", "data.field_1", "data.nodes")); } TEST(TestExtractFieldPaths, CelExprContainingEndsWithFunction) { EXPECT_THAT(GetExtractedFields("data.repeated_field.exists(f, " "f.field_1.field_2.endsWith('val_1')) || " "data.field_3.endsWith('val_3')"), UnorderedElementsAre("data.repeated_field", "data.field_3")); } TEST(TestExtractFieldPaths, CelExprWithMatchFunctionInsideComprehensionAndRegexConstants) { EXPECT_THAT(GetExtractedFields("req.field_1.field_2=='val_1' && " "data!=null && req.repeated_field.exists(f, " "f.matches('a100.*|.*h100_80gb.*|.*h200.*'))"), UnorderedElementsAre("req.field_1.field_2", "req.repeated_field", "data")); } TEST(TestExtractFieldPaths, CelExprWithMultipleChecksInComprehension) { EXPECT_THAT( GetExtractedFields("req.field.repeated_field.exists(f, f.key == 'data_1'" " && f.str_value == 'val_1') && " "req.metadata.type == 3"), UnorderedElementsAre("req.field.repeated_field", "req.metadata.type")); } } // namespace } // namespace cel ================================================ FILE: tools/cel_unparser.cc ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/cel_unparser.h" #include #include #include #include #include "cel/expr/syntax.pb.h" #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "common/operators.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "re2/re2.h" namespace google::api::expr { namespace { using ::cel::expr::CheckedExpr; using ::cel::expr::Constant; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::cel::expr::SourceInfo; using ::google::api::expr::common::CelOperator; using ::google::api::expr::common::IsOperatorLeftRecursive; using ::google::api::expr::common::IsOperatorLowerPrecedence; using ::google::api::expr::common::IsOperatorSamePrecedence; using ::google::api::expr::common::LookupBinaryOperator; using ::google::api::expr::common::LookupUnaryOperator; constexpr absl::string_view kLeftParen = "("; constexpr absl::string_view kRightParen = ")"; constexpr absl::string_view kLeftBracket = "["; constexpr absl::string_view kRightBracket = "]"; constexpr absl::string_view kLeftBrace = "{"; constexpr absl::string_view kRightBrace = "}"; constexpr absl::string_view kSpace = " "; constexpr absl::string_view kDot = "."; constexpr absl::string_view kColon = ":"; constexpr absl::string_view kComma = ","; constexpr absl::string_view kBackQuote = "`"; constexpr absl::string_view kQuestionMark = "?"; static const LazyRE2 kSimpleIdentifierPattern = {R"([a-zA-Z_][a-zA-Z0-9_]*)"}; const absl::flat_hash_set& ReservedFieldIdentifiers() { static const absl::NoDestructor> kReservedFieldIdentifiers( []() { return absl::flat_hash_set{"in"}; }()); return *kReservedFieldIdentifiers; } std::string FormatField(absl::string_view field) { if (ReservedFieldIdentifiers().contains(field) || !RE2::FullMatch(field, *kSimpleIdentifierPattern)) { return absl::StrCat(kBackQuote, field, kBackQuote); } return std::string(field); } class Unparser { public: static absl::StatusOr Unparse(const Expr& expr, const SourceInfo& source_info) { Unparser unparser(expr, source_info); return unparser.DoUnparse(); } private: const Expr& expr_; const SourceInfo& source_info_; std::string output_; Unparser(const Expr& expr, const SourceInfo& source_info) : expr_(expr), source_info_(source_info) {} absl::StatusOr DoUnparse() { CEL_RETURN_IF_ERROR(Visit(expr_)); absl::StripAsciiWhitespace(&output_); return std::move(output_); } absl::Status Visit(const Expr& expr); absl::Status VisitConst(const Constant& expr); absl::Status VisitIdent(const Expr::Ident& expr); absl::Status VisitSelect(const Expr::Select& expr); absl::Status VisitOptSelect(const Expr::Call& expr); absl::Status VisitCall(const Expr::Call& expr); absl::Status VisitCreateList(const Expr::CreateList& expr); absl::Status VisitCreateStruct(const Expr::CreateStruct& expr); absl::Status VisitComprehension(const Expr::Comprehension& expr); absl::Status VisitAllMacro(const Expr::Comprehension& expr); absl::Status VisitExistsMacro(const Expr::Comprehension& expr); absl::Status VisitExistsOneMacro(const Expr::Comprehension& expr); absl::Status VisitMapMacro(const Expr::Comprehension& expr); absl::Status VisitUnary(const Expr::Call& expr, const std::string& op); absl::Status VisitBinary(const Expr::Call& expr, const std::string& op); absl::Status VisitMaybeNested(const Expr& expr, bool nested); absl::Status VisitIndex(const Expr::Call& expr); absl::Status VisitOptIndex(const Expr::Call& expr); absl::Status VisitTernary(const Expr::Call& expr); bool IsComplexOperatorWithRespectTo(const Expr& expr, const std::string& op); bool IsComplexOperator(const Expr& expr); // Returns true the given expression is // - a call expression AND ONE of the following holds: // - a binary operator // - a ternary conditional operator bool IsBinaryOrTernaryOperator(const Expr& expr); template void Print(Ts&&... args) { absl::StrAppend(&output_, std::forward(args)...); } }; absl::Status Unparser::Visit(const Expr& expr) { auto macro = source_info_.macro_calls().find(expr.id()); if (macro != source_info_.macro_calls().end()) { return Visit(macro->second); } switch (expr.expr_kind_case()) { case Expr::kConstExpr: return VisitConst(expr.const_expr()); case Expr::kIdentExpr: return VisitIdent(expr.ident_expr()); case Expr::kSelectExpr: return VisitSelect(expr.select_expr()); case Expr::kCallExpr: return VisitCall(expr.call_expr()); case Expr::kListExpr: return VisitCreateList(expr.list_expr()); case Expr::kStructExpr: return VisitCreateStruct(expr.struct_expr()); case Expr::kComprehensionExpr: return VisitComprehension(expr.comprehension_expr()); default: return absl::InvalidArgumentError( absl::StrCat("Unsupported Expr kind: ", expr.expr_kind_case())); } } absl::Status Unparser::VisitConst(const Constant& expr) { switch (expr.constant_kind_case()) { case Constant::kStringValue: Print( cel::internal::FormatDoubleQuotedStringLiteral(expr.string_value())); break; case Constant::kInt64Value: Print(expr.int64_value()); break; case Constant::kUint64Value: Print(expr.uint64_value(), "u"); break; case Constant::kBoolValue: Print(expr.bool_value() ? "true" : "false"); break; case Constant::kDoubleValue: Print(expr.double_value()); break; case Constant::kNullValue: Print("null"); break; case Constant::kBytesValue: Print(cel::internal::FormatDoubleQuotedBytesLiteral(expr.bytes_value())); break; default: return absl::InvalidArgumentError(absl::StrCat( "Unsupported Constant kind: ", expr.constant_kind_case())); } return absl::OkStatus(); } absl::Status Unparser::VisitIdent(const Expr::Ident& expr) { Print(expr.name()); return absl::OkStatus(); } absl::Status Unparser::VisitSelect(const Expr::Select& expr) { if (expr.test_only()) { Print(CelOperator::HAS, kLeftParen); } const auto& operand = expr.operand(); bool nested = !expr.test_only() && IsBinaryOrTernaryOperator(operand); CEL_RETURN_IF_ERROR(VisitMaybeNested(operand, nested)); Print(kDot, FormatField(expr.field())); if (expr.test_only()) { Print(kRightParen); } return absl::OkStatus(); } absl::Status Unparser::VisitOptSelect(const Expr::Call& expr) { if (expr.args_size() != 2 || !expr.args()[1].has_const_expr() || !expr.args()[1].const_expr().has_string_value()) { return absl::InvalidArgumentError( absl::StrCat("Unexpected select: ", expr.ShortDebugString())); } const auto& operand = expr.args()[0]; bool nested = IsBinaryOrTernaryOperator(operand); CEL_RETURN_IF_ERROR(VisitMaybeNested(operand, nested)); Print(kDot, kQuestionMark, FormatField(expr.args()[1].const_expr().string_value())); return absl::OkStatus(); } absl::Status Unparser::VisitCall(const Expr::Call& expr) { const auto& fun = expr.function(); absl::optional op = LookupUnaryOperator(fun); if (op.has_value()) { return VisitUnary(expr, *op); } op = LookupBinaryOperator(fun); if (op.has_value()) { return VisitBinary(expr, *op); } if (fun == CelOperator::INDEX) { return VisitIndex(expr); } if (fun == CelOperator::OPT_INDEX) { return VisitOptIndex(expr); } if (fun == CelOperator::OPT_SELECT) { return VisitOptSelect(expr); } if (fun == CelOperator::CONDITIONAL) { return VisitTernary(expr); } if (expr.has_target()) { bool nested = IsBinaryOrTernaryOperator(expr.target()); CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.target(), nested)); Print(kDot); } Print(fun, kLeftParen); for (int i = 0; i < expr.args_size(); i++) { if (i > 0) { Print(kComma, kSpace); } CEL_RETURN_IF_ERROR(Visit(expr.args(i))); } Print(kRightParen); return absl::OkStatus(); } absl::Status Unparser::VisitCreateList(const Expr::CreateList& expr) { Print(kLeftBracket); for (int i = 0; i < expr.elements_size(); i++) { if (i > 0) { Print(kComma, kSpace); } if (std::find(expr.optional_indices().begin(), expr.optional_indices().end(), static_cast(i)) != expr.optional_indices().end()) { Print(kQuestionMark); } CEL_RETURN_IF_ERROR(Visit(expr.elements(i))); } Print(kRightBracket); return absl::OkStatus(); } absl::Status Unparser::VisitCreateStruct(const Expr::CreateStruct& expr) { if (!expr.message_name().empty()) { Print(expr.message_name()); } Print(kLeftBrace); for (int i = 0; i < expr.entries_size(); i++) { if (i > 0) { Print(kComma, kSpace); } const auto& e = expr.entries(i); if (e.optional_entry()) { Print(kQuestionMark); } switch (e.key_kind_case()) { case Expr::CreateStruct::Entry::kFieldKey: Print(FormatField(e.field_key())); break; case Expr::CreateStruct::Entry::kMapKey: CEL_RETURN_IF_ERROR(Visit(e.map_key())); break; default: return absl::InvalidArgumentError( absl::StrCat("Unexpected struct: ", expr.ShortDebugString())); } Print(kColon, kSpace); CEL_RETURN_IF_ERROR(Visit(e.value())); } Print(kRightBrace); return absl::OkStatus(); } absl::Status Unparser::VisitComprehension(const Expr::Comprehension& expr) { bool nested = IsComplexOperator(expr.iter_range()); CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.iter_range(), nested)); Print(kDot); if (expr.loop_step().call_expr().function() == CelOperator::LOGICAL_AND) { return VisitAllMacro(expr); } if (expr.loop_step().call_expr().function() == CelOperator::LOGICAL_OR) { return VisitExistsMacro(expr); } if (expr.result().expr_kind_case() == Expr::kCallExpr) { return VisitExistsOneMacro(expr); } return VisitMapMacro(expr); } absl::Status Unparser::VisitAllMacro(const Expr::Comprehension& expr) { if (expr.loop_step().call_expr().args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected all macro: ", expr.ShortDebugString())); } Print(CelOperator::ALL, kLeftParen, expr.iter_var(), kComma, kSpace); CEL_RETURN_IF_ERROR(Visit(expr.loop_step().call_expr().args(1))); Print(kRightParen); return absl::OkStatus(); } absl::Status Unparser::VisitExistsMacro(const Expr::Comprehension& expr) { if (expr.loop_step().call_expr().args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected exists macro: ", expr.ShortDebugString())); } Print(CelOperator::EXISTS, kLeftParen, expr.iter_var(), kComma, kSpace); CEL_RETURN_IF_ERROR(Visit(expr.loop_step().call_expr().args(1))); Print(kRightParen); return absl::OkStatus(); } absl::Status Unparser::VisitExistsOneMacro(const Expr::Comprehension& expr) { if (expr.loop_step().call_expr().args_size() != 3) { return absl::InvalidArgumentError( absl::StrCat("Unexpected exists one macro: ", expr.ShortDebugString())); } Print(CelOperator::EXISTS_ONE, kLeftParen, expr.iter_var(), kComma, kSpace); CEL_RETURN_IF_ERROR(Visit(expr.loop_step().call_expr().args(0))); Print(kRightParen); return absl::OkStatus(); } absl::Status Unparser::VisitMapMacro(const Expr::Comprehension& expr) { Print(CelOperator::MAP, kLeftParen, expr.iter_var(), kComma, kSpace); Expr step = expr.loop_step(); if (step.call_expr().function() == CelOperator::CONDITIONAL) { if (step.call_expr().args_size() != 3) { return absl::InvalidArgumentError( absl::StrCat("Unexpected exists map macro filter step: ", expr.ShortDebugString())); } CEL_RETURN_IF_ERROR(Visit(step.call_expr().args(0))); Print(kComma, kSpace); auto temp = step.call_expr().args(1); step = temp; } if (step.call_expr().args_size() != 2 || step.call_expr().args(1).list_expr().elements_size() != 1) { return absl::InvalidArgumentError( absl::StrCat("Unexpected exists map macro: ", expr.ShortDebugString())); } CEL_RETURN_IF_ERROR(Visit(step.call_expr().args(1).list_expr().elements(0))); Print(kRightParen); return absl::OkStatus(); } absl::Status Unparser::VisitUnary(const Expr::Call& expr, const std::string& op) { if (expr.args_size() != 1) { return absl::InvalidArgumentError( absl::StrCat("Unexpected unary: ", expr.ShortDebugString())); } Print(op); bool nested = IsComplexOperator(expr.args(0)); return VisitMaybeNested(expr.args(0), nested); } absl::Status Unparser::VisitBinary(const Expr::Call& expr, const std::string& op) { if (expr.args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); } const auto& lhs = expr.args(0); const auto& rhs = expr.args(1); const auto& fun = expr.function(); // add parens if the current operator is lower precedence than the lhs expr // operator. bool lhs_paren = IsComplexOperatorWithRespectTo(lhs, fun); // add parens if the current operator is lower precedence than the rhs expr // operator, or the same precedence and the operator is left recursive. bool rhs_paren = IsComplexOperatorWithRespectTo(rhs, fun); if (!rhs_paren && IsOperatorLeftRecursive(fun)) { rhs_paren = IsOperatorSamePrecedence(fun, rhs); } CEL_RETURN_IF_ERROR(VisitMaybeNested(lhs, lhs_paren)); Print(kSpace, op, kSpace); return VisitMaybeNested(rhs, rhs_paren); } absl::Status Unparser::VisitMaybeNested(const Expr& expr, bool nested) { if (nested) { Print(kLeftParen); } CEL_RETURN_IF_ERROR(Visit(expr)); if (nested) { Print(kRightParen); } return absl::OkStatus(); } absl::Status Unparser::VisitIndex(const Expr::Call& expr) { if (expr.args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected index call: ", expr.ShortDebugString())); } bool nested = IsBinaryOrTernaryOperator(expr.args(0)); CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.args(0), nested)); Print(kLeftBracket); CEL_RETURN_IF_ERROR(Visit(expr.args(1))); Print(kRightBracket); return absl::OkStatus(); } absl::Status Unparser::VisitOptIndex(const Expr::Call& expr) { if (expr.args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected index call: ", expr.ShortDebugString())); } bool nested = IsBinaryOrTernaryOperator(expr.args(0)); CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.args(0), nested)); Print(kLeftBracket); Print(kQuestionMark); CEL_RETURN_IF_ERROR(Visit(expr.args(1))); Print(kRightBracket); return absl::OkStatus(); } absl::Status Unparser::VisitTernary(const Expr::Call& expr) { if (expr.args_size() != 3) { return absl::InvalidArgumentError( absl::StrCat("Unexpected ternary: ", expr.ShortDebugString())); } bool nested = IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr.args(0)) || IsComplexOperator(expr.args(0)); CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.args(0), nested)); Print(kSpace, kQuestionMark, kSpace); nested = IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr.args(1)) || IsComplexOperator(expr.args(1)); CEL_RETURN_IF_ERROR(VisitMaybeNested(expr.args(1), nested)); Print(kSpace, kColon, kSpace); nested = IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr.args(2)) || IsComplexOperator(expr.args(2)); return VisitMaybeNested(expr.args(2), nested); } bool Unparser::IsComplexOperatorWithRespectTo(const Expr& expr, const std::string& op) { // If the arg is not a call with more than one arg, return false. if (!expr.has_call_expr() || expr.call_expr().args_size() < 2) { return false; } // Otherwise, return whether the given op has lower precedence than expr return IsOperatorLowerPrecedence(op, expr); } bool Unparser::IsComplexOperator(const Expr& expr) { // If the arg is a call with more than one arg, return true return expr.has_call_expr() && expr.call_expr().args_size() >= 2; } // Returns true the given expression is // - a call expression AND ONE of the following holds: // - a binary operator // - a ternary conditional operator bool Unparser::IsBinaryOrTernaryOperator(const Expr& expr) { if (!IsComplexOperator(expr)) { return false; } return LookupBinaryOperator(expr.call_expr().function()).has_value() || IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr); } } // namespace absl::StatusOr Unparse(const Expr& expr, const SourceInfo* source_info) { const SourceInfo& info = source_info == nullptr ? SourceInfo::default_instance() : *source_info; return Unparser::Unparse(expr, info); } absl::StatusOr Unparse(const ParsedExpr& parsed_expr) { return Unparse(parsed_expr.expr(), &parsed_expr.source_info()); } absl::StatusOr Unparse(const CheckedExpr& checked_expr) { return Unparse(checked_expr.expr(), &checked_expr.source_info()); } } // namespace google::api::expr ================================================ FILE: tools/cel_unparser.h ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Provides an unparsing utility that converts an AST back into // a human readable format. // // Input to the unparser is the proto AST (Expr, CheckedExpr, or ParsedExpr). // The unparser does not do any checks to see if the ParsedExpr is syntactically // or semantically correct but does checks enough to prevent its crash and might // return errors in such cases. #ifndef THIRD_PARTY_CEL_CPP_TOOLS_UNPARSER_H_ #define THIRD_PARTY_CEL_CPP_TOOLS_UNPARSER_H_ #include #include "cel/expr/checked.pb.h" #include "cel/expr/syntax.pb.h" #include "absl/base/attributes.h" #include "absl/status/statusor.h" namespace google::api::expr { // Unparses the given expression into a human readable cel expression. ABSL_DEPRECATED( "Use Unparse(ParsedExpr) to ensure proper unparsing of all CEL " "expressions. Note, ParserOptions.add_macro_calls must be set to true " "for full fidelity unparsing.") absl::StatusOr Unparse( const cel::expr::Expr& expr, const cel::expr::SourceInfo* source_info = nullptr); // Unparses the ParsedExpr value to a human-readable string. // // For the best results ensure that the expression is parsed with // ParserOptions.add_macro_calls = true. absl::StatusOr Unparse( const cel::expr::ParsedExpr& parsed_expr); // Unparses the CheckedExpr value to a human-readable string. // // For the best results ensure that the expression is parsed with // ParserOptions.add_macro_calls = true. absl::StatusOr Unparse( const cel::expr::CheckedExpr& checked_expr); } // namespace google::api::expr #endif // THIRD_PARTY_CEL_CPP_TOOLS_UNPARSER_H_ ================================================ FILE: tools/cel_unparser_test.cc ================================================ // Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/cel_unparser.h" #include #include "cel/expr/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "internal/proto_matchers.h" #include "internal/testing.h" #include "parser/options.h" #include "parser/parser.h" #include "google/protobuf/text_format.h" namespace google::api::expr { namespace { using ::absl_testing::StatusIs; using ::cel::internal::test::EqualsProto; using ::cel::expr::Expr; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::HasSubstr; using ::testing::ValuesIn; struct UnparserTestCaseTextProto { std::string proto_text; absl::StatusOr expr; }; class UnparserTestTextProto : public testing::TestWithParam {}; TEST_P(UnparserTestTextProto, Test) { auto test_case = GetParam(); Expr expr; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(test_case.proto_text, &expr)); absl::StatusOr result = Unparse(expr); if (result.ok()) { ASSERT_OK(test_case.expr); ASSERT_EQ(*(test_case.expr), *result); } else { ASSERT_THAT(result.status(), StatusIs(test_case.expr.status().code(), HasSubstr(test_case.expr.status().message()))); } } // these tests make explicit assumptions about specific proto structures // that are to be observed INSTANTIATE_TEST_SUITE_P( UnparseCompProto, UnparserTestTextProto, ValuesIn( {// Empty Expr error {"", absl::InvalidArgumentError("Unsupported Expr")}, // Constants {"const_expr{}", absl::InvalidArgumentError("Unsupported Constant")}, {"const_expr{bool_value: true}", "true"}, {"const_expr{int64_value: 4}", "4"}, {"const_expr{uint64_value: 4}", "4u"}, // Sequences { R"pb( struct_expr { entries { value { const_expr { uint64_value: 2 } } } })pb", absl::InvalidArgumentError("Unexpected struct")}, {R"pb( list_expr { elements { const_expr { int64_value: 1 } } elements { const_expr { uint64_value: 2 } } } )pb", "[1, 2u]"}, {R"pb( struct_expr { entries { map_key { const_expr { int64_value: 1 } } value { const_expr { uint64_value: 2 } } } entries { map_key { const_expr { int64_value: 2 } } value { const_expr { uint64_value: 3 } } } })pb", "{1: 2u, 2: 3u}"}, // Messages {R"pb( struct_expr { message_name: 'TestAllTypes' entries { field_key: 'single_int32' value { const_expr { int64_value: 1 } } } entries { field_key: 'single_int64' value { const_expr { int64_value: 2 } } } } )pb", "TestAllTypes{single_int32: 1, single_int64: 2}"}, // Conditionals {R"pb( call_expr { function: '!_' } )pb", absl::InvalidArgumentError("Unexpected unary")}, {R"pb( call_expr { function: '_||_' } )pb", absl::InvalidArgumentError("Unexpected binary")}, {R"pb( call_expr { function: '_[_]' } )pb", absl::InvalidArgumentError("Unexpected index")}, {R"pb( call_expr { function: '_?_:_' } )pb", absl::InvalidArgumentError("Unexpected ternary")}, {R"pb( call_expr { function: '_||_' args { call_expr { function: '_&&_' args { const_expr { bool_value: false } } args { call_expr { function: '!_' args { const_expr { bool_value: true } } } } } } args { const_expr { bool_value: false } } })pb", "false && !true || false"}, {R"pb( call_expr { function: '_&&_' args { const_expr { bool_value: false } } args { call_expr { function: '_||_' args { call_expr { function: '!_' args { const_expr { bool_value: true } } } } args { const_expr { bool_value: false } } } } })pb", "false && (!true || false)"}, {R"pb( call_expr { function: '_?_:_' args { call_expr { function: '_||_' args { call_expr { function: '_&&_' args { const_expr { bool_value: false } } args { call_expr { function: "!_" args { const_expr { bool_value: true } } } } } } args { const_expr { bool_value: false } } } } args { const_expr { int64_value: 2 } } args { const_expr { int64_value: 3 } } })pb", "(false && !true || false) ? 2 : 3"}, {R"pb( call_expr { function: '!_' args { call_expr { function: '!_' args { const_expr { bool_value: true } } } } })pb", "!!true"}, {R"pb( call_expr { function: '_?_:_' args { call_expr { function: '_<_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 5 } } } } args { ident_expr { name: 'x' } } args { const_expr { int64_value: 5 } } })pb", "(x < 5) ? x : 5"}, {R"pb( call_expr { function: '_?_:_' args { call_expr { function: '_>_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 5 } } } } args { call_expr { function: '_-_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 5 } } } } args { const_expr { int64_value: 0 } } })pb", "(x > 5) ? (x - 5) : 0"}, {R"pb( call_expr { function: '_?_:_' args { call_expr { function: '_>_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 5 } } } } args { call_expr { function: '_?_:_' args { call_expr { function: '_>_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 10 } } } } args { call_expr { function: '_-_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 10 } } } } args { const_expr { int64_value: 5 } } } } args { const_expr { int64_value: 0 } } })pb", "(x > 5) ? ((x > 10) ? (x - 10) : 5) : 0"}, {R"pb( call_expr { function: '_in_' args { ident_expr { name: 'a' } } args { ident_expr { name: 'b' } } })pb", "a in b"}, // Calculations {R"pb( call_expr { function: '_*_' args { call_expr { function: '_+_' args { const_expr { int64_value: 1 } } args { const_expr { int64_value: 2 } } } } args { const_expr { int64_value: 3 } } })pb", "(1 + 2) * 3"}, {R"pb( call_expr { function: '_+_' args { const_expr { int64_value: 1 } } args { call_expr { function: '_*_' args { const_expr { int64_value: 2 } } args { const_expr { int64_value: 3 } } } } })pb", "1 + 2 * 3"}, {R"pb( call_expr { function: '-_' args { call_expr { function: '_*_' args { const_expr { int64_value: 1 } } args { const_expr { int64_value: 2 } } } } })pb", "-(1 * 2)"}, // Comprehensions {R"pb( comprehension_expr { iter_var: 'x' iter_range { list_expr { elements { const_expr { int64_value: 1 } } elements { const_expr { int64_value: 2 } } elements { const_expr { int64_value: 3 } } } } accu_var: 'accu' accu_init { const_expr { bool_value: true } } loop_condition { ident_expr { name: 'accu' } } loop_step { call_expr { function: '_&&_' args { ident_expr { name: 'x' } } args { call_expr { function: '_>_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 0 } } } } } } result { ident_expr { name: 'accu' } } })pb", "[1, 2, 3].all(x, x > 0)"}, {R"pb( comprehension_expr { iter_var: 'x' iter_range { list_expr { elements { const_expr { int64_value: 1 } } elements { const_expr { int64_value: 2 } } elements { const_expr { int64_value: 3 } } } } accu_var: 'accu' accu_init { const_expr { bool_value: false } } loop_condition { call_expr { function: '!_' args { ident_expr { name: 'accu' } } } } loop_step { call_expr { function: '_||_' args { ident_expr { name: 'x' } } args { call_expr { function: '_>_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 0 } } } } } } result { ident_expr { name: 'accu' } } })pb", "[1, 2, 3].exists(x, x > 0)"}, {R"pb( comprehension_expr { iter_var: 'x' iter_range { list_expr { elements { const_expr { int64_value: 1 } } elements { const_expr { int64_value: 2 } } elements { const_expr { int64_value: 3 } } } } accu_var: 'accu' accu_init { list_expr {} } loop_condition { const_expr { bool_value: false } } loop_step { call_expr { function: '_?_:_' args { call_expr { function: '_>=_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 2 } } } } args { call_expr { function: '_+_' args { ident_expr { name: 'accu' } } args { list_expr { elements { call_expr { function: '_*_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 4 } } } } } } } } args { ident_expr { name: 'accu' } } } } result { ident_expr { name: 'accu' } } })pb", "[1, 2, 3].map(x, x >= 2, x * 4)"}, {R"pb( comprehension_expr { iter_var: 'x' iter_range { list_expr { elements { const_expr { int64_value: 1 } } elements { const_expr { int64_value: 2 } } elements { const_expr { int64_value: 3 } } } } accu_var: 'accu' accu_init { const_expr { int64_value: 0 } } loop_condition { call_expr { function: '_<=_' args { ident_expr { name: 'accu' } } args { const_expr { int64_value: 1 } } } } loop_step { call_expr { function: '_?_:_' args { call_expr { function: '_>=_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 2 } } } } args { call_expr { function: '_+_' args { ident_expr { name: 'accu' } } args { const_expr { int64_value: 1 } } } } args { ident_expr { name: 'accu' } } } } result { call_expr { function: '_==_' args { ident_expr { name: 'accu' } } args { const_expr { int64_value: 1 } } } } })pb", "[1, 2, 3].exists_one(x, x >= 2)"}, {R"pb( select_expr { operand { call_expr { function: '_[_]' args { ident_expr { name: 'x' } } args { const_expr { string_value: 'a' } } } } field: 'single_int32' test_only: true })pb", "has(x[\"a\"].single_int32)"}, // This is a filter expression but is decompiled back to // map(x, filter_function, x) for which the evaluation is // equal to filter(x, filter_function). {R"pb( comprehension_expr { iter_var: 'x' iter_range { list_expr { elements { const_expr { int64_value: 1 } } elements { const_expr { int64_value: 2 } } elements { const_expr { int64_value: 3 } } } } accu_var: 'accu' accu_init { list_expr {} } loop_condition { const_expr { bool_value: false } } loop_step { call_expr { function: '_?_:_' args { call_expr { function: '_>=_' args { ident_expr { name: 'x' } } args { const_expr { int64_value: 2 } } } } args { call_expr { function: '_+_' args { ident_expr { name: 'accu' } } args { list_expr { elements { ident_expr { name: 'x' } } } } } } args { ident_expr { name: 'accu' } } } } result { ident_expr { name: 'accu' } } })pb", "[1, 2, 3].map(x, x >= 2, x)"}, // Index {R"pb( call_expr { function: '_==_' args { select_expr { operand { call_expr { function: '_[_]' args { ident_expr { name: 'x' } } args { const_expr { string_value: 'a' } } } } field: 'single_int32' } } args { const_expr { int64_value: 23 } } })pb", "x[\"a\"].single_int32 == 23"}, {R"pb( call_expr { function: '_[_]' args { call_expr { function: '_[_]' args { ident_expr { name: 'a' } } args { const_expr { int64_value: 1 } } } } args { const_expr { string_value: 'b' } } })pb", "a[1][\"b\"]"}, // Functions {R"pb( call_expr { function: '_!=_' args { ident_expr { name: 'x' } } args { const_expr { string_value: 'a' } } })pb", "x != \"a\""}, {R"pb( call_expr { function: '_==_' args { call_expr { function: 'size' args { ident_expr { name: 'x' } } } } args { call_expr { target { ident_expr { name: 'x' } } function: 'size' } } })pb", "size(x) == x.size()"}, // Long string {R"pb( list_expr { elements { const_expr { string_value: 'Loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong' } } })pb", R"(["Loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong"])"}})); struct UnparserTestCaseTextExpr { std::string expr; std::string equiv_expected; }; class UnparserTestTextExpr : public testing::TestWithParam {}; TEST_P(UnparserTestTextExpr, Test) { Expr expr; parser::ParserOptions options; options.add_macro_calls = true; options.enable_optional_syntax = true; options.enable_quoted_identifiers = true; ASSERT_OK_AND_ASSIGN(ParsedExpr result, Parse(GetParam().expr, "unparser", options)); ASSERT_OK_AND_ASSIGN(std::string result_expr, Unparse(result)); if (!GetParam().equiv_expected.empty()) { ASSERT_EQ(GetParam().equiv_expected, result_expr); } else { ASSERT_EQ(GetParam().expr, result_expr); } if (GetParam().equiv_expected.empty()) { // parse again, confirm it's the same result ASSERT_OK_AND_ASSIGN(ParsedExpr result2, Parse(result_expr, "unparser", options)); EXPECT_THAT(result, EqualsProto(result2)); } else { // We cannot compare the original parsed proto and the equivalent expected // proto, since the IDs will most likely be different, e.g., due to // rebalancing logical expressions. } } // These test cases check that Unparse(Parse(expr)) is idempotent // (if there is one string in an entry), or equivalent to some other // form (if there are two strings in an entry). The latter can occur // especially due to spacing in the expression, or if the logical // expression balancer modifies an expression. INSTANTIATE_TEST_SUITE_P( UnparseCompExpr, UnparserTestTextExpr, ValuesIn({ {"a + b - c", ""}, {"a && b && c && d && e", ""}, {"a || b && (c || d) && e", ""}, {"a ? b : c", ""}, {"a[1][\"b\"]", ""}, {"x[\"a\"].single_int32 == 23", ""}, {"a * (b / c) % 0", ""}, {"a + b * c", ""}, {"(a + b) * c / (d - e)", ""}, {"a * b / c % 0", ""}, {"!true", ""}, {"-num", ""}, {"a || b || c || d || e", ""}, {"-(1 * 2)", ""}, {"-(1 + 2)", ""}, {"(x > 5) ? (x - 5) : 0", ""}, {"size(a ? (b ? c : d) : e)", ""}, {"a.hello(\"world\")", ""}, {"zero()", ""}, {"one(\"a\")", ""}, {"and(d, 32u)", ""}, {"max(a, b, 100)", ""}, {"x != \"a\"", ""}, {"[]", ""}, {"[1]", ""}, {"[\"hello, world\", \"goodbye, world\", \"sure, why not?\"]", ""}, {"b\"ÿ\"", "b\"\\xc3\\x83\\xc2\\xbf\""}, {"b'aaa\"bbb'", "b\"aaa\\\"bbb\""}, {"-42.101", ""}, {"false", ""}, {"-405069", ""}, {"null", ""}, {"\"hello:\\t'world'\"", ""}, {"true", ""}, {"42u", ""}, {"my_ident", ""}, {"has(hello.world)", ""}, {"{}", ""}, {"{\"a\": a.b.c, b\"b\": bytes(a.b.c)}", ""}, {"{a: a, b: a.b, c: a.b.c, a ? b : c: false, a || b: true}", ""}, {"v1alpha1.Expr{}", ""}, {"v1alpha1.Expr{id: 1, call_expr: v1alpha1.Call_Expr{function: " "\"name\"}}", ""}, {"a.b.c", ""}, {"a[b][c].name", ""}, {"(a + b).name", ""}, {"(a ? b : c).name", ""}, {"(a ? b : c)[0]", ""}, {"(a1 && a2) ? b : c", ""}, {"a ? (b1 || b2) : (c1 && c2)", ""}, {"(a ? b : c).method(d)", ""}, // the following give the expected equivalent representation that // is to be observed when parsing and decompiling again, note the // differences in spacing and simplification of logical expressions {"a+b-c", "a + b - c"}, {"a ? b : c", "a ? b : c"}, {"a[ 1 ][\"b\"]", "a[1][\"b\"]"}, {"(false && !true) || false", "false && !true || false"}, {"a . b . c", "a.b.c"}, // here we expect the expression balancer to remove the double negation {"!!true", "true"}, // From protos above // Constants {"true", ""}, {"4", ""}, {"4u", ""}, // Sequences {"[1, 2u]", ""}, {"{1: 2u, 2: 3u}", ""}, // Messages {"TestAllTypes{single_int32: 1, single_int64: 2}", ""}, // Conditionals {"false && !true || false", ""}, {"false && (!true || false)", ""}, {"(false && !true || false) ? 2 : 3", ""}, {"(x < 5) ? x : 5", ""}, {"(x > 5) ? (x - 5) : 0", ""}, {"(x > 5) ? ((x > 10) ? (x - 10) : 5) : 0", ""}, {"a in b", ""}, // Calculations {"(1 + 2) * 3", ""}, {"1 + 2 * 3", ""}, {"-(1 * 2)", ""}, // Comprehensions {"[1, 2, 3].all(x, x > 0)", ""}, {"[1, 2, 3].exists(x, x > 0)", ""}, {"[1, 2, 3].map(x, x >= 2, x * 4)", ""}, {"[1, 2, 3].exists_one(x, x >= 2)", ""}, {"[[1], [2], [3]].all(x, x.all(y, y >= 2))", ""}, {"(has(x.y) ? x.y : []).filter(z, z == \"zed\")", ""}, // Macros {"has(x[\"a\"].single_int32)", ""}, // This is a filter expression but is decompiled back to // map(x, filter_function, x) for which the evaluation is // equal to filter(x, filter_function). {"[1, 2, 3].map(x, x >= 2, x)", ""}, // Index {"x[\"a\"].single_int32 == 23", ""}, {"a[1][\"b\"]", ""}, // Functions {"x != \"a\"", ""}, {"size(x) == x.size()", ""}, // Long string {R"(["Loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong"])", ""}, {"a.?b[?0] && a[?c]", ""}, {"{?\"key\": value}", ""}, {"[?a, ?b]", ""}, {"[?a[?b]]", ""}, {"Msg{?field: value}", ""}, {"Msg{`in`: value}", ""}, {"Msg{?`b.c`: value}", ""}, {"has(a.`b.c`)", ""}, {"a.`b/c`", ""}, {"a.?`b/c`", ""}, })); } // namespace } // namespace google::api::expr ================================================ FILE: tools/descriptor_pool_builder.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/descriptor_pool_builder.h" #include #include #include "google/protobuf/descriptor.pb.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "common/minimal_descriptor_database.h" #include "internal/status_macros.h" #include "google/protobuf/descriptor.h" namespace cel { namespace { absl::Status FindDeps( std::vector& to_resolve, absl::flat_hash_set& resolved, DescriptorPoolBuilder& builder) { while (!to_resolve.empty()) { const auto* file = to_resolve.back(); to_resolve.pop_back(); if (resolved.contains(file)) { continue; } google::protobuf::FileDescriptorProto file_proto; file->CopyTo(&file_proto); // Note: order doesn't matter here as long as all the cross references are // correct in the final database. CEL_RETURN_IF_ERROR(builder.AddFileDescriptor(file_proto)); resolved.insert(file); for (int i = 0; i < file->dependency_count(); ++i) { to_resolve.push_back(file->dependency(i)); } } return absl::OkStatus(); } } // namespace DescriptorPoolBuilder::StateHolder::StateHolder( google::protobuf::DescriptorDatabase* base) : base(base), merged(base, &extensions), pool(&merged) {} DescriptorPoolBuilder::DescriptorPoolBuilder() : state_(std::make_shared( cel::GetMinimalDescriptorDatabase())) {} std::shared_ptr DescriptorPoolBuilder::Build() && { auto alias = std::shared_ptr(state_, &state_->pool); state_.reset(); return alias; } absl::Status DescriptorPoolBuilder::AddTransitiveDescriptorSet( const google::protobuf::Descriptor* absl_nonnull desc) { absl::flat_hash_set resolved; std::vector to_resolve{desc->file()}; return FindDeps(to_resolve, resolved, *this); } absl::Status DescriptorPoolBuilder::AddTransitiveDescriptorSet( absl::Span descs) { absl::flat_hash_set resolved; std::vector to_resolve; to_resolve.reserve(descs.size()); for (const google::protobuf::Descriptor* desc : descs) { to_resolve.push_back(desc->file()); } return FindDeps(to_resolve, resolved, *this); } absl::Status DescriptorPoolBuilder::AddFileDescriptor( const google::protobuf::FileDescriptorProto& file) { if (!state_->extensions.Add(file)) { return absl::InvalidArgumentError( absl::StrCat("proto descriptor conflict: ", file.name())); } return absl::OkStatus(); } absl::Status DescriptorPoolBuilder::AddFileDescriptorSet( const google::protobuf::FileDescriptorSet& file) { for (const auto& file : file.file()) { CEL_RETURN_IF_ERROR(AddFileDescriptor(file)); } return absl::OkStatus(); } } // namespace cel ================================================ FILE: tools/descriptor_pool_builder.h ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ #include #include #include "google/protobuf/descriptor.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/types/span.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/descriptor_database.h" namespace cel { // A helper class for building a descriptor pool from a set proto file // descriptors. Manages lifetime for the descriptor databases backing // the pool. // // Client must ensure that types are not added multiple times. // // Note: in the constructed pool, the definitions for the required types for // CEL will shadow any added to the builder. Clients should not modify types // from the google.protobuf package in general, but if they do the behavior of // the constructed descriptor pool will be inconsistent. class DescriptorPoolBuilder { public: DescriptorPoolBuilder(); DescriptorPoolBuilder& operator=(const DescriptorPoolBuilder&) = delete; DescriptorPoolBuilder(const DescriptorPoolBuilder&) = delete; DescriptorPoolBuilder& operator=(const DescriptorPoolBuilder&&) = delete; DescriptorPoolBuilder(DescriptorPoolBuilder&&) = delete; ~DescriptorPoolBuilder() = default; // Returns a shared pointer to the new descriptor pool that manages the // underlying descriptor databases backing the pool. // // Consumes the builder instance. It is unsafe to make any further changes // to the descriptor databases after accessing the pool. std::shared_ptr Build() &&; // Utility for adding the transitive dependencies of a message with a linked // descriptor. absl::Status AddTransitiveDescriptorSet( const google::protobuf::Descriptor* absl_nonnull desc); absl::Status AddTransitiveDescriptorSet( absl::Span); // Adds a file descriptor set to the pool. Client must ensure that all // dependencies are satisfied and that files are not added multiple times. absl::Status AddFileDescriptorSet(const google::protobuf::FileDescriptorSet& files); // Adds a single proto file descriptor set to the pool. Client must ensure // that all dependencies are satisfied and that files are not added multiple // times. absl::Status AddFileDescriptor(const google::protobuf::FileDescriptorProto& file); private: struct StateHolder { explicit StateHolder(google::protobuf::DescriptorDatabase* base); google::protobuf::DescriptorDatabase* base; google::protobuf::SimpleDescriptorDatabase extensions; google::protobuf::MergedDescriptorDatabase merged; google::protobuf::DescriptorPool pool; }; explicit DescriptorPoolBuilder(std::shared_ptr state) : state_(std::move(state)) {} std::shared_ptr state_; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ ================================================ FILE: tools/descriptor_pool_builder_test.cc ================================================ // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/descriptor_pool_builder.h" #include #include "google/protobuf/descriptor.pb.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "internal/testing.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" #include "google/protobuf/text_format.h" using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::testing::IsNull; using ::testing::NotNull; namespace cel { namespace { TEST(DescriptorPoolBuilderTest, IncludesDefaults) { DescriptorPoolBuilder builder; auto pool = std::move(builder).Build(); EXPECT_THAT( pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), IsNull()); EXPECT_THAT(pool->FindMessageTypeByName("google.protobuf.Timestamp"), NotNull()); EXPECT_THAT(pool->FindMessageTypeByName("google.protobuf.Any"), NotNull()); } TEST(DescriptorPoolBuilderTest, AddTransitiveDescriptorSet) { DescriptorPoolBuilder builder; ASSERT_THAT(builder.AddTransitiveDescriptorSet( cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: descriptor()), IsOk()); auto pool = std::move(builder).Build(); EXPECT_THAT( pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), NotNull()); } TEST(DescriptorPoolBuilderTest, AddTransitiveDescriptorSetSpan) { DescriptorPoolBuilder builder; const google::protobuf::Descriptor* descs[] = { cel::expr::conformance::proto2::TestAllTypes::descriptor(), cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: descriptor()}; ASSERT_THAT(builder.AddTransitiveDescriptorSet(descs), IsOk()); auto pool = std::move(builder).Build(); EXPECT_THAT( pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), NotNull()); } TEST(DescriptorPoolBuilderTest, AddFileDescriptorSet) { DescriptorPoolBuilder builder; google::protobuf::FileDescriptorSet file_set; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( name: "foo.proto" package: "cel.test" dependency: "bar.proto" message_type { name: "Foo" field: { name: "bar" number: 1 label: LABEL_OPTIONAL type: TYPE_MESSAGE type_name: ".cel.test.Bar" } } )pb", file_set.add_file())); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( name: "bar.proto" package: "cel.test" message_type { name: "Bar" field: { name: "baz" number: 1 label: LABEL_OPTIONAL type: TYPE_STRING } } )pb", file_set.add_file())); ASSERT_THAT(builder.AddFileDescriptorSet(file_set), IsOk()); auto pool = std::move(builder).Build(); EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Foo"), NotNull()); EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Bar"), NotNull()); } TEST(DescriptorPoolBuilderTest, BadRef) { DescriptorPoolBuilder builder; google::protobuf::FileDescriptorSet file_set; // Unfulfilled dependency. ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( name: "foo.proto" package: "cel.test" dependency: "bar.proto" message_type { name: "Foo" field: { name: "bar" number: 1 label: LABEL_OPTIONAL type: TYPE_MESSAGE type_name: ".cel.test.Bar" } } )pb", file_set.add_file())); // Note: descriptor pool is initialized lazily so this will not lead to an // error now, but looking up the message will fail. ASSERT_THAT(builder.AddFileDescriptorSet(file_set), IsOk()); auto pool = std::move(builder).Build(); EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Foo"), IsNull()); } TEST(DescriptorPoolBuilderTest, AddFile) { DescriptorPoolBuilder builder; google::protobuf::FileDescriptorProto file; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( name: "bar.proto" package: "cel.test" message_type { name: "Bar" field: { name: "baz" number: 1 label: LABEL_OPTIONAL type: TYPE_STRING } } )pb", &file)); ASSERT_THAT(builder.AddFileDescriptor(file), IsOk()); // Duplicate file. ASSERT_THAT(builder.AddFileDescriptor(file), StatusIs(absl::StatusCode::kInvalidArgument)); // In this specific case, we know that the duplicate is the same so // the pool will still be valid. auto pool = std::move(builder).Build(); EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Bar"), NotNull()); } } // namespace } // namespace cel ================================================ FILE: tools/flatbuffers_backed_impl.cc ================================================ #include "tools/flatbuffers_backed_impl.h" #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "eval/public/cel_value.h" #include "flatbuffers/flatbuffers.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { CelValue CreateValue(int64_t value) { return CelValue::CreateInt64(value); } CelValue CreateValue(uint64_t value) { return CelValue::CreateUint64(value); } CelValue CreateValue(double value) { return CelValue::CreateDouble(value); } CelValue CreateValue(bool value) { return CelValue::CreateBool(value); } template class FlatBuffersListImpl : public CelList { public: FlatBuffersListImpl(const flatbuffers::Table& table, const reflection::Field& field) : list_(table.GetPointer*>(field.offset())) { } int size() const override { return list_ ? list_->size() : 0; } CelValue operator[](int index) const override { return CreateValue(static_cast(list_->Get(index))); } private: const flatbuffers::Vector* list_; }; class StringListImpl : public CelList { public: explicit StringListImpl( const flatbuffers::Vector>* list) : list_(list) {} int size() const override { return list_ ? list_->size() : 0; } CelValue operator[](int index) const override { auto value = list_->Get(index); return CelValue::CreateStringView( absl::string_view(value->c_str(), value->size())); } private: const flatbuffers::Vector>* list_; }; class ObjectListImpl : public CelList { public: ObjectListImpl( const flatbuffers::Vector>* list, const reflection::Schema& schema, const reflection::Object& object, google::protobuf::Arena* arena) : arena_(arena), list_(list), schema_(schema), object_(object) {} int size() const override { return list_ ? list_->size() : 0; } CelValue operator[](int index) const override { auto value = list_->Get(index); return CelValue::CreateMap(google::protobuf::Arena::Create( arena_, *value, schema_, object_, arena_)); } private: google::protobuf::Arena* arena_; const flatbuffers::Vector>* list_; const reflection::Schema& schema_; const reflection::Object& object_; }; class ObjectStringIndexedMapImpl : public CelMap { public: ObjectStringIndexedMapImpl( const flatbuffers::Vector>* list, const reflection::Schema& schema, const reflection::Object& object, const reflection::Field& index, google::protobuf::Arena* arena) : arena_(arena), list_(list), schema_(schema), object_(object), index_(index) { keys_.parent = this; } int size() const override { return list_ ? list_->size() : 0; } absl::StatusOr Has(const CelValue& key) const override { auto lookup_result = (*this)[key]; if (!lookup_result.has_value()) { return false; } auto result = *lookup_result; if (result.IsError()) { return *(result.ErrorOrDie()); } return true; } absl::optional operator[](CelValue cel_key) const override { if (!cel_key.IsString()) { return CreateErrorValue( arena_, absl::InvalidArgumentError( absl::StrCat("Invalid map key type: '", CelValue::TypeName(cel_key.type()), "'"))); } const absl::string_view key = cel_key.StringOrDie().value(); const auto it = std::lower_bound( list_->begin(), list_->end(), key, [this](const flatbuffers::Table* t, const absl::string_view key) { auto value = flatbuffers::GetFieldS(*t, index_); auto sv = value ? absl::string_view(value->c_str(), value->size()) : absl::string_view(); return sv < key; }); if (it != list_->end()) { auto value = flatbuffers::GetFieldS(**it, index_); auto sv = value ? absl::string_view(value->c_str(), value->size()) : absl::string_view(); if (sv == key) { return CelValue::CreateMap(google::protobuf::Arena::Create( arena_, **it, schema_, object_, arena_)); } } return absl::nullopt; } absl::StatusOr ListKeys() const override { return &keys_; } private: struct KeyList : public CelList { int size() const override { return parent->size(); } CelValue operator[](int index) const override { auto value = flatbuffers::GetFieldS(*(parent->list_->Get(index)), parent->index_); if (value == nullptr) { return CelValue::CreateStringView(absl::string_view()); } return CelValue::CreateStringView( absl::string_view(value->c_str(), value->size())); } ObjectStringIndexedMapImpl* parent; }; google::protobuf::Arena* arena_; const flatbuffers::Vector>* list_; const reflection::Schema& schema_; const reflection::Object& object_; const reflection::Field& index_; KeyList keys_; }; // Detects a "key" field of the type string. const reflection::Field* findStringKeyField(const reflection::Object& object) { for (const auto field : *object.fields()) { if (field->key() && field->type()->base_type() == reflection::String) { return field; } } return nullptr; } } // namespace absl::StatusOr FlatBuffersMapImpl::Has(const CelValue& key) const { auto lookup_result = (*this)[key]; if (!lookup_result.has_value()) { return false; } auto result = *lookup_result; if (result.IsError()) { return *(result.ErrorOrDie()); } return true; } absl::optional FlatBuffersMapImpl::operator[]( CelValue cel_key) const { if (!cel_key.IsString()) { return CreateErrorValue( arena_, absl::InvalidArgumentError( absl::StrCat("Invalid map key type: '", CelValue::TypeName(cel_key.type()), "'"))); } auto field = keys_.fields->LookupByKey(cel_key.StringOrDie().value().data()); if (field == nullptr) { return absl::nullopt; } switch (field->type()->base_type()) { case reflection::Byte: return CelValue::CreateInt64( flatbuffers::GetFieldI(table_, *field)); case reflection::Short: return CelValue::CreateInt64( flatbuffers::GetFieldI(table_, *field)); case reflection::Int: return CelValue::CreateInt64( flatbuffers::GetFieldI(table_, *field)); case reflection::Long: return CelValue::CreateInt64( flatbuffers::GetFieldI(table_, *field)); case reflection::UByte: return CelValue::CreateUint64( flatbuffers::GetFieldI(table_, *field)); case reflection::UShort: return CelValue::CreateUint64( flatbuffers::GetFieldI(table_, *field)); case reflection::UInt: return CelValue::CreateUint64( flatbuffers::GetFieldI(table_, *field)); case reflection::ULong: return CelValue::CreateUint64( flatbuffers::GetFieldI(table_, *field)); case reflection::Float: return CelValue::CreateDouble( flatbuffers::GetFieldF(table_, *field)); case reflection::Double: return CelValue::CreateDouble( flatbuffers::GetFieldF(table_, *field)); case reflection::Bool: return CelValue::CreateBool( flatbuffers::GetFieldI(table_, *field)); case reflection::String: { auto value = flatbuffers::GetFieldS(table_, *field); if (value == nullptr) { return CelValue::CreateStringView(absl::string_view()); } return CelValue::CreateStringView( absl::string_view(value->c_str(), value->size())); } case reflection::Obj: { const auto* field_schema = schema_.objects()->Get(field->type()->index()); const auto* field_table = flatbuffers::GetFieldT(table_, *field); if (field_table == nullptr) { return CelValue::CreateNull(); } if (field_schema) { return CelValue::CreateMap(google::protobuf::Arena::Create( arena_, *field_table, schema_, *field_schema, arena_)); } break; } case reflection::Vector: { switch (field->type()->element()) { case reflection::Byte: case reflection::UByte: { const auto* field_table = flatbuffers::GetFieldAnyV(table_, *field); if (field_table == nullptr) { return CelValue::CreateBytesView(absl::string_view()); } return CelValue::CreateBytesView(absl::string_view( reinterpret_cast(field_table->Data()), field_table->size())); } case reflection::Short: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::Int: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::Long: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::UShort: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::UInt: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::ULong: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::Float: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::Double: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::Bool: return CelValue::CreateList( google::protobuf::Arena::Create>( arena_, table_, *field)); case reflection::String: return CelValue::CreateList(google::protobuf::Arena::Create( arena_, table_.GetPointer>*>( field->offset()))); case reflection::Obj: { const auto* field_schema = schema_.objects()->Get(field->type()->index()); if (field_schema) { const auto* index = findStringKeyField(*field_schema); if (index) { return CelValue::CreateMap( google::protobuf::Arena::Create( arena_, table_.GetPointer>*>( field->offset()), schema_, *field_schema, *index, arena_)); } else { return CelValue::CreateList(google::protobuf::Arena::Create( arena_, table_.GetPointer>*>( field->offset()), schema_, *field_schema, arena_)); } } break; } default: // Unsupported vector base types return absl::nullopt; } break; } default: // Unsupported types: enums, unions, arrays return absl::nullopt; } return absl::nullopt; } const CelMap* CreateFlatBuffersBackedObject(const uint8_t* flatbuf, const reflection::Schema& schema, google::protobuf::Arena* arena) { return google::protobuf::Arena::Create( arena, *flatbuffers::GetAnyRoot(flatbuf), schema, *schema.root_table(), arena); } } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: tools/flatbuffers_backed_impl.h ================================================ #ifndef THIRD_PARTY_CEL_CPP_TOOLS_FLATBUFFERS_BACKED_IMPL_H_ #define THIRD_PARTY_CEL_CPP_TOOLS_FLATBUFFERS_BACKED_IMPL_H_ #include "eval/public/cel_value.h" #include "flatbuffers/reflection.h" namespace google { namespace api { namespace expr { namespace runtime { class FlatBuffersMapImpl : public CelMap { public: FlatBuffersMapImpl(const flatbuffers::Table& table, const reflection::Schema& schema, const reflection::Object& object, google::protobuf::Arena* arena) : arena_(arena), table_(table), schema_(schema) { keys_.fields = object.fields(); } int size() const override { return keys_.fields->size(); } absl::StatusOr Has(const CelValue& key) const override; absl::optional operator[](CelValue cel_key) const override; // Import base class signatures to bypass GCC warning/error. using CelMap::ListKeys; absl::StatusOr ListKeys() const override { return &keys_; } private: struct FieldList : public CelList { int size() const override { return fields->size(); } CelValue operator[](int index) const override { auto name = fields->Get(index)->name(); return CelValue::CreateStringView( absl::string_view(name->c_str(), name->size())); } const flatbuffers::Vector>* fields; }; FieldList keys_; google::protobuf::Arena* arena_; const flatbuffers::Table& table_; const reflection::Schema& schema_; }; // Factory method to instantiate a CelValue on the arena for flatbuffer object // from a reflection schema. const CelMap* CreateFlatBuffersBackedObject(const uint8_t* flatbuf, const reflection::Schema& schema, google::protobuf::Arena* arena); } // namespace runtime } // namespace expr } // namespace api } // namespace google #endif // THIRD_PARTY_CEL_CPP_TOOLS_FLATBUFFERS_BACKED_IMPL_H_ ================================================ FILE: tools/flatbuffers_backed_impl_test.cc ================================================ #include "tools/flatbuffers_backed_impl.h" #include #include "internal/status_macros.h" #include "internal/testing.h" #include "flatbuffers/idl.h" #include "flatbuffers/reflection.h" namespace google { namespace api { namespace expr { namespace runtime { namespace { constexpr char kReflectionBufferPath[] = "tools/testdata/" "flatbuffers.bfbs"; constexpr absl::string_view kByteField = "f_byte"; constexpr absl::string_view kUbyteField = "f_ubyte"; constexpr absl::string_view kShortField = "f_short"; constexpr absl::string_view kUshortField = "f_ushort"; constexpr absl::string_view kIntField = "f_int"; constexpr absl::string_view kUintField = "f_uint"; constexpr absl::string_view kLongField = "f_long"; constexpr absl::string_view kUlongField = "f_ulong"; constexpr absl::string_view kFloatField = "f_float"; constexpr absl::string_view kDoubleField = "f_double"; constexpr absl::string_view kBoolField = "f_bool"; constexpr absl::string_view kStringField = "f_string"; constexpr absl::string_view kObjField = "f_obj"; constexpr absl::string_view kUnknownField = "f_unknown"; constexpr absl::string_view kBytesField = "r_byte"; constexpr absl::string_view kUbytesField = "r_ubyte"; constexpr absl::string_view kShortsField = "r_short"; constexpr absl::string_view kUshortsField = "r_ushort"; constexpr absl::string_view kIntsField = "r_int"; constexpr absl::string_view kUintsField = "r_uint"; constexpr absl::string_view kLongsField = "r_long"; constexpr absl::string_view kUlongsField = "r_ulong"; constexpr absl::string_view kFloatsField = "r_float"; constexpr absl::string_view kDoublesField = "r_double"; constexpr absl::string_view kBoolsField = "r_bool"; constexpr absl::string_view kStringsField = "r_string"; constexpr absl::string_view kObjsField = "r_obj"; constexpr absl::string_view kIndexedField = "r_indexed"; const int64_t kNumFields = 27; class FlatBuffersTest : public testing::Test { public: FlatBuffersTest() { EXPECT_TRUE( flatbuffers::LoadFile(kReflectionBufferPath, true, &schema_file_)); flatbuffers::Verifier verifier( reinterpret_cast(schema_file_.data()), schema_file_.size()); EXPECT_TRUE(reflection::VerifySchemaBuffer(verifier)); EXPECT_TRUE(parser_.Deserialize( reinterpret_cast(schema_file_.data()), schema_file_.size())); schema_ = reflection::GetSchema(schema_file_.data()); } const CelMap& loadJson(std::string data) { EXPECT_TRUE(parser_.Parse(data.data())); const CelMap* value = CreateFlatBuffersBackedObject( parser_.builder_.GetBufferPointer(), *schema_, &arena_); EXPECT_NE(nullptr, value); EXPECT_EQ(kNumFields, value->size()); const CelList* keys = value->ListKeys().value(); EXPECT_NE(nullptr, keys); EXPECT_EQ(kNumFields, keys->size()); EXPECT_TRUE((*keys)[2].IsString()); return *value; } protected: std::string schema_file_; flatbuffers::Parser parser_; const reflection::Schema* schema_; google::protobuf::Arena arena_; }; TEST_F(FlatBuffersTest, PrimitiveFields) { const CelMap& value = loadJson(R"({ f_byte: -1, f_ubyte: 1, f_short: -2, f_ushort: 2, f_int: -3, f_uint: 3, f_long: -4, f_ulong: 4, f_float: 5.0, f_double: 6.0, f_bool: false, f_string: "test" })"); // byte { auto f = value[CelValue::CreateStringView(kByteField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsInt64()); EXPECT_EQ(-1, f->Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUbyteField)]; EXPECT_TRUE(uf.has_value()); EXPECT_TRUE(uf->IsUint64()); EXPECT_EQ(1, uf->Uint64OrDie()); } // short { auto f = value[CelValue::CreateStringView(kShortField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsInt64()); EXPECT_EQ(-2, f->Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUshortField)]; EXPECT_TRUE(uf.has_value()); EXPECT_TRUE(uf->IsUint64()); EXPECT_EQ(2, uf->Uint64OrDie()); } // int { auto f = value[CelValue::CreateStringView(kIntField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsInt64()); EXPECT_EQ(-3, f->Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUintField)]; EXPECT_TRUE(uf.has_value()); EXPECT_TRUE(uf->IsUint64()); EXPECT_EQ(3, uf->Uint64OrDie()); } // long { auto f = value[CelValue::CreateStringView(kLongField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsInt64()); EXPECT_EQ(-4, f->Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUlongField)]; EXPECT_TRUE(uf.has_value()); EXPECT_TRUE(uf->IsUint64()); EXPECT_EQ(4, uf->Uint64OrDie()); } // float and double { auto f = value[CelValue::CreateStringView(kFloatField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsDouble()); EXPECT_EQ(5.0, f->DoubleOrDie()); } { auto f = value[CelValue::CreateStringView(kDoubleField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsDouble()); EXPECT_EQ(6.0, f->DoubleOrDie()); } // bool { auto f = value[CelValue::CreateStringView(kBoolField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsBool()); EXPECT_EQ(false, f->BoolOrDie()); } // string { auto f = value[CelValue::CreateStringView(kStringField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsString()); EXPECT_EQ("test", f->StringOrDie().value()); } // bad field type { CelValue bad_field = CelValue::CreateInt64(1); auto f = value[bad_field]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsError()); auto presence = value.Has(bad_field); EXPECT_FALSE(presence.ok()); EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); } // missing field { auto f = value[CelValue::CreateStringView(kUnknownField)]; EXPECT_FALSE(f.has_value()); } } TEST_F(FlatBuffersTest, PrimitiveFieldDefaults) { const CelMap& value = loadJson("{}"); // byte { auto f = value[CelValue::CreateStringView(kByteField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsInt64()); EXPECT_EQ(0, f->Int64OrDie()); } // short { auto f = value[CelValue::CreateStringView(kShortField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsInt64()); EXPECT_EQ(150, f->Int64OrDie()); } // bool { auto f = value[CelValue::CreateStringView(kBoolField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsBool()); EXPECT_EQ(true, f->BoolOrDie()); } // string { auto f = value[CelValue::CreateStringView(kStringField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsString()); EXPECT_EQ("", f->StringOrDie().value()); } } TEST_F(FlatBuffersTest, ObjectField) { const CelMap& value = loadJson(R"({ f_obj: { f_string: "entry", f_int: 16 } })"); CelValue field = CelValue::CreateStringView(kObjField); auto presence = value.Has(field); EXPECT_OK(presence); EXPECT_TRUE(*presence); auto f = value[field]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsMap()); const CelMap& m = *f->MapOrDie(); EXPECT_EQ(2, m.size()); { auto obj_field = CelValue::CreateStringView(kStringField); auto member_presence = m.Has(obj_field); EXPECT_OK(member_presence); EXPECT_TRUE(*member_presence); auto mf = m[obj_field]; EXPECT_TRUE(mf.has_value()); EXPECT_TRUE(mf->IsString()); EXPECT_EQ("entry", mf->StringOrDie().value()); } { auto obj_field = CelValue::CreateStringView(kIntField); auto member_presence = m.Has(obj_field); EXPECT_OK(member_presence); EXPECT_TRUE(*member_presence); auto mf = m[obj_field]; EXPECT_TRUE(mf.has_value()); EXPECT_TRUE(mf->IsInt64()); EXPECT_EQ(16, mf->Int64OrDie()); } { std::string undefined = "f_undefined"; CelValue undefined_field = CelValue::CreateStringView(undefined); auto presence = m.Has(undefined_field); EXPECT_OK(presence); EXPECT_FALSE(*presence); auto v = m[undefined_field]; EXPECT_FALSE(v.has_value()); presence = m.Has(CelValue::CreateBool(false)); EXPECT_FALSE(presence.ok()); EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); } } TEST_F(FlatBuffersTest, ObjectFieldDefault) { const CelMap& value = loadJson("{}"); auto f = value[CelValue::CreateStringView(kObjField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsNull()); } TEST_F(FlatBuffersTest, PrimitiveVectorFields) { const CelMap& value = loadJson(R"({ r_byte: [-97], r_ubyte: [97, 98, 99], r_short: [-2], r_ushort: [2], r_int: [-3], r_uint: [3], r_long: [-4], r_ulong: [4], r_float: [5.0], r_double: [6.0], r_bool: [false], r_string: ["test"] })"); // byte { auto f = value[CelValue::CreateStringView(kBytesField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsBytes()); EXPECT_EQ("\x9F", f->BytesOrDie().value()); } { auto uf = value[CelValue::CreateStringView(kUbytesField)]; EXPECT_TRUE(uf.has_value()); EXPECT_TRUE(uf->IsBytes()); EXPECT_EQ("abc", uf->BytesOrDie().value()); } // short { auto f = value[CelValue::CreateStringView(kShortsField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(-2, l[0].Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUshortsField)]; EXPECT_TRUE(uf.has_value()); EXPECT_TRUE(uf->IsList()); const CelList& l = *uf->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(2, l[0].Uint64OrDie()); } // int { auto f = value[CelValue::CreateStringView(kIntsField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(-3, l[0].Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUintsField)]; EXPECT_TRUE(uf.has_value()); EXPECT_TRUE(uf->IsList()); const CelList& l = *uf->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(3, l[0].Uint64OrDie()); } // long { auto f = value[CelValue::CreateStringView(kLongsField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(-4, l[0].Int64OrDie()); } { auto uf = value[CelValue::CreateStringView(kUlongsField)]; EXPECT_TRUE(uf.has_value()); EXPECT_TRUE(uf->IsList()); const CelList& l = *uf->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(4, l[0].Uint64OrDie()); } // float and double { auto f = value[CelValue::CreateStringView(kFloatsField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(5.0, l[0].DoubleOrDie()); } { auto f = value[CelValue::CreateStringView(kDoublesField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(6.0, l[0].DoubleOrDie()); } // bool { auto f = value[CelValue::CreateStringView(kBoolsField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ(false, l[0].BoolOrDie()); } // string { auto f = value[CelValue::CreateStringView(kStringsField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(1, l.size()); EXPECT_EQ("test", l[0].StringOrDie().value()); } } TEST_F(FlatBuffersTest, ObjectVectorField) { const CelMap& value = loadJson(R"({ r_obj: [{ f_string: "entry", f_int: 16 },{ f_int: 32 }] })"); auto f = value[CelValue::CreateStringView(kObjsField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(2, l.size()); { EXPECT_TRUE(l[0].IsMap()); const CelMap& m = *l[0].MapOrDie(); EXPECT_EQ(2, m.size()); { CelValue field = CelValue::CreateStringView(kStringField); auto presence = m.Has(field); EXPECT_OK(presence); EXPECT_TRUE(*presence); auto mf = m[field]; EXPECT_TRUE(mf.has_value()); EXPECT_TRUE(mf->IsString()); EXPECT_EQ("entry", mf->StringOrDie().value()); } { CelValue field = CelValue::CreateStringView(kIntField); auto presence = m.Has(field); EXPECT_OK(presence); EXPECT_TRUE(*presence); auto mf = m[field]; EXPECT_TRUE(mf.has_value()); EXPECT_TRUE(mf->IsInt64()); EXPECT_EQ(16, mf->Int64OrDie()); } } { EXPECT_TRUE(l[1].IsMap()); const CelMap& m = *l[1].MapOrDie(); EXPECT_EQ(2, m.size()); { CelValue field = CelValue::CreateStringView(kStringField); auto presence = m.Has(field); EXPECT_OK(presence); // Note, the presence checks on flat buffers seem to only apply to whether // the field is defined. EXPECT_TRUE(*presence); auto mf = m[field]; EXPECT_TRUE(mf.has_value()); EXPECT_TRUE(mf->IsString()); EXPECT_EQ("", mf->StringOrDie().value()); } { CelValue field = CelValue::CreateStringView(kIntField); auto presence = m.Has(field); EXPECT_OK(presence); EXPECT_TRUE(*presence); auto mf = m[field]; EXPECT_TRUE(mf.has_value()); EXPECT_TRUE(mf->IsInt64()); EXPECT_EQ(32, mf->Int64OrDie()); } { std::string undefined = "f_undefined"; CelValue field = CelValue::CreateStringView(undefined); auto presence = m.Has(field); EXPECT_OK(presence); EXPECT_FALSE(*presence); auto mf = m[field]; EXPECT_FALSE(mf.has_value()); } } } TEST_F(FlatBuffersTest, VectorFieldDefaults) { const CelMap& value = loadJson("{}"); for (const auto field : std::vector{ kIntsField, kBoolsField, kStringsField, kObjsField}) { auto f = value[CelValue::CreateStringView(field)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsList()); const CelList& l = *f->ListOrDie(); EXPECT_EQ(0, l.size()); } { auto f = value[CelValue::CreateStringView(kIndexedField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsMap()); const CelMap& m = *f->MapOrDie(); EXPECT_EQ(0, m.size()); EXPECT_EQ(0, (*m.ListKeys())->size()); } { auto f = value[CelValue::CreateStringView(kBytesField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsBytes()); EXPECT_EQ("", f->BytesOrDie().value()); } } TEST_F(FlatBuffersTest, IndexedObjectVectorField) { const CelMap& value = loadJson(R"({ r_indexed: [ { f_string: "a", f_int: 16 }, { f_string: "b", f_int: 32 }, { f_string: "c", f_int: 64 }, { f_string: "d", f_int: 128 } ] })"); auto f = value[CelValue::CreateStringView(kIndexedField)]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsMap()); const CelMap& m = *f->MapOrDie(); EXPECT_EQ(4, m.size()); const CelList& l = *m.ListKeys().value(); EXPECT_EQ(4, l.size()); EXPECT_TRUE(l[0].IsString()); EXPECT_TRUE(l[1].IsString()); EXPECT_TRUE(l[2].IsString()); EXPECT_TRUE(l[3].IsString()); std::string a = "a"; std::string b = "b"; std::string c = "c"; std::string d = "d"; EXPECT_EQ(a, l[0].StringOrDie().value()); EXPECT_EQ(b, l[1].StringOrDie().value()); EXPECT_EQ(c, l[2].StringOrDie().value()); EXPECT_EQ(d, l[3].StringOrDie().value()); for (const std::string& key : std::vector{a, b, c, d}) { auto v = m[CelValue::CreateString(&key)]; EXPECT_TRUE(v.has_value()); const CelMap& vm = *v->MapOrDie(); EXPECT_EQ(2, vm.size()); auto vf = vm[CelValue::CreateStringView(kStringField)]; EXPECT_TRUE(vf.has_value()); EXPECT_TRUE(vf->IsString()); EXPECT_EQ(key, vf->StringOrDie().value()); auto vi = vm[CelValue::CreateStringView(kIntField)]; EXPECT_TRUE(vi.has_value()); EXPECT_TRUE(vi->IsInt64()); } { std::string bb = "bb"; std::string dd = "dd"; EXPECT_FALSE(m[CelValue::CreateString(&bb)].has_value()); EXPECT_FALSE(m[CelValue::CreateString(&dd)].has_value()); EXPECT_FALSE( m[CelValue::CreateStringView(absl::string_view())].has_value()); } } TEST_F(FlatBuffersTest, IndexedObjectVectorFieldDefaults) { const CelMap& value = loadJson(R"({ r_indexed: [ { f_string: "", f_int: 16 } ] })"); CelValue field = CelValue::CreateStringView(kIndexedField); auto presence = value.Has(field); EXPECT_OK(presence); EXPECT_TRUE(*presence); auto f = value[field]; EXPECT_TRUE(f.has_value()); EXPECT_TRUE(f->IsMap()); const CelMap& m = *f->MapOrDie(); EXPECT_EQ(1, m.size()); const CelList& l = *m.ListKeys().value(); EXPECT_EQ(1, l.size()); EXPECT_TRUE(l[0].IsString()); EXPECT_EQ("", l[0].StringOrDie().value()); CelValue map_field = CelValue::CreateStringView(absl::string_view()); presence = m.Has(map_field); EXPECT_OK(presence); EXPECT_TRUE(*presence); auto v = m[map_field]; EXPECT_TRUE(v.has_value()); std::string undefined = "f_undefined"; CelValue undefined_field = CelValue::CreateStringView(undefined); presence = m.Has(undefined_field); EXPECT_OK(presence); EXPECT_FALSE(*presence); v = m[undefined_field]; EXPECT_FALSE(v.has_value()); presence = m.Has(CelValue::CreateBool(false)); EXPECT_FALSE(presence.ok()); EXPECT_EQ(presence.status().code(), absl::StatusCode::kInvalidArgument); } } // namespace } // namespace runtime } // namespace expr } // namespace api } // namespace google ================================================ FILE: tools/navigable_ast.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/navigable_ast.h" #include #include #include #include #include #include #include "cel/expr/checked.pb.h" #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/memory/memory.h" #include "common/ast/navigable_ast_internal.h" #include "eval/public/ast_traverse.h" #include "eval/public/ast_visitor.h" #include "eval/public/ast_visitor_base.h" #include "eval/public/source_position.h" namespace cel { namespace { using ::cel::expr::Expr; using ::google::api::expr::runtime::AstTraverse; using ::google::api::expr::runtime::SourcePosition; using AstNode = NavigableProtoAstNode; using NavigableAstNodeData = common_internal::NavigableAstNodeData; using NavigableAstMetadata = common_internal::NavigableAstMetadata; NodeKind GetNodeKind(const Expr& expr) { switch (expr.expr_kind_case()) { case Expr::kConstExpr: return NodeKind::kConstant; case Expr::kIdentExpr: return NodeKind::kIdent; case Expr::kSelectExpr: return NodeKind::kSelect; case Expr::kCallExpr: return NodeKind::kCall; case Expr::kListExpr: return NodeKind::kList; case Expr::kStructExpr: if (!expr.struct_expr().message_name().empty()) { return NodeKind::kStruct; } else { return NodeKind::kMap; } case Expr::kComprehensionExpr: return NodeKind::kComprehension; case Expr::EXPR_KIND_NOT_SET: default: return NodeKind::kUnspecified; } } // Get the traversal relationship from parent to the given node. // Note: these depend on the ast_visitor utility's traversal ordering. ChildKind GetChildKind(const NavigableAstNodeData& parent_node, size_t child_index) { constexpr size_t kComprehensionRangeArgIndex = google::api::expr::runtime::ITER_RANGE; constexpr size_t kComprehensionInitArgIndex = google::api::expr::runtime::ACCU_INIT; constexpr size_t kComprehensionConditionArgIndex = google::api::expr::runtime::LOOP_CONDITION; constexpr size_t kComprehensionLoopStepArgIndex = google::api::expr::runtime::LOOP_STEP; constexpr size_t kComprehensionResultArgIndex = google::api::expr::runtime::RESULT; switch (parent_node.node_kind) { case NodeKind::kStruct: return ChildKind::kStructValue; case NodeKind::kMap: if (child_index % 2 == 0) { return ChildKind::kMapKey; } return ChildKind::kMapValue; case NodeKind::kList: return ChildKind::kListElem; case NodeKind::kSelect: return ChildKind::kSelectOperand; case NodeKind::kCall: if (child_index == 0 && parent_node.expr->call_expr().has_target()) { return ChildKind::kCallReceiver; } return ChildKind::kCallArg; case NodeKind::kComprehension: switch (child_index) { case kComprehensionRangeArgIndex: return ChildKind::kComprehensionRange; case kComprehensionInitArgIndex: return ChildKind::kComprehensionInit; case kComprehensionConditionArgIndex: return ChildKind::kComprehensionCondition; case kComprehensionLoopStepArgIndex: return ChildKind::kComprehensionLoopStep; case kComprehensionResultArgIndex: return ChildKind::kComprensionResult; default: return ChildKind::kUnspecified; } default: return ChildKind::kUnspecified; } } class NavigableExprBuilderVisitor : public google::api::expr::runtime::AstVisitorBase { public: NavigableExprBuilderVisitor( absl::AnyInvocable()> node_factory, absl::AnyInvocable node_data_accessor) : node_factory_(std::move(node_factory)), node_data_accessor_(std::move(node_data_accessor)), metadata_(std::make_unique()) {} NavigableAstNodeData& NodeDataAt(size_t index) { return node_data_accessor_(*metadata_->nodes[index]); } void PreVisitExpr(const Expr* expr, const SourcePosition* position) override { NavigableProtoAstNode* parent = parent_stack_.empty() ? nullptr : metadata_->nodes[parent_stack_.back()].get(); size_t index = metadata_->nodes.size(); metadata_->nodes.push_back(node_factory_()); NavigableProtoAstNode* node = metadata_->nodes[index].get(); auto& node_data = NodeDataAt(index); node_data.parent = parent; node_data.expr = expr; node_data.parent_relation = ChildKind::kUnspecified; node_data.node_kind = GetNodeKind(*expr); node_data.tree_size = 1; node_data.height = 1; node_data.index = index; node_data.child_index = -1; node_data.metadata = metadata_.get(); metadata_->id_to_node.insert({expr->id(), node}); metadata_->expr_to_node.insert({expr, node}); if (!parent_stack_.empty()) { auto& parent_node_data = NodeDataAt(parent_stack_.back()); size_t child_index = parent_node_data.children.size(); parent_node_data.children.push_back(node); node_data.parent_relation = GetChildKind(parent_node_data, child_index); node_data.child_index = child_index; } parent_stack_.push_back(index); } void PostVisitExpr(const Expr* expr, const SourcePosition* position) override { size_t idx = parent_stack_.back(); parent_stack_.pop_back(); metadata_->postorder.push_back(metadata_->nodes[idx].get()); NavigableAstNodeData& node = NodeDataAt(idx); if (!parent_stack_.empty()) { auto& parent_node_data = NodeDataAt(parent_stack_.back()); parent_node_data.tree_size += node.tree_size; parent_node_data.height = std::max(parent_node_data.height, node.height + 1); } } std::unique_ptr Consume() && { return std::move(metadata_); } private: absl::AnyInvocable()> node_factory_; absl::AnyInvocable node_data_accessor_; std::unique_ptr metadata_; std::vector parent_stack_; }; } // namespace NavigableProtoAst NavigableProtoAst::Build(const Expr& expr) { NavigableExprBuilderVisitor visitor( []() { return absl::WrapUnique(new AstNode()); }, [](AstNode& node) -> NavigableAstNodeData& { return node.data_; }); AstTraverse(&expr, /*source_info=*/nullptr, &visitor); return NavigableProtoAst(std::move(visitor).Consume()); } } // namespace cel ================================================ FILE: tools/navigable_ast.h ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_TOOLS_NAVIGABLE_AST_H_ #define THIRD_PARTY_CEL_CPP_TOOLS_NAVIGABLE_AST_H_ #include "cel/expr/syntax.pb.h" #include "common/ast/navigable_ast_internal.h" #include "common/ast/navigable_ast_kinds.h" // IWYU pragma: export namespace cel { class NavigableProtoAst; class NavigableProtoAstNode; namespace common_internal { struct ProtoAstTraits { using ExprType = cel::expr::Expr; using AstType = NavigableProtoAst; using NodeType = NavigableProtoAstNode; }; } // namespace common_internal // Wrapper around a CEL AST node that exposes traversal information. class NavigableProtoAstNode : public common_internal::NavigableAstNodeBase< common_internal::ProtoAstTraits> { private: using Base = common_internal::NavigableAstNodeBase; public: // A const Span like type that provides pre-order traversal for a sub tree. // provides .begin() and .end() returning bidirectional iterators to // const AstNode&. using PreorderRange = Base::PreorderRange; // A const Span like type that provides post-order traversal for a sub tree. // provides .begin() and .end() returning bidirectional iterators to // const AstNode&. using PostorderRange = Base::PostorderRange; // The parent of this node or nullptr if it is a root. using Base::parent; // The ptr to the backing Expr in the source AST. // // This may dangle if the source AST is mutated or destroyed. using Base::expr; // The index of this node in the parent's children. -1 if this is a root. using Base::child_index; // The type of traversal from parent to this node. using Base::parent_relation; // The type of this node, analogous to Expr::ExprKindCase. using Base::node_kind; // The number of nodes in the tree rooted at this node (including self). using Base::tree_size; // The height of this node in the tree (the number of descendants including // self on the longest path). using Base::height; // The children of this node in their natural order. using Base::children; // Range over the descendants of this node (including self) using preorder // semantics. Each node is visited immediately before all of its descendants. // // example: // for (const cel::NavigableProtoAstNode& node : // ast.Root().DescendantsPreorder()) { // ... // } // // Children are traversed in their natural order: // - call arguments are traversed in order (receiver if present is first) // - list elements are traversed in order // - maps are traversed in order (alternating key, value per entry) // - comprehensions are traversed in the order: range, accu_init, condition, // step, result using Base::DescendantsPreorder; // Range over the descendants of this node (including self) using postorder // semantics. Each node is visited immediately after all of its descendants. using Base::DescendantsPostorder; private: friend class NavigableProtoAst; NavigableProtoAstNode() = default; }; // NavigableExpr provides a view over a CEL AST that allows for generalized // traversal. The traversal structures are eagerly built on construction, // requiring a full traversal of the AST. This is intended for use in tools that // might require random access or multiple passes over the AST, amortizing the // cost of building the traversal structures. // // Pointers to AstNodes are owned by this instance and must not outlive it. // // `NavigableAst` and Navigable nodes are independent of the input Expr and may // outlive it, but may contain dangling pointers if the input Expr is modified // or destroyed. class NavigableProtoAst : public common_internal::NavigableAstBase< common_internal::ProtoAstTraits> { private: using Base = common_internal::NavigableAstBase; public: static NavigableProtoAst Build(const cel::expr::Expr& expr); // Default constructor creates an empty instance. // // Operations other than equality are undefined on an empty instance. // // This is intended for composed object construction, a new NavigableProtoAst // should be obtained from the Build factory function. NavigableProtoAst() = default; // Move only. NavigableProtoAst(const NavigableProtoAst&) = delete; NavigableProtoAst& operator=(const NavigableProtoAst&) = delete; NavigableProtoAst(NavigableProtoAst&&) = default; NavigableProtoAst& operator=(NavigableProtoAst&&) = default; // Return ptr to the AST node with id if present. Otherwise returns nullptr. // // If ids are non-unique, the first pre-order node encountered with id is // returned. using Base::FindId; // Return ptr to the AST node representing the given Expr node. using Base::FindExpr; // Returns the root of the AST. using Base::Root; // Return whether the source AST used unique IDs for each node. // // This is typically the case, but older versions of the parsers didn't // guarantee uniqueness for nodes generated by some macros and ASTs modified // outside of CEL's parse/type check may not have unique IDs. using Base::IdsAreUnique; private: using Base::Base; }; } // namespace cel #endif // THIRD_PARTY_CEL_CPP_TOOLS_NAVIGABLE_AST_H_ ================================================ FILE: tools/navigable_ast_test.cc ================================================ // Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tools/navigable_ast.h" #include #include #include "cel/expr/syntax.pb.h" #include "base/builtins.h" #include "internal/testing.h" #include "parser/parser.h" namespace cel { namespace { using ::cel::expr::Expr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::SizeIs; TEST(NavigableProtoAst, Basic) { Expr const_node; const_node.set_id(1); const_node.mutable_const_expr()->set_int64_value(42); NavigableProtoAst ast = NavigableProtoAst::Build(const_node); EXPECT_TRUE(ast.IdsAreUnique()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.expr(), &const_node); EXPECT_THAT(root.children(), IsEmpty()); EXPECT_TRUE(root.parent() == nullptr); EXPECT_EQ(root.child_index(), -1); EXPECT_EQ(root.node_kind(), NodeKind::kConstant); EXPECT_EQ(root.parent_relation(), ChildKind::kUnspecified); } TEST(NavigableProtoAst, DefaultCtorEmpty) { Expr const_node; const_node.set_id(1); const_node.mutable_const_expr()->set_int64_value(42); NavigableProtoAst ast = NavigableProtoAst::Build(const_node); EXPECT_EQ(ast, ast); NavigableProtoAst empty; EXPECT_NE(ast, empty); EXPECT_EQ(empty, empty); EXPECT_TRUE(static_cast(ast)); EXPECT_FALSE(static_cast(empty)); NavigableProtoAst moved = std::move(ast); EXPECT_EQ(ast, empty); EXPECT_FALSE(static_cast(ast)); EXPECT_TRUE(static_cast(moved)); } TEST(NavigableProtoAst, FindById) { Expr const_node; const_node.set_id(1); const_node.mutable_const_expr()->set_int64_value(42); NavigableProtoAst ast = NavigableProtoAst::Build(const_node); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(ast.FindId(const_node.id()), &root); EXPECT_EQ(ast.FindId(-1), nullptr); } MATCHER_P(AstNodeWrapping, expr, "") { const NavigableProtoAstNode* ptr = arg; return ptr != nullptr && ptr->expr() == expr; } TEST(NavigableProtoAst, ToleratesNonUnique) { Expr call_node; call_node.set_id(1); call_node.mutable_call_expr()->set_function(cel::builtin::kNot); Expr* const_node = call_node.mutable_call_expr()->add_args(); const_node->mutable_const_expr()->set_bool_value(false); const_node->set_id(1); NavigableProtoAst ast = NavigableProtoAst::Build(call_node); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(ast.FindId(1), &root); EXPECT_EQ(ast.FindExpr(&call_node), &root); EXPECT_FALSE(ast.IdsAreUnique()); EXPECT_THAT(ast.FindExpr(const_node), AstNodeWrapping(const_node)); } TEST(NavigableProtoAst, FindByExprPtr) { Expr const_node; const_node.set_id(1); const_node.mutable_const_expr()->set_int64_value(42); NavigableProtoAst ast = NavigableProtoAst::Build(const_node); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(ast.FindExpr(&const_node), &root); EXPECT_EQ(ast.FindExpr(&Expr::default_instance()), nullptr); } TEST(NavigableProtoAst, Children) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + 2")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.expr(), &parsed_expr.expr()); EXPECT_THAT(root.children(), SizeIs(2)); EXPECT_TRUE(root.parent() == nullptr); EXPECT_EQ(root.child_index(), -1); EXPECT_EQ(root.parent_relation(), ChildKind::kUnspecified); EXPECT_EQ(root.node_kind(), NodeKind::kCall); EXPECT_THAT( root.children(), ElementsAre(AstNodeWrapping(&parsed_expr.expr().call_expr().args(0)), AstNodeWrapping(&parsed_expr.expr().call_expr().args(1)))); ASSERT_THAT(root.children(), SizeIs(2)); const auto* child1 = root.children()[0]; EXPECT_EQ(child1->child_index(), 0); EXPECT_EQ(child1->parent(), &root); EXPECT_EQ(child1->parent_relation(), ChildKind::kCallArg); EXPECT_EQ(child1->node_kind(), NodeKind::kConstant); EXPECT_THAT(child1->children(), IsEmpty()); const auto* child2 = root.children()[1]; EXPECT_EQ(child2->child_index(), 1); } TEST(NavigableProtoAst, UnspecifiedExpr) { Expr expr; expr.set_id(1); NavigableProtoAst ast = NavigableProtoAst::Build(expr); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.expr(), &expr); EXPECT_THAT(root.children(), SizeIs(0)); EXPECT_TRUE(root.parent() == nullptr); EXPECT_EQ(root.child_index(), -1); EXPECT_EQ(root.node_kind(), NodeKind::kUnspecified); } TEST(NavigableProtoAst, ParentRelationSelect) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("a.b")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); ASSERT_THAT(root.children(), SizeIs(1)); const auto* child = root.children()[0]; EXPECT_EQ(child->parent_relation(), ChildKind::kSelectOperand); EXPECT_EQ(child->node_kind(), NodeKind::kIdent); } TEST(NavigableProtoAst, ParentRelationCallReceiver) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("a.b()")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); ASSERT_THAT(root.children(), SizeIs(1)); const auto* child = root.children()[0]; EXPECT_EQ(child->parent_relation(), ChildKind::kCallReceiver); EXPECT_EQ(child->node_kind(), NodeKind::kIdent); } TEST(NavigableProtoAst, ParentRelationCreateStruct) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("com.example.Type{field: '123'}")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kStruct); ASSERT_THAT(root.children(), SizeIs(1)); const auto* child = root.children()[0]; EXPECT_EQ(child->parent_relation(), ChildKind::kStructValue); EXPECT_EQ(child->node_kind(), NodeKind::kConstant); } TEST(NavigableProtoAst, ParentRelationCreateMap) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'a': 123}")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kMap); ASSERT_THAT(root.children(), SizeIs(2)); const auto* key = root.children()[0]; const auto* value = root.children()[1]; EXPECT_EQ(key->parent_relation(), ChildKind::kMapKey); EXPECT_EQ(key->node_kind(), NodeKind::kConstant); EXPECT_EQ(value->parent_relation(), ChildKind::kMapValue); EXPECT_EQ(value->node_kind(), NodeKind::kConstant); } TEST(NavigableProtoAst, ParentRelationCreateList) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[123]")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kList); ASSERT_THAT(root.children(), SizeIs(1)); const auto* child = root.children()[0]; EXPECT_EQ(child->parent_relation(), ChildKind::kListElem); EXPECT_EQ(child->node_kind(), NodeKind::kConstant); } TEST(NavigableProtoAst, ParentRelationComprehension) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1].all(x, x < 2)")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); ASSERT_THAT(root.children(), SizeIs(5)); const auto* range = root.children()[0]; const auto* init = root.children()[1]; const auto* condition = root.children()[2]; const auto* step = root.children()[3]; const auto* finish = root.children()[4]; EXPECT_EQ(range->parent_relation(), ChildKind::kComprehensionRange); EXPECT_EQ(init->parent_relation(), ChildKind::kComprehensionInit); EXPECT_EQ(condition->parent_relation(), ChildKind::kComprehensionCondition); EXPECT_EQ(step->parent_relation(), ChildKind::kComprehensionLoopStep); EXPECT_EQ(finish->parent_relation(), ChildKind::kComprensionResult); } TEST(NavigableProtoAst, DescendantsPostorder) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + (x * 3)")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kCall); std::vector constants; std::vector node_kinds; for (const NavigableProtoAstNode& node : root.DescendantsPostorder()) { if (node.node_kind() == NodeKind::kConstant) { constants.push_back(node.expr()->const_expr().int64_value()); } node_kinds.push_back(node.node_kind()); } EXPECT_THAT(node_kinds, ElementsAre(NodeKind::kConstant, NodeKind::kIdent, NodeKind::kConstant, NodeKind::kCall, NodeKind::kCall)); EXPECT_THAT(constants, ElementsAre(1, 3)); } TEST(NavigableProtoAst, DescendantsPreorder) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("1 + (x * 3)")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kCall); std::vector constants; std::vector node_kinds; for (const NavigableProtoAstNode& node : root.DescendantsPreorder()) { if (node.node_kind() == NodeKind::kConstant) { constants.push_back(node.expr()->const_expr().int64_value()); } node_kinds.push_back(node.node_kind()); } EXPECT_THAT(node_kinds, ElementsAre(NodeKind::kCall, NodeKind::kConstant, NodeKind::kCall, NodeKind::kIdent, NodeKind::kConstant)); EXPECT_THAT(constants, ElementsAre(1, 3)); } TEST(NavigableProtoAst, DescendantsPreorderComprehension) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); std::vector> node_kinds; for (const NavigableProtoAstNode& node : root.DescendantsPreorder()) { node_kinds.push_back( std::make_pair(node.node_kind(), node.parent_relation())); } EXPECT_THAT( node_kinds, ElementsAre(Pair(NodeKind::kComprehension, ChildKind::kUnspecified), Pair(NodeKind::kList, ChildKind::kComprehensionRange), Pair(NodeKind::kConstant, ChildKind::kListElem), Pair(NodeKind::kConstant, ChildKind::kListElem), Pair(NodeKind::kConstant, ChildKind::kListElem), Pair(NodeKind::kList, ChildKind::kComprehensionInit), Pair(NodeKind::kConstant, ChildKind::kComprehensionCondition), Pair(NodeKind::kCall, ChildKind::kComprehensionLoopStep), Pair(NodeKind::kIdent, ChildKind::kCallArg), Pair(NodeKind::kList, ChildKind::kCallArg), Pair(NodeKind::kCall, ChildKind::kListElem), Pair(NodeKind::kIdent, ChildKind::kCallArg), Pair(NodeKind::kConstant, ChildKind::kCallArg), Pair(NodeKind::kIdent, ChildKind::kComprensionResult))); } TEST(NavigableProtoAst, TreeSize) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); std::vector> node_kinds; EXPECT_EQ(root.tree_size(), 14); auto it = root.DescendantsPostorder().begin(); EXPECT_EQ(it->tree_size(), 1); } TEST(NavigableProtoAst, Height) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); std::vector> node_kinds; EXPECT_EQ(root.height(), 5); auto it = root.DescendantsPostorder().begin(); EXPECT_EQ(it->height(), 1); } TEST(NavigableProtoAst, DescendantsPreorderCreateMap) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'key1': 1, 'key2': 2}")); NavigableProtoAst ast = NavigableProtoAst::Build(parsed_expr.expr()); const NavigableProtoAstNode& root = ast.Root(); EXPECT_EQ(root.node_kind(), NodeKind::kMap); std::vector> node_kinds; for (const NavigableProtoAstNode& node : root.DescendantsPreorder()) { node_kinds.push_back( std::make_pair(node.node_kind(), node.parent_relation())); } EXPECT_THAT(node_kinds, ElementsAre(Pair(NodeKind::kMap, ChildKind::kUnspecified), Pair(NodeKind::kConstant, ChildKind::kMapKey), Pair(NodeKind::kConstant, ChildKind::kMapValue), Pair(NodeKind::kConstant, ChildKind::kMapKey), Pair(NodeKind::kConstant, ChildKind::kMapValue))); } } // namespace } // namespace cel ================================================ FILE: tools/testdata/BUILD ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load( "@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public", ) load("@rules_cc//cc:cc_library.bzl", "cc_library") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) flatbuffer_library_public( name = "flatbuffers_test", srcs = ["flatbuffers.fbs"], outs = ["flatbuffers_generated.h"], language_flag = "-c", reflection_name = "flatbuffers_reflection", ) filegroup( name = "coverage_testdata", srcs = [ "coverage_example.textproto", "exists_macro.textproto", ], ) cc_library( name = "flatbuffers_test_cc", srcs = [":flatbuffers_test"], hdrs = [":flatbuffers_test"], features = ["-parse_headers"], linkstatic = True, deps = ["@com_github_google_flatbuffers//:runtime_cc"], ) ================================================ FILE: tools/testdata/checked_expr_and.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr # x && y reference_map { key: 1 value { name: "x" } } reference_map { key: 2 value { name: "y" } } reference_map { key: 3 value { overload_id: "logical_and" } } type_map { key: 1 value { primitive: BOOL } } type_map { key: 2 value { primitive: BOOL } } type_map { key: 3 value { primitive: BOOL } } expr { id: 3 call_expr { function: "_&&_" args { id: 1 ident_expr { name: "x" } } args { id: 2 ident_expr { name: "y" } } } } source_info { location: "" line_offsets: 7 positions { key: 1 value: 0 } positions { key: 2 value: 5 } positions { key: 3 value: 2 } } ================================================ FILE: tools/testdata/const_str.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr type_map { key: 1 value { primitive: STRING } } expr { id: 1 const_expr { string_value: "127.0.0.1" } } source_info { location: "" line_offsets: 12 positions { key: 1 value: 0 } } ================================================ FILE: tools/testdata/coverage_example.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr # # int1 < int2 && # (43 > 42) && # !(bool1 || bool2) && # 4 / int_divisor >= 1 && # (ternary_c ? ternary_t : ternary_f) reference_map: { key: 1 value: { name: "int1" } } reference_map: { key: 2 value: { overload_id: "less_int64" } } reference_map: { key: 3 value: { name: "int2" } } reference_map: { key: 5 value: { overload_id: "greater_int64" } } reference_map: { key: 7 value: { overload_id: "logical_and" } } reference_map: { key: 8 value: { overload_id: "logical_not" } } reference_map: { key: 9 value: { name: "bool1" } } reference_map: { key: 10 value: { name: "bool2" } } reference_map: { key: 11 value: { overload_id: "logical_or" } } reference_map: { key: 12 value: { overload_id: "logical_and" } } reference_map: { key: 14 value: { overload_id: "divide_int64" } } reference_map: { key: 15 value: { name: "int_divisor" } } reference_map: { key: 16 value: { overload_id: "greater_equals_int64" } } reference_map: { key: 18 value: { overload_id: "logical_and" } } reference_map: { key: 19 value: { name: "ternary_c" } } reference_map: { key: 20 value: { overload_id: "conditional" } } reference_map: { key: 21 value: { name: "ternary_t" } } reference_map: { key: 22 value: { name: "ternary_f" } } reference_map: { key: 23 value: { overload_id: "logical_and" } } type_map: { key: 1 value: { primitive: INT64 } } type_map: { key: 2 value: { primitive: BOOL } } type_map: { key: 3 value: { primitive: INT64 } } type_map: { key: 4 value: { primitive: INT64 } } type_map: { key: 5 value: { primitive: BOOL } } type_map: { key: 6 value: { primitive: INT64 } } type_map: { key: 7 value: { primitive: BOOL } } type_map: { key: 8 value: { primitive: BOOL } } type_map: { key: 9 value: { primitive: BOOL } } type_map: { key: 10 value: { primitive: BOOL } } type_map: { key: 11 value: { primitive: BOOL } } type_map: { key: 12 value: { primitive: BOOL } } type_map: { key: 13 value: { primitive: INT64 } } type_map: { key: 14 value: { primitive: INT64 } } type_map: { key: 15 value: { primitive: INT64 } } type_map: { key: 16 value: { primitive: BOOL } } type_map: { key: 17 value: { primitive: INT64 } } type_map: { key: 18 value: { primitive: BOOL } } type_map: { key: 19 value: { primitive: BOOL } } type_map: { key: 20 value: { primitive: BOOL } } type_map: { key: 21 value: { primitive: BOOL } } type_map: { key: 22 value: { primitive: BOOL } } type_map: { key: 23 value: { primitive: BOOL } } source_info: { location: "" line_offsets: 109 positions: { key: 1 value: 0 } positions: { key: 2 value: 5 } positions: { key: 3 value: 7 } positions: { key: 4 value: 16 } positions: { key: 5 value: 19 } positions: { key: 6 value: 21 } positions: { key: 7 value: 12 } positions: { key: 8 value: 28 } positions: { key: 9 value: 30 } positions: { key: 10 value: 39 } positions: { key: 11 value: 36 } positions: { key: 12 value: 25 } positions: { key: 13 value: 49 } positions: { key: 14 value: 51 } positions: { key: 15 value: 53 } positions: { key: 16 value: 65 } positions: { key: 17 value: 68 } positions: { key: 18 value: 46 } positions: { key: 19 value: 74 } positions: { key: 20 value: 84 } positions: { key: 21 value: 86 } positions: { key: 22 value: 98 } positions: { key: 23 value: 70 } } expr: { id: 18 call_expr: { function: "_&&_" args: { id: 12 call_expr: { function: "_&&_" args: { id: 7 call_expr: { function: "_&&_" args: { id: 2 call_expr: { function: "_<_" args: { id: 1 ident_expr: { name: "int1" } } args: { id: 3 ident_expr: { name: "int2" } } } } args: { id: 5 call_expr: { function: "_>_" args: { id: 4 const_expr: { int64_value: 43 } } args: { id: 6 const_expr: { int64_value: 42 } } } } } } args: { id: 8 call_expr: { function: "!_" args: { id: 11 call_expr: { function: "_||_" args: { id: 9 ident_expr: { name: "bool1" } } args: { id: 10 ident_expr: { name: "bool2" } } } } } } } } args: { id: 23 call_expr: { function: "_&&_" args: { id: 16 call_expr: { function: "_>=_" args: { id: 14 call_expr: { function: "_/_" args: { id: 13 const_expr: { int64_value: 4 } } args: { id: 15 ident_expr: { name: "int_divisor" } } } } args: { id: 17 const_expr: { int64_value: 1 } } } } args: { id: 20 call_expr: { function: "_?_:_" args: { id: 19 ident_expr: { name: "ternary_c" } } args: { id: 21 ident_expr: { name: "ternary_t" } } args: { id: 22 ident_expr: { name: "ternary_f" } } } } } } } } ================================================ FILE: tools/testdata/exists_macro.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr # [1].exists(x, x == 1) reference_map: { key: 5 value: { name: "x" } } reference_map: { key: 6 value: { overload_id: "equals" } } reference_map: { key: 9 value: { name: "__result__" } } reference_map: { key: 10 value: { overload_id: "logical_not" } } reference_map: { key: 11 value: { overload_id: "not_strictly_false" } } reference_map: { key: 12 value: { name: "__result__" } } reference_map: { key: 13 value: { overload_id: "logical_or" } } reference_map: { key: 14 value: { name: "__result__" } } type_map: { key: 1 value: { list_type: { elem_type: { primitive: INT64 } } } } type_map: { key: 2 value: { primitive: INT64 } } type_map: { key: 5 value: { primitive: INT64 } } type_map: { key: 6 value: { primitive: BOOL } } type_map: { key: 7 value: { primitive: INT64 } } type_map: { key: 8 value: { primitive: BOOL } } type_map: { key: 9 value: { primitive: BOOL } } type_map: { key: 10 value: { primitive: BOOL } } type_map: { key: 11 value: { primitive: BOOL } } type_map: { key: 12 value: { primitive: BOOL } } type_map: { key: 13 value: { primitive: BOOL } } type_map: { key: 14 value: { primitive: BOOL } } type_map: { key: 15 value: { primitive: BOOL } } source_info: { location: "" line_offsets: 22 positions: { key: 1 value: 0 } positions: { key: 2 value: 1 } positions: { key: 3 value: 10 } positions: { key: 4 value: 11 } positions: { key: 5 value: 14 } positions: { key: 6 value: 16 } positions: { key: 7 value: 19 } positions: { key: 8 value: 10 } positions: { key: 9 value: 10 } positions: { key: 10 value: 10 } positions: { key: 11 value: 10 } positions: { key: 12 value: 10 } positions: { key: 13 value: 10 } positions: { key: 14 value: 10 } positions: { key: 15 value: 10 } macro_calls: { key: 15 value: { call_expr: { target: { id: 1 list_expr: { elements: { id: 2 const_expr: { int64_value: 1 } } } } function: "exists" args: { id: 4 ident_expr: { name: "x" } } args: { id: 6 call_expr: { function: "_==_" args: { id: 5 ident_expr: { name: "x" } } args: { id: 7 const_expr: { int64_value: 1 } } } } } } } } expr: { id: 15 comprehension_expr: { iter_var: "x" iter_range: { id: 1 list_expr: { elements: { id: 2 const_expr: { int64_value: 1 } } } } accu_var: "__result__" accu_init: { id: 8 const_expr: { bool_value: false } } loop_condition: { id: 11 call_expr: { function: "@not_strictly_false" args: { id: 10 call_expr: { function: "!_" args: { id: 9 ident_expr: { name: "__result__" } } } } } } loop_step: { id: 13 call_expr: { function: "_||_" args: { id: 12 ident_expr: { name: "__result__" } } args: { id: 6 call_expr: { function: "_==_" args: { id: 5 ident_expr: { name: "x" } } args: { id: 7 const_expr: { int64_value: 1 } } } } } } result: { id: 14 ident_expr: { name: "__result__" } } } } ================================================ FILE: tools/testdata/flatbuffers.fbs ================================================ namespace google.api.expr; table Entry { f_string:string; f_int:int; } table IndexedEntry { f_string:string (key); f_int:int; } table TestBuffer { f_byte:byte; f_ubyte:ubyte; f_short:short = 150; f_ushort:ushort; f_int:int; f_uint:uint; f_long:long; f_ulong:ulong; f_float:float; f_double:double; f_bool:bool = true; f_string:string; f_obj:Entry; r_byte:[byte]; r_ubyte:[ubyte]; r_short:[short]; r_ushort:[ushort]; r_int:[int]; r_uint:[uint]; r_long:[long]; r_ulong:[ulong]; r_float:[float]; r_double:[double]; r_bool:[bool]; r_string:[string]; r_obj:[Entry]; r_indexed:[IndexedEntry]; } root_type TestBuffer; ================================================ FILE: tools/testdata/macro_multiple_references.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr # has(msg.old_field) || has(msg.old_field) || # math.least(msg.old_field, msg.old_field) < 0 reference_map: { key: 2 value: { name: "msg" } } reference_map: { key: 6 value: { name: "msg" } } reference_map: { key: 9 value: { overload_id: "logical_or" } } reference_map: { key: 12 value: { name: "msg" } } reference_map: { key: 14 value: { name: "msg" } } reference_map: { key: 16 value: { overload_id: "math_@min_int_int" } } reference_map: { key: 17 value: { overload_id: "less_int64" } } reference_map: { key: 19 value: { overload_id: "logical_or" } } type_map: { key: 2 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: INT64 } } } } type_map: { key: 4 value: { primitive: BOOL } } type_map: { key: 6 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: INT64 } } } } type_map: { key: 8 value: { primitive: BOOL } } type_map: { key: 9 value: { primitive: BOOL } } type_map: { key: 12 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: INT64 } } } } type_map: { key: 13 value: { primitive: INT64 } } type_map: { key: 14 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: INT64 } } } } type_map: { key: 15 value: { primitive: INT64 } } type_map: { key: 16 value: { primitive: INT64 } } type_map: { key: 17 value: { primitive: BOOL } } type_map: { key: 18 value: { primitive: INT64 } } type_map: { key: 19 value: { primitive: BOOL } } source_info: { location: "" line_offsets: 89 positions: { key: 1 value: 3 } positions: { key: 2 value: 4 } positions: { key: 3 value: 7 } positions: { key: 4 value: 3 } positions: { key: 5 value: 25 } positions: { key: 6 value: 26 } positions: { key: 7 value: 29 } positions: { key: 8 value: 25 } positions: { key: 9 value: 19 } positions: { key: 10 value: 44 } positions: { key: 11 value: 54 } positions: { key: 12 value: 55 } positions: { key: 13 value: 58 } positions: { key: 14 value: 70 } positions: { key: 15 value: 73 } positions: { key: 16 value: 54 } positions: { key: 17 value: 85 } positions: { key: 18 value: 87 } positions: { key: 19 value: 41 } macro_calls: { key: 4 value: { call_expr: { function: "has" args: { id: 3 select_expr: { operand: { id: 2 ident_expr: { name: "msg" } } field: "old_field" } } } } } macro_calls: { key: 8 value: { call_expr: { function: "has" args: { id: 7 select_expr: { operand: { id: 6 ident_expr: { name: "msg" } } field: "old_field" } } } } } macro_calls: { key: 16 value: { call_expr: { target: { id: 10 ident_expr: { name: "math" } } function: "least" args: { id: 13 select_expr: { operand: { id: 12 ident_expr: { name: "msg" } } field: "old_field" } } args: { id: 15 select_expr: { operand: { id: 14 ident_expr: { name: "msg" } } field: "old_field" } } } } } } expr: { id: 19 call_expr: { function: "_||_" args: { id: 9 call_expr: { function: "_||_" args: { id: 4 select_expr: { operand: { id: 2 ident_expr: { name: "msg" } } field: "old_field" test_only: true } } args: { id: 8 select_expr: { operand: { id: 6 ident_expr: { name: "msg" } } field: "old_field" test_only: true } } } } args: { id: 17 call_expr: { function: "_<_" args: { id: 16 call_expr: { function: "math.@min" args: { id: 13 select_expr: { operand: { id: 12 ident_expr: { name: "msg" } } field: "old_field" } } args: { id: 15 select_expr: { operand: { id: 14 ident_expr: { name: "msg" } } field: "old_field" } } } } args: { id: 18 const_expr: { int64_value: 0 } } } } } } ================================================ FILE: tools/testdata/macro_nested_macro_call.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr # math.least(has(msg.old_field) ? msg.old_field : 0, 1) reference_map: { key: 4 value: { name: "msg" } } reference_map: { key: 7 value: { overload_id: "conditional" } } reference_map: { key: 8 value: { name: "msg" } } reference_map: { key: 12 value: { overload_id: "math_@min_int_int" } } type_map: { key: 4 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: INT64 } } } } type_map: { key: 6 value: { primitive: BOOL } } type_map: { key: 7 value: { primitive: INT64 } } type_map: { key: 8 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: INT64 } } } } type_map: { key: 9 value: { primitive: INT64 } } type_map: { key: 10 value: { primitive: INT64 } } type_map: { key: 11 value: { primitive: INT64 } } type_map: { key: 12 value: { primitive: INT64 } } source_info: { location: "" line_offsets: 54 positions: { key: 1 value: 0 } positions: { key: 2 value: 10 } positions: { key: 3 value: 14 } positions: { key: 4 value: 15 } positions: { key: 5 value: 18 } positions: { key: 6 value: 14 } positions: { key: 7 value: 30 } positions: { key: 8 value: 32 } positions: { key: 9 value: 35 } positions: { key: 10 value: 48 } positions: { key: 11 value: 51 } positions: { key: 12 value: 10 } macro_calls: { key: 6 value: { call_expr: { function: "has" args: { id: 5 select_expr: { operand: { id: 4 ident_expr: { name: "msg" } } field: "old_field" } } } } } macro_calls: { key: 12 value: { call_expr: { target: { id: 1 ident_expr: { name: "math" } } function: "least" args: { id: 7 call_expr: { function: "_?_:_" args: { id: 6 } args: { id: 9 select_expr: { operand: { id: 8 ident_expr: { name: "msg" } } field: "old_field" } } args: { id: 10 const_expr: { int64_value: 0 } } } } args: { id: 11 const_expr: { int64_value: 1 } } } } } } expr: { id: 12 call_expr: { function: "math.@min" args: { id: 7 call_expr: { function: "_?_:_" args: { id: 6 select_expr: { operand: { id: 4 ident_expr: { name: "msg" } } field: "old_field" test_only: true } } args: { id: 9 select_expr: { operand: { id: 8 ident_expr: { name: "msg" } } field: "old_field" } } args: { id: 10 const_expr: { int64_value: 0 } } } } args: { id: 11 const_expr: { int64_value: 1 } } } } ================================================ FILE: tools/testdata/macro_single_reference.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr # has(msg.old_field) reference_map: { key: 2 value: { name: "msg" } } type_map: { key: 2 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: STRING } } } } type_map: { key: 4 value: { primitive: BOOL } } source_info: { location: "" line_offsets: 15 positions: { key: 1 value: 3 } positions: { key: 2 value: 4 } positions: { key: 3 value: 7 } positions: { key: 4 value: 3 } macro_calls: { key: 4 value: { call_expr: { function: "has" args: { id: 3 select_expr: { operand: { id: 2 ident_expr: { name: "msg" } } field: "old_field" } } } } } } expr: { id: 4 select_expr: { operand: { id: 2 ident_expr: { name: "msg" } } field: "old_field" test_only: true } } ================================================ FILE: tools/testdata/msg_new_field.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr # msg.new_field reference_map: { key: 1 value: { name: "msg" } } type_map: { key: 1 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: STRING } } } } type_map: { key: 2 value: { primitive: STRING } } source_info: { location: "" line_offsets: 10 positions: { key: 1 value: 0 } positions: { key: 2 value: 3 } } expr: { id: 2 select_expr: { operand: { id: 1 ident_expr: { name: "msg" } } field: "new_field" } } ================================================ FILE: tools/testdata/msg_new_field_int.textproto ================================================ # proto-file: google3/google/api/expr/checked.proto # proto-message: CheckedExpr # msg.new_field reference_map: { key: 1 value: { name: "msg" } } type_map: { key: 1 value: { map_type: { key_type: { primitive: STRING } value_type: { primitive: INT64 } } } } type_map: { key: 2 value: { primitive: INT64 } } source_info: { location: "" line_offsets: 14 positions: { key: 1 value: 0 } positions: { key: 2 value: 3 } } expr: { id: 2 select_expr: { operand: { id: 1 ident_expr: { name: "msg" } } field: "new_field" } } ================================================ FILE: validator/BUILD ================================================ # Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") package(default_visibility = ["//visibility:public"]) cc_library( name = "validator", srcs = ["validator.cc"], hdrs = ["validator.h"], deps = [ "//checker:type_check_issue", "//checker:validation_result", "//common:ast", "//common:navigable_ast", "//common:source", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) cc_test( name = "validator_test", srcs = ["validator_test.cc"], deps = [ ":validator", "//checker:type_check_issue", "//common:ast", "//common:expr", "//common:source", "//internal:testing", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "timestamp_literal_validator_test", srcs = ["timestamp_literal_validator_test.cc"], deps = [ ":timestamp_literal_validator", ":validator", "//checker:validation_result", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) cc_library( name = "timestamp_literal_validator", srcs = ["timestamp_literal_validator.cc"], hdrs = ["timestamp_literal_validator.h"], deps = [ ":validator", "//common:constant", "//common:navigable_ast", "//common:standard_definitions", "//internal:time", "//tools:navigable_ast", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", ], ) cc_library( name = "ast_depth_validator", srcs = ["ast_depth_validator.cc"], hdrs = ["ast_depth_validator.h"], deps = [ ":validator", "@com_google_absl//absl/strings", ], ) cc_library( name = "homogeneous_literal_validator", srcs = ["homogeneous_literal_validator.cc"], hdrs = ["homogeneous_literal_validator.h"], deps = [ ":validator", "//common:ast", "//common:expr", "//common:navigable_ast", "@com_google_absl//absl/strings", ], ) cc_library( name = "regex_validator", srcs = ["regex_validator.cc"], hdrs = ["regex_validator.h"], deps = [ ":validator", "//common:ast", "//common:constant", "//common:expr", "//common:navigable_ast", "//common:standard_definitions", "//internal:re2_options", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", ], ) cc_test( name = "homogeneous_literal_validator_test", srcs = ["homogeneous_literal_validator_test.cc"], deps = [ ":homogeneous_literal_validator", ":validator", "//checker:validation_result", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:optional", "//compiler:standard_library", "//extensions:strings", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) cc_test( name = "ast_depth_validator_test", srcs = ["ast_depth_validator_test.cc"], deps = [ ":ast_depth_validator", ":validator", "//checker:type_check_issue", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/log:absl_check", ], ) cc_test( name = "regex_validator_test", srcs = ["regex_validator_test.cc"], deps = [ ":regex_validator", ":validator", "//common:decl", "//common:type", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status:statusor", ], ) cc_library( name = "comprehension_nesting_validator", srcs = ["comprehension_nesting_validator.cc"], hdrs = ["comprehension_nesting_validator.h"], deps = [ ":validator", "//common:expr", "//common:navigable_ast", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) cc_test( name = "comprehension_nesting_validator_test", srcs = ["comprehension_nesting_validator_test.cc"], deps = [ ":comprehension_nesting_validator", ":validator", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//extensions:bindings_ext", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/status:statusor", ], ) licenses(["notice"]) ================================================ FILE: validator/ast_depth_validator.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/ast_depth_validator.h" #include "absl/strings/str_cat.h" #include "validator/validator.h" namespace cel { Validation AstDepthValidator(int max_depth) { return Validation([max_depth](ValidationContext& context) { int height = context.navigable_ast().Root().height(); if (height > max_depth) { context.ReportError(absl::StrCat("AST depth ", height, " exceeds maximum of ", max_depth)); return false; } return true; }); } } // namespace cel ================================================ FILE: validator/ast_depth_validator.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ #define THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ #include "validator/validator.h" namespace cel { // Returns a `Validation` that checks the AST depth is less than or equal to // max_depth. Validation AstDepthValidator(int max_depth); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ ================================================ FILE: validator/ast_depth_validator_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/ast_depth_validator.h" #include #include #include "absl/log/absl_check.h" #include "checker/type_check_issue.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "validator/validator.h" namespace cel { namespace { std::unique_ptr CreateCompiler() { auto builder = NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()); ABSL_CHECK_OK(builder); ABSL_CHECK_OK((*builder)->AddLibrary(StandardCompilerLibrary())); auto compiler = (*builder)->Build(); ABSL_CHECK_OK(compiler); return *std::move(compiler); } TEST(AstDepthValidatorTest, Basic) { auto compiler = CreateCompiler(); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("1 + 2 + 3")); Validator validator; validator.AddValidation(AstDepthValidator(10)); auto output = validator.Validate(*result.GetAst()); EXPECT_TRUE(output.valid); Validator validator2; validator2.AddValidation(AstDepthValidator(2)); output = validator2.Validate(*result.GetAst()); EXPECT_FALSE(output.valid); EXPECT_THAT(output.issues, testing::Contains(testing::Property( &TypeCheckIssue::message, testing::Eq("AST depth 3 exceeds maximum of 2")))); } TEST(AstDepthValidatorTest, Nested) { auto compiler = CreateCompiler(); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("1 + (2 + (3 + (4 + 5)))")); Validator validator; validator.AddValidation(AstDepthValidator(10)); auto output = validator.Validate(*result.GetAst()); EXPECT_TRUE(output.valid); Validator validator2; validator2.AddValidation(AstDepthValidator(4)); output = validator2.Validate(*result.GetAst()); EXPECT_FALSE(output.valid); EXPECT_THAT(output.issues, testing::Contains(testing::Property( &TypeCheckIssue::message, testing::Eq("AST depth 5 exceeds maximum of 4")))); } } // namespace } // namespace cel ================================================ FILE: validator/comprehension_nesting_validator.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/comprehension_nesting_validator.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "common/expr.h" #include "common/navigable_ast.h" #include "validator/validator.h" namespace cel { namespace { bool IsEmptyRangeComprehension(const NavigableAstNode& node) { ABSL_DCHECK(node.expr()->has_comprehension_expr()); const auto& comp = node.expr()->comprehension_expr(); return comp.has_iter_range() && comp.iter_range().has_list_expr() && comp.iter_range().list_expr().elements().empty(); } } // namespace Validation ComprehensionNestingLimitValidator(int limit) { return Validation( [limit](ValidationContext& context) -> bool { bool is_valid = true; for (const auto& node : context.navigable_ast().Root().DescendantsPostorder()) { if (node.node_kind() != NodeKind::kComprehension) { continue; } if (IsEmptyRangeComprehension(node)) { continue; } int count = 0; const NavigableAstNode* current = &node; while (current != nullptr) { if (current->node_kind() == NodeKind::kComprehension && !IsEmptyRangeComprehension(*current)) { count++; } current = current->parent(); } if (count > limit) { context.ReportErrorAt( node.expr()->id(), absl::StrCat("comprehension nesting level of ", count, " exceeds limit of ", limit)); is_valid = false; break; } } return is_valid; }, "cel.validator.comprehension_nesting_limit"); } } // namespace cel ================================================ FILE: validator/comprehension_nesting_validator.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ #define THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ #include "validator/validator.h" namespace cel { // Returns a `Validation` that checks that comprehensions are not nested beyond // the specified limit. // // Comprehensions with an empty iteration range (e.g. `cel.bind`) do not count // towards the nesting limit. Validation ComprehensionNestingLimitValidator(int limit); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ ================================================ FILE: validator/comprehension_nesting_validator_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/comprehension_nesting_validator.h" #include #include #include #include "absl/status/statusor.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "extensions/bindings_ext.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "validator/validator.h" namespace cel { namespace { using ::testing::HasSubstr; absl::StatusOr> StdLibCompiler() { CEL_ASSIGN_OR_RETURN( auto builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCompilerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); return builder->Build(); } struct TestCase { std::string expression; int limit; bool valid; std::string error_substr = ""; }; using ComprehensionNestingValidatorTest = testing::TestWithParam; TEST_P(ComprehensionNestingValidatorTest, Validate) { const auto& test_case = GetParam(); Validator validator; validator.AddValidation(ComprehensionNestingLimitValidator(test_case.limit)); ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); auto result_or = compiler->Compile(test_case.expression); if (!result_or.ok()) { GTEST_SKIP() << "Expression failed to compile: " << test_case.expression << " " << result_or.status().message(); } auto result = std::move(result_or).value(); validator.UpdateValidationResult(result); EXPECT_EQ(result.IsValid(), test_case.valid) << "Expression: " << test_case.expression << " Limit: " << test_case.limit; if (!test_case.valid) { EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); } } INSTANTIATE_TEST_SUITE_P( ComprehensionNestingValidatorTest, ComprehensionNestingValidatorTest, testing::Values( TestCase{"[1, 2].all(x, x > 0)", 1, true}, TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 1, false, "comprehension nesting level of 2 exceeds limit of 1"}, TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 2, true}, // Empty range comprehension (does not count) TestCase{"[].all(x, [1, 2].all(y, y > 0))", 1, true}, TestCase{"cel.bind(x, [1, 2].all(y, y > 0), [1, 2].all(z, z > 0))", 1, true}, // Nested empty range comprehensions TestCase{"[].all(x, [].all(y, true))", 0, true}, // Deeply nested mixed TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 1, false, "comprehension nesting level of 2 exceeds limit of 1"}, TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 2, true})); } // namespace } // namespace cel ================================================ FILE: validator/homogeneous_literal_validator.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/homogeneous_literal_validator.h" #include #include #include #include #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/ast.h" #include "common/expr.h" #include "common/navigable_ast.h" #include "validator/validator.h" namespace cel { namespace { bool InExemptFunction(const NavigableAstNode& node, const std::vector& exempt_functions) { const NavigableAstNode* parent = node.parent(); while (parent != nullptr) { if (parent->node_kind() == NodeKind::kCall) { absl::string_view fn_name = parent->expr()->call_expr().function(); for (const auto& exempt : exempt_functions) { if (exempt == fn_name) { return true; } } } parent = parent->parent(); } return false; } bool IsOptional(const TypeSpec& t) { return t.has_abstract_type() && t.abstract_type().name() == "optional_type"; } const TypeSpec& GetOptionalParameter(const TypeSpec& t) { return t.abstract_type().parameter_types()[0]; } void TypeMismatch(ValidationContext& context, int64_t id, const TypeSpec& expected, const TypeSpec& actual) { context.ReportErrorAt( id, absl::StrCat("expected type '", FormatTypeSpec(expected), "' but found '", FormatTypeSpec(actual), "'")); } bool TypeEquiv(const TypeSpec& a, const TypeSpec& b) { if (a == b) { return true; } if (a.has_error() || b.has_error()) { // Don't report mismatch if there's an error (type checking failed for the // expression). return true; } if (a.has_wrapper() && b.has_primitive()) { return a.wrapper() == b.primitive(); } else if (a.has_primitive() && b.has_wrapper()) { return a.primitive() == b.wrapper(); } if (a.has_list_type() && b.has_list_type()) { return TypeEquiv(a.list_type().elem_type(), b.list_type().elem_type()); } if (a.has_map_type() && b.has_map_type()) { return TypeEquiv(a.map_type().key_type(), b.map_type().key_type()) && TypeEquiv(a.map_type().value_type(), b.map_type().value_type()); } if (a.has_abstract_type() && b.has_abstract_type() && a.abstract_type().name() == b.abstract_type().name() && a.abstract_type().parameter_types().size() == b.abstract_type().parameter_types().size()) { for (int i = 0; i < a.abstract_type().parameter_types().size(); ++i) { if (!TypeEquiv(a.abstract_type().parameter_types()[i], b.abstract_type().parameter_types()[i])) { return false; } } return true; } return false; } } // namespace Validation HomogeneousLiteralValidator( std::vector exempt_functions) { return Validation([exempt_functions = std::move(exempt_functions)]( ValidationContext& context) -> bool { bool valid = true; for (const auto& node : context.navigable_ast().Root().DescendantsPostorder()) { if (node.node_kind() == NodeKind::kList) { if (InExemptFunction(node, exempt_functions)) { continue; } const auto& list_expr = node.expr()->list_expr(); const auto& elements = list_expr.elements(); const TypeSpec* expected_type = nullptr; for (const auto& element : elements) { int64_t id = element.expr().id(); const TypeSpec& actual_type = context.ast().GetTypeOrDyn(id); const TypeSpec* type_to_check = &actual_type; if (element.optional() && IsOptional(actual_type)) { type_to_check = &GetOptionalParameter(actual_type); } if (expected_type == nullptr) { expected_type = type_to_check; continue; } if (!(TypeEquiv(*expected_type, *type_to_check))) { TypeMismatch(context, id, *expected_type, *type_to_check); valid = false; break; } } } else if (node.node_kind() == NodeKind::kMap) { if (InExemptFunction(node, exempt_functions)) { continue; } const auto& map_expr = node.expr()->map_expr(); const auto& entries = map_expr.entries(); const TypeSpec* expected_key_type = nullptr; const TypeSpec* expected_value_type = nullptr; for (const auto& entry : entries) { int64_t key_id = entry.key().id(); int64_t val_id = entry.value().id(); const TypeSpec& actual_key_type = context.ast().GetTypeOrDyn(key_id); const TypeSpec& actual_val_type = context.ast().GetTypeOrDyn(val_id); const TypeSpec* key_type_to_check = &actual_key_type; const TypeSpec* val_type_to_check = &actual_val_type; if (entry.optional() && IsOptional(actual_val_type)) { val_type_to_check = &GetOptionalParameter(actual_val_type); } if (expected_key_type == nullptr) { expected_key_type = key_type_to_check; expected_value_type = val_type_to_check; continue; } if (!(TypeEquiv(*expected_key_type, *key_type_to_check))) { TypeMismatch(context, key_id, *expected_key_type, *key_type_to_check); valid = false; break; } if (!(TypeEquiv(*expected_value_type, *val_type_to_check))) { TypeMismatch(context, val_id, *expected_value_type, *val_type_to_check); valid = false; break; } } } } return valid; }); } } // namespace cel ================================================ FILE: validator/homogeneous_literal_validator.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ #define THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ #include #include #include "validator/validator.h" namespace cel { // Returns a `Validation` that checks that all literals in map or list literals // are the same type. If the list or map is part of an argument to an exempted // function, it is not checked. Validation HomogeneousLiteralValidator( std::vector exempt_functions); inline Validation HomogeneousLiteralValidator() { // Default to exempting the strings extension "format" function. return HomogeneousLiteralValidator({"format"}); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ ================================================ FILE: validator/homogeneous_literal_validator_test.cc ================================================ // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/homogeneous_literal_validator.h" #include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/validation_result.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/optional.h" #include "compiler/standard_library.h" #include "extensions/strings.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "validator/validator.h" namespace cel { namespace { using ::testing::HasSubstr; absl::StatusOr> StdLibCompiler() { CEL_ASSIGN_OR_RETURN( auto builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); builder->AddLibrary(OptionalCompilerLibrary()).IgnoreError(); builder->AddLibrary(extensions::StringsCompilerLibrary()).IgnoreError(); cel::Type message_type = cel::Type::Message( builder->GetCheckerBuilder().descriptor_pool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")); CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("msg", message_type))); return builder->Build(); } struct TestCase { std::string expression; bool valid; std::string error_substr = ""; }; using HomogeneousLiteralValidatorTest = testing::TestWithParam; TEST_P(HomogeneousLiteralValidatorTest, Validate) { const auto& test_case = GetParam(); Validator validator; validator.AddValidation(HomogeneousLiteralValidator()); ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); validator.UpdateValidationResult(result); EXPECT_EQ(result.IsValid(), test_case.valid); if (!test_case.valid) { EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); } } INSTANTIATE_TEST_SUITE_P( HomogeneousLiteralValidatorTest, HomogeneousLiteralValidatorTest, testing::Values( // Lists TestCase{"[1, 2, 3]", true}, TestCase{"['a', 'b', 'c']", true}, TestCase{"[1, 'a']", false, "expected type 'int' but found 'string'"}, TestCase{"[1, 2, 'a']", false, "expected type 'int' but found 'string'"}, TestCase{"[[1], [2]]", true}, TestCase{"[[1], ['a']]", false, "expected type 'list(int)' but found 'list(string)'"}, // Dyn casts TestCase{"[dyn(1), dyn('a')]", true, ""}, TestCase{"[dyn(1), 2]", false, "expected type 'dyn' but found 'int'"}, // Maps TestCase{"{1: 'a', 2: 'b'}", true}, TestCase{"{'a': 1, 'b': 2}", true}, TestCase{"{1: 'a', 'b': 2}", false, "expected type 'int' but found 'string'"}, TestCase{"{1: 'a', 2: 3}", false, "expected type 'string' but found 'int'"}, // Optionals TestCase{"[optional.of(1), optional.of(2)]", true}, TestCase{"[optional.of(1), optional.of('b')]", false, "expected type 'optional_type(int)' but found " "'optional_type(string)'"}, TestCase{"[?optional.of(1), ?optional.of(2)]", true}, TestCase{"[?optional.of(1), ?optional.of('a')]", false, "expected type 'int' but found 'string'"}, TestCase{"{?1: optional.of('a'), ?2: optional.none()}", true}, TestCase{"{?1: optional.of('a'), ?2: optional.of(1)}", false, "expected type 'string' but found 'int'"}, // Exempted Functions TestCase{"'%v %v'.format([1, 'a'])", true}, // Mixed Primitives and Wrappers TestCase{"[1, msg.single_int64_wrapper]", true}, TestCase{"[msg.single_int64_wrapper, 1]", true}, TestCase{"['foo', msg.single_string_wrapper]", true}, TestCase{"[msg.single_string_wrapper, 'foo']", true}, TestCase{"{1: msg.single_int64_wrapper, 2: 3}", true}, TestCase{"{1: 2, 2: msg.single_int64_wrapper}", true}, TestCase{"[[1], [msg.single_int64_wrapper]]", true}, TestCase{"[optional.of(1), optional.of(msg.single_int64_wrapper)]", true}, TestCase{"[1, msg.single_string_wrapper]", false, "expected type 'int' but found 'wrapper(string)'"}, TestCase{"[msg.single_int64_wrapper, 'foo']", false, "expected type 'wrapper(int)' but found 'string'"}, TestCase{"[msg.single_int64_wrapper, msg.single_string_wrapper]", false, "expected type 'wrapper(int)' but found 'wrapper(string)'"}, // Nested TestCase{"[1, [2, 'a']]", false, "expected type 'int' but found 'string'"}, TestCase{"[[1, 2], [3, 4]]", true, ""}, TestCase{"[{1: 2}, {'foo': 3}]", false, "expected type 'map(int, int)' but found 'map(string, int)'"}, TestCase{"[{1: 2}, {3: 'foo'}]", false, "expected type 'map(int, int)' but found 'map(int, string)'"}, TestCase{"[{1: 2}, {3: 4}]", true, ""})); } // namespace } // namespace cel ================================================ FILE: validator/regex_validator.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/regex_validator.h" #include #include #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/constant.h" #include "common/expr.h" #include "common/navigable_ast.h" #include "internal/re2_options.h" #include "validator/validator.h" #include "re2/re2.h" namespace cel { namespace { bool CheckPattern(ValidationContext& context, const NavigableAstNode& node, int arg_index) { ABSL_DCHECK(node.expr()->has_call_expr()); const auto& call_expr = node.expr()->call_expr(); const Expr* pattern_expr = nullptr; if (call_expr.has_target()) { if (arg_index == 0) { pattern_expr = &call_expr.target(); } else if (call_expr.args().size() > arg_index - 1) { pattern_expr = &call_expr.args()[arg_index - 1]; } } else if (call_expr.args().size() > arg_index) { pattern_expr = &call_expr.args()[arg_index]; } if (pattern_expr == nullptr || !pattern_expr->has_const_expr()) { return true; } const auto& const_expr = pattern_expr->const_expr(); if (!const_expr.has_string_value()) { return true; } absl::string_view pattern_string = const_expr.string_value(); RE2 re(pattern_string, internal::MakeRE2Options()); if (!re.ok()) { context.ReportErrorAt( pattern_expr->id(), absl::StrCat("invalid regular expression: ", re.error())); return false; } return true; } } // namespace Validation RegexPatternValidator( absl::string_view id, std::vector config) { return Validation( [config = std::move(config)](ValidationContext& context) -> bool { bool result = true; for (const auto& node : context.navigable_ast().Root().DescendantsPostorder()) { if (node.node_kind() == NodeKind::kCall) { for (const auto& config : config) { if (node.expr()->call_expr().function() == config.function_name) { if (!CheckPattern(context, node, config.pattern_arg_index)) { result = false; } break; } } } } return result; }, id); } } // namespace cel ================================================ FILE: validator/regex_validator.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ #define THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ #include #include #include "absl/strings/string_view.h" #include "common/standard_definitions.h" #include "validator/validator.h" namespace cel { // Configuration for the regex pattern validator. struct RegexPatternValidatorConfig { // The resolved function name. std::string function_name; // the index of the pattern argument (counting the receiver as arg 0 if // present). int pattern_arg_index; }; // Returns a `Validation` that checks all calls to the given regex functions // It validates that the specified argument is a valid regular expression if it // is a literal string. Validation RegexPatternValidator( absl::string_view id, std::vector config); // Returns a `Validation` that checks all calls to the CEL `matches` function. // It validates that if the pattern is a literal string, it is a valid regular // expression. inline Validation MatchesValidator() { return RegexPatternValidator( "cel.validator.matches", {{std::string(StandardFunctions::kRegexMatch), 1}}); } } // namespace cel #endif // THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ ================================================ FILE: validator/regex_validator_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/regex_validator.h" #include #include #include "absl/status/statusor.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "validator/validator.h" namespace cel { namespace { using ::testing::HasSubstr; absl::StatusOr> StdLibCompiler() { CEL_ASSIGN_OR_RETURN( auto builder, NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( MakeVariableDecl("p", StringType()))); return builder->Build(); } struct TestCase { std::string expression; bool valid; std::string error_substr = ""; }; using MatchesValidatorTest = testing::TestWithParam; TEST_P(MatchesValidatorTest, Validate) { const auto& test_case = GetParam(); Validator validator; validator.AddValidation(MatchesValidator()); ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); validator.UpdateValidationResult(result); EXPECT_EQ(result.IsValid(), test_case.valid) << "Expression: " << test_case.expression; if (!test_case.valid) { EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); } } INSTANTIATE_TEST_SUITE_P( MatchesValidatorTest, MatchesValidatorTest, testing::Values( // Member calls TestCase{"'hello'.matches('h.*')", true}, TestCase{"'hello'.matches('h[')", false, "invalid regular expression"}, TestCase{"'hello'.matches('h(a|b)')", true}, TestCase{"'hello'.matches('h(a|b')", false, "invalid regular expression"}, // Global calls TestCase{"matches('hello', 'h.*')", true}, TestCase{"matches('hello', 'h[')", false, "invalid regular expression"}, // Non-literal patterns (should not report regex errors) TestCase{"'hello'.matches(p)", true}, TestCase{"'hello'.matches('h' + 'ello')", true}, TestCase{"'hello'.matches(dyn(1))", true}, // Empty pattern TestCase{"'hello'.matches('')", true})); } // namespace } // namespace cel ================================================ FILE: validator/timestamp_literal_validator.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/timestamp_literal_validator.h" #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "common/constant.h" #include "common/navigable_ast.h" #include "common/standard_definitions.h" #include "internal/time.h" #include "tools/navigable_ast.h" #include "validator/validator.h" namespace cel { namespace { bool ValidateTimestamps(ValidationContext& context) { bool valid = true; for (const auto& node : context.navigable_ast().Root().DescendantsPostorder()) { if (node.node_kind() != NodeKind::kCall || node.expr()->call_expr().function() != StandardFunctions::kTimestamp) { continue; } if (node.children().size() != 1) { // Checker should have already reported an error. continue; } const NavigableAstNode& child = *node.children()[0]; if (child.node_kind() != NodeKind::kConstant) { // Not a literal, so nothing to do. continue; } absl::Time ts; const Constant& constant = child.expr()->const_expr(); if (constant.has_string_value()) { absl::string_view timestamp_str = child.expr()->const_expr().string_value(); if (!absl::ParseTime(absl::RFC3339_full, timestamp_str, &ts, nullptr)) { context.ReportErrorAt(child.expr()->id(), "invalid timestamp literal"); valid = false; continue; } } else if (constant.has_int_value()) { ts = absl::FromUnixSeconds(constant.int_value()); } else { // Checker should have already reported an error. continue; } if (absl::Status status = internal::ValidateTimestamp(ts); !status.ok()) { context.ReportErrorAt( child.expr()->id(), absl::StrCat("invalid timestamp literal: ", status.message())); valid = false; } } return valid; } bool ValidateDurations(ValidationContext& context) { bool valid = true; for (const auto& node : context.navigable_ast().Root().DescendantsPostorder()) { if (node.node_kind() != NodeKind::kCall || node.expr()->call_expr().function() != StandardFunctions::kDuration) { continue; } if (node.children().size() != 1) { // Checker should have already reported an error. continue; } const NavigableAstNode& child = *node.children()[0]; if (child.node_kind() != NodeKind::kConstant) { // Not a literal, so nothing to do. continue; } const Constant& constant = child.expr()->const_expr(); if (!constant.has_string_value()) { continue; } absl::Duration duration; absl::string_view duration_str = child.expr()->const_expr().string_value(); if (!absl::ParseDuration(duration_str, &duration)) { context.ReportErrorAt(child.expr()->id(), "invalid duration literal"); valid = false; continue; } if (absl::Status status = internal::ValidateDuration(duration); !status.ok()) { context.ReportErrorAt( child.expr()->id(), absl::StrCat("invalid duration literal: ", status.message())); valid = false; } } return valid; } } // namespace const Validation& TimestampLiteralValidator() { static const absl::NoDestructor kInstance( ValidateTimestamps, "cel.validator.timestamp"); return *kInstance; } // Returns a validator that checks duration literals. const Validation& DurationLiteralValidator() { static const absl::NoDestructor kInstance( ValidateDurations, "cel.validator.duration"); return *kInstance; } } // namespace cel ================================================ FILE: validator/timestamp_literal_validator.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ #define THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ #include "validator/validator.h" namespace cel { // Returns a `Validation` that checks timestamp literals are valid for CEL. const Validation& TimestampLiteralValidator(); // Returns a `Validation` that checks duration literals are valid for CEL. const Validation& DurationLiteralValidator(); } // namespace cel #endif // THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ ================================================ FILE: validator/timestamp_literal_validator_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/timestamp_literal_validator.h" #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "checker/validation_result.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "validator/validator.h" namespace cel { namespace { using ::testing::HasSubstr; absl::StatusOr> StdLibCompiler() { auto builder = NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()).value(); builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); return builder->Build(); } class TimestampLiteralValidatorTest : public ::testing::Test { protected: TimestampLiteralValidatorTest() { validator_.AddValidation(TimestampLiteralValidator()); } std::unique_ptr compiler_; Validator validator_; }; TEST(TimestampLiteralValidatorTest, FormatsIssues) { Validator validator; validator.AddValidation(TimestampLiteralValidator()); ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, compiler->Compile("timestamp('invalid')")); validator.UpdateValidationResult(result); EXPECT_FALSE(result.IsValid()); EXPECT_EQ(result.FormatError(), R"(ERROR: :1:11: invalid timestamp literal | timestamp('invalid') | ..........^)"); } TEST(TimestampLiteralValidatorTest, AccumulatesIssues) { Validator validator; validator.AddValidation(TimestampLiteralValidator()); validator.AddValidation(DurationLiteralValidator()); constexpr absl::string_view kExpression = R"cel( [ timestamp('invalid'), timestamp('9999-12-31T23:59:59Z'), timestamp('10000-01-01T00:00:00Z') ].all(t, t - timestamp(0) < duration('10000s') && t - timestamp(0) > duration("invalid") ))cel"; ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, compiler->Compile(kExpression)); validator.UpdateValidationResult(result); EXPECT_FALSE(result.IsValid()); EXPECT_THAT(result.FormatError(), AllOf(HasSubstr("2:17: invalid timestamp literal"), HasSubstr("4:17: invalid timestamp literal"), HasSubstr("7:35: invalid duration literal"))); } struct TestCase { std::string expression; bool valid; std::string error_substr = ""; }; using TimestampLiteralValidatorParameterizedTest = testing::TestWithParam; TEST_P(TimestampLiteralValidatorParameterizedTest, Validate) { const auto& test_case = GetParam(); Validator validator; validator.AddValidation(TimestampLiteralValidator()); validator.AddValidation(DurationLiteralValidator()); ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); validator.UpdateValidationResult(result); EXPECT_EQ(result.IsValid(), test_case.valid); if (!test_case.valid) { EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); } } INSTANTIATE_TEST_SUITE_P( TimestampLiteralValidatorParameterizedTest, TimestampLiteralValidatorParameterizedTest, ::testing::Values( TestCase{"timestamp('2023-01-01T00:00:00Z')", true}, TestCase{"timestamp('9999-12-31T23:59:59Z')", true}, TestCase{"timestamp('invalid')", false, "invalid timestamp literal"}, TestCase{"timestamp('10000-01-01T00:00:00Z')", false, "invalid timestamp literal"}, TestCase{"timestamp(0)", true}, TestCase{"timestamp(-62135596801)", false, "invalid timestamp literal: Timestamp \"0-12-31T23:59:59Z\" " "below minimum allowed timestamp \"1-01-01T00:00:00Z\""}, TestCase{"timestamp(253402300800)", false, "invalid timestamp literal: Timestamp " "\"10000-01-01T00:00:00Z\" above maximum allowed timestamp " "\"9999-12-31T23:59:59.999999999Z\""}, TestCase{"duration('1s')", true}, TestCase{"duration('invalid')", false, "invalid duration literal"}, TestCase{"duration('-1000000000000s')", false, "below minimum allowed duration"}, TestCase{"duration('1000000000000s')", false, "above maximum allowed duration"})); } // namespace } // namespace cel ================================================ FILE: validator/validator.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/validator.h" #include #include #include #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/source.h" namespace cel { void Validator::AddValidation(Validation validation) { ABSL_DCHECK(validation); if (!validation) return; validations_.push_back(std::move(validation)); } Validator::ValidationOutput Validator::Validate(const Ast& ast) const { ValidationOutput result; ValidationContext context(ast); for (const auto& validation : validations_) { if (!validation(context)) { result.valid = false; } } result.issues = context.ReleaseIssues(); return result; } void Validator::UpdateValidationResult(ValidationResult& in) const { if (!in.IsValid() || in.GetAst() == nullptr) { // If the result is already decided invalid, just return it. return; } auto result = Validate(*in.GetAst()); if (!result.valid) { in.ReleaseAst().IgnoreError(); } for (auto& issue : result.issues) { in.AddIssue(std::move(issue)); } } void ValidationContext::ReportWarningAt(int64_t id, absl::string_view message) { issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, ast_.ComputeSourceLocation(id), std::string(message))); } void ValidationContext::ReportErrorAt(int64_t id, absl::string_view message) { issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, ast_.ComputeSourceLocation(id), std::string(message))); } void ValidationContext::ReportWarning(absl::string_view message) { issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, SourceLocation{}, std::string(message))); } void ValidationContext::ReportError(absl::string_view message) { issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, SourceLocation{}, std::string(message))); } } // namespace cel ================================================ FILE: validator/validator.h ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ #define THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/functional/any_invocable.h" #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/navigable_ast.h" namespace cel { // Context for a validation pass. // // Assumed to be scoped to a Validator::Validate() call. Instances must not // outlive the `ast` passed to the constructor. class ValidationContext { public: explicit ValidationContext(const Ast& ast ABSL_ATTRIBUTE_LIFETIME_BOUND) : ast_(ast) {} const Ast& ast() const { return ast_; } const NavigableAst& navigable_ast() const { if (!navigable_ast_) { navigable_ast_ = NavigableAst::Build(ast_.root_expr()); } return navigable_ast_; } void ReportWarningAt(int64_t id, absl::string_view message); void ReportErrorAt(int64_t id, absl::string_view message); void ReportWarning(absl::string_view message); void ReportError(absl::string_view message); std::vector ReleaseIssues() { auto out = std::move(issues_); issues_.clear(); return out; } private: const Ast& ast_; mutable NavigableAst navigable_ast_; std::vector issues_; }; // A single validation to apply to an AST. // // May be empty if default constructed or moved from. // use operator bool() to check if the validation is empty. class Validation { public: // Tests the AST reports any issues to the context. // // Returns false if the AST is invalid. // // The same instance is used across Validate() so must be thread safe // (typically stateless). using ImplFunction = absl::AnyInvocable; Validation() = default; explicit Validation(ImplFunction impl); Validation(ImplFunction impl, absl::string_view id); const ImplFunction& impl() const { ABSL_DCHECK(rep_ != nullptr); return rep_->impl; } absl::string_view id() const { ABSL_DCHECK(rep_ != nullptr); return rep_->id; } bool operator()(ValidationContext& context) const { ABSL_DCHECK(rep_ != nullptr); return rep_->impl(context); } explicit operator bool() const { return rep_ != nullptr; } private: struct Rep { ImplFunction impl; // Optional id if supported in environment config. std::string id; }; std::shared_ptr rep_; }; // A validator checks a set of semantic rules for a given AST. class Validator { public: Validator() = default; void AddValidation(Validation validation); absl::Span validations() const { return validations_; } struct ValidationOutput { bool valid = true; std::vector issues; }; // Validates the given AST by applying all of the validations. ValidationOutput Validate(const Ast& ast) const; // Validates the given AST, updating the validation result in place. // // Used to apply validators to the output of the type checker. void UpdateValidationResult(ValidationResult& in) const; private: std::vector validations_; }; // Implementation details. inline Validation::Validation(ImplFunction impl) : rep_(std::make_shared( Validation::Rep{std::move(impl)})) {} inline Validation::Validation(ImplFunction impl, absl::string_view id) : rep_(std::make_shared( Validation::Rep{std::move(impl), std::string(id)})) {} } // namespace cel #endif // THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ ================================================ FILE: validator/validator_test.cc ================================================ // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "validator/validator.h" #include #include #include "absl/strings/string_view.h" #include "checker/type_check_issue.h" #include "common/ast.h" #include "common/expr.h" #include "common/source.h" #include "internal/testing.h" namespace cel { namespace { using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Property; TEST(ValidatorTest, AddValidationAndValidate) { Validator validator; validator.AddValidation(Validation([](ValidationContext& context) { context.ReportError("error 1"); return false; })); validator.AddValidation(Validation([](ValidationContext& context) { context.ReportWarning("warning 1"); return true; })); Ast ast; auto output = validator.Validate(ast); EXPECT_FALSE(output.valid); EXPECT_THAT(output.issues, ElementsAre(Property(&TypeCheckIssue::message, Eq("error 1")), Property(&TypeCheckIssue::message, Eq("warning 1")))); EXPECT_EQ(output.issues[0].severity(), TypeCheckIssue::Severity::kError); EXPECT_EQ(output.issues[1].severity(), TypeCheckIssue::Severity::kWarning); } TEST(ValidatorTest, ReportAt) { Validator validator; validator.AddValidation(Validation([](ValidationContext& context) { context.ReportErrorAt(1, "error at 1"); context.ReportWarningAt(2, "warning at 2"); return false; })); Expr expr; expr.set_id(1); SourceInfo source_info; source_info.mutable_positions()[1] = 10; source_info.mutable_positions()[2] = 20; source_info.set_line_offsets({15, 25}); Ast ast(std::move(expr), std::move(source_info)); auto output = validator.Validate(ast); EXPECT_FALSE(output.valid); ASSERT_EQ(output.issues.size(), 2); EXPECT_EQ(output.issues[0].location().line, 1); EXPECT_EQ(output.issues[0].location().column, 10); EXPECT_EQ(output.issues[1].location().line, 2); EXPECT_EQ(output.issues[1].location().column, 5); } } // namespace } // namespace cel